目录
一、tf.stcak()
tf.stack(
values,
axis=0,
name='stack'
)
将一列秩为R的张量叠加成一个秩为(R+1)的张量。 将值中的张量列表沿轴维进行打包,将其打包成一个比值中的每个张量的秩高1的张量。给出形状张量长度N的列表(A, B, C);如果axis == 0,则输出张量的形状为(N, A, B, C);如果axis == 1,则输出张量的形状为(A, N, B, C)等。
例:
x = tf.constant([1, 4])
y = tf.constant([2, 5])
z = tf.constant([3, 6])
tf.stack([x, y, z]) # [[1, 4], [2, 5], [3, 6]] (Pack along first dim.)
tf.stack([x, y, z], axis=1) # [[1, 2, 3], [4, 5, 6]]
这是unstack的反面。相当于numpy的:
tf.stack([x, y, z]) = np.stack([x, y, z])
参数:
- value: 具有相同形状和类型的张量对象列表。
- axis: 一个整型数。要堆叠的轴。默认为第一个维度。负值环绕,所以有效范围是[-(R+1), R+1)。
- name: 此操作的名称(可选)。
返回值:
-
output
: 与值类型相同的叠加张量。
可能产生的异常:
-
ValueError
: Ifaxis
is out of the range [-(R+1), R+1).
二、tf.unstcak()
tf.unstack(
value,
num=None,
axis=0,
name='unstack'
)
将秩为R张量的给定维数分解为秩为(R-1)张量。通过沿着轴维对num张量进行切分,从值中解压缩num张量。如果没有指定num(默认值),则从值的形状推断它。如果value.shape[axis]未知,将引发ValueError。例如,给定一个形状张量(A, B, C, D);如果axis == 0,那么输出中的第i张量就是切片值[i,:,:,:],而输出中的每个张量都有形状(B, C, D)。(注意,与split不同的是,未打包的维度已经没有了)。如果axis == 1,则输出中的第i张量为切片值[:,i,:,:],输出中的每个张量都有形状(A, C, D)等。这是堆栈的反面。
参数:
- value: 要被解压的秩大于0的张量。
- num: 一个int类型, 一个整型数。尺寸轴的长度。如果没有(默认值)就自动推断。
- axis: 一个整型数。沿着整型数展开堆栈。默认为第一个维度。负值环绕,所以有效范围是[-R, R]。
- name: 操作的名称(可选)。
返回值:
- 张量对象的列表从值中分解。
异常:
-
ValueError
: Ifnum
is unspecified and cannot be inferred. -
ValueError
: Ifaxis
is out of the range [-R, R).
三、实例
1、tf.stack
import tensorflow as tf
a = tf.constant([1, 2, 3])
b = tf.constant([4, 5, 6])
c = tf.stack( [a,b], axis=0)
with tf.Session() as sess:
print(sess.run(c))
Output:
----------
[[1 2 3]
[4 5 6]]
----------
如果设置axis = 1
Output:
--------
[[1 4]
[2 5]
[3 6]]
--------
2、tf.unstack
import tensorflow as tf
c = tf.constant([[1, 2, 3],
[4, 5, 6]])
d = tf.unstack(c, axis=0)
e = tf.unstack(c, axis=1)
with tf.Session() as sess:
print(sess.run(d))
print(sess.run(e))
Output:
----------------------------------------------
[array([1, 2, 3]), array([4, 5, 6])]
[array([1, 4]), array([2, 5]), array([3, 6])]
----------------------------------------------
原链接: https://tensorflow.google.cn/api_docs/python/tf/stack
https://tensorflow.google.cn/versions/r1.9/api_docs/python/tf/unstack?hl=en