一、概述
tf.nest的公共API称空间。
函数列表:
-
assert_same_structure(...)
: 断言两个结构以相同的方式嵌套。 -
flatten(...)
: 从给定的嵌套结构返回平面列表。 -
is_nested(...)
: 如果输入是collection .abc,则返回true。序列(字符串除外)。 -
map_structure(...)
: 对结构中的每个条目应用func并返回一个新结构。 -
pack_sequence_as(...)
: 返回压缩到给定结构中的给定扁平序列。
2、tf.nest.assert_same_structure
断言两个结构以相同的方式嵌套。
tf.nest.assert_same_structure(
nest1,
nest2,
check_types=True,
expand_composites=False
)
注意,具有相同名称和字段的namedtuple总是被认为具有相同的浅结构(即使check_types=True)。例如,这段代码将打印True:
def nt(a, b):
return collections.namedtuple('foo', 'a b')(a, b)
print(assert_same_structure(nt(0, 1), nt(2, 3)))
参数:
-
nest1
:一个任意嵌套的结构。 -
nest2
:一个任意嵌套的结构。 - check_types:如果序列的类型为True(默认值)也被选中,包括字典的键。如果设置为False,例如,如果对象的列表和元组具有相同的大小,则它们看起来是相同的。注意,具有相同名称和字段的namedtuple总是被认为具有相同的浅结构。如果这两种类型都是list子类型(允许可跟踪依赖项跟踪中的“list”和“_ListWrapper”进行相等比较),那么这两种类型也将被认为是相同的。
- expand_composites:如果为真,则复合张量,如tf。SparseTensor和tf。拉格张量被展开成它们的分量张量。
可能产生的异常:
-
ValueError
: If the two structures do not have the same number of elements or if the two structures are not nested in the same way. -
TypeError
: If the two structures differ in the type of sequence in any of their substructures. Only possible ifcheck_types
isTrue
.
3、tf.nest.flatten
从给定的嵌套结构返回平面列表。
tf.nest.flatten(
structure,
expand_composites=False
)
如果嵌套不是序列、元组或dict,则返回一个单元素列表:[nest]。在dict实例的情况下,序列由值组成,按键排序,以确保确定性行为。对于OrderedDict实例也是如此:忽略它们的序列顺序,而使用键的排序顺序。在pack_sequence_as中遵循相同的约定。这将正确地重新打包已压扁的dict和OrderedDict,并允许压扁OrderedDict,然后使用相应的普通dict重新打包,反之亦然。具有不可排序键的字典不能被压扁。在运行此函数时,用户不能修改nest中使用的任何集合。
参数:
-
structure
:任意嵌套结构或标量对象。注意,numpy数组被认为是标量。 - expand_composites:如果为真,则复合张量,如tf。SparseTensor和tf。拉格张量被展开成它们的分量张量。
返回值:
一个Python列表,输入的扁平版本。
可能产生的异常:
-
TypeError
: The nest is or contains a dict with non-sortable keys.
4、tf.nest.is_nested
如果输入是collection.abc,则返回true。序列(字符串除外)。
tf.nest.is_nested(seq)
参数:
- 一个输入序列。
返回值:
- 如果序列不是字符串而是集合,则为True。顺序或dict。
5、tf.nest.map_structure
对结构中的每个条目应用func并返回一个新结构。
tf.nest.map_structure(
func,
*structure,
**kwargs
)
应用func(x[0], x[1],…),其中x[i]是结构中的一个条目[i]。结构中的所有结构必须具有相同的特性,返回值将包含具有相同结构布局的结果。
参数:
-
func
:一个可调用的函数,它接受的参数和结构一样多。 - *
structure
:标量、构造标量的元组或列表以及/或其他元组/列表或标量。注意:numpy数组被认为是标量。 - **kwargs:有效的关键字args是:
- check_types:如果设置为True(默认值),结构中的迭代器类型必须相同(例如map_structure(func,[1],(1,)),这会引发类型错误异常)。为了让这个参数为假。注意,具有相同名称和字段的namedtuple总是被认为具有相同的浅结构。
- expand_composites:如果设置为True,则复合张量,如tf。SparseTensor和tf。拉格张量被展开成它们的分量张量。如果为False(默认值),则不展开复合张量。
返回值:
- 一种新的结构,具有与结构相同的圆度,其值对应于func(x[0], x[1],…),其中x[i]是结构[i]中对应位置的一个值。如果有不同的序列类型,且check_types为False,则将使用第一个结构的序列类型。
可能产生的异常:
-
TypeError
: Iffunc
is not callable or if the structures do not match each other by depth tree. -
ValueError
: If no structure is provided or if the structures do not match each other by type. -
ValueError
: If wrong keyword arguments are provided.
6、tf.nest.pack_sequence_as
返回压缩到给定结构中的给定扁平序列。
tf.nest.pack_sequence_as(
structure,
flat_sequence,
expand_composites=False
)
如果结构是标量,则flat_sequence必须是单元素列表;在本例中,返回值是flat_sequence[0]。如果结构是或包含dict实例,则将对键进行排序,以确定顺序打包平面序列。对于OrderedDict实例也是如此:忽略它们的序列顺序,而使用键的排序顺序。在flatten中遵循相同的约定。这将正确地重新打包已压扁的dict和OrderedDict,并允许压扁OrderedDict,然后使用相应的普通dict重新打包,反之亦然。具有不可排序键的字典不能被压扁。
参数:
-
structure
:嵌套结构,其结构由嵌套列表、元组和dict给出。注意:numpy数组和字符串被认为是标量。 - flat_sequence:要打包的扁平序列。
- expand_composites:如果为真,则复合张量,如tf。SparseTensor和tf。拉格张量被展开成它们的分量张量。
返回值:
-
packed
:flat_sequence转换为与结构相同的递归结构。
可能产生的异常:
-
ValueError
: Ifflat_sequence
andstructure
have different element counts. -
TypeError
:structure
is or contains a dict with non-sortable keys.