函数原型
tf.linalg.band_part(
    input, num_lower, num_upper, name=None
)
函数说明

band_part函数主要用于处理方形矩阵的副对角线上的元素。以对角线为中心,对副对角线上的元素进行取舍(是否用0填充)。

参数num_lower表示下三角矩阵保留的副对角线的数量,比如num_lower=2表示下三角矩阵从第二条副对角线开始,之后的所有的副对角线的元素全部用0填充。类似的,参数num_upper表示上三角矩阵保留的副对角线的数量。注意,如果为负数,则表示全部保留。

函数使用
>>> a = [[1, 2, 3, 4],
	 [2, 1, 5, 6],
	 [3, 5, 1, 7],
         [4, 6, 7, 1]]
>>> b = tf.constant(a)
>>> b
<tf.Tensor: shape=(4, 4), dtype=int32, numpy=
array([[1, 2, 3, 4],
       [2, 1, 5, 6],
       [3, 5, 1, 7],
       [4, 6, 7, 1]])>
>>> c = tf.linalg.band_part(b, 1, -1)
>>> c
<tf.Tensor: shape=(4, 4), dtype=int32, numpy=
array([[1, 2, 3, 4],
       [2, 1, 5, 6],
       [0, 5, 1, 7],
       [0, 0, 7, 1]])>
>>> d = tf.linalg.band_part(b, 2, 2)
>>> d
<tf.Tensor: shape=(4, 4), dtype=int32, numpy=
array([[1, 2, 3, 0],
       [2, 1, 5, 6],
       [3, 5, 1, 7],
       [0, 6, 7, 1]])>
Logo

有“AI”的1024 = 2048,欢迎大家加入2048 AI社区

更多推荐