一、max()函数

函数定义:torch.max(input, dim, max=None, max_indices=None,keepdim=False)

参数:

  • input:进行max操作的Tensor变量
  • dim:需要查找最大值得维度(这里很迷,后面重点介绍)
  • max:结果张量,用于存储查找到的最大值
  • max_indices:结果张量,用于存储查找到最大值所处的索引
  • keepdim=False:返回值与原Tensor的size保持一致

1. 简单应用

t1=torch.LongTensor([3,9,6,2,5])

print("-------max-------")

print(torch.max(t1))

print("-------max dim-------")

print(torch.max(t1,dim=0))

输出结果为:

Pytorch的max()与min()函数_torch

可以看到,加了dim参数后,返回值中多了一个indices Tensor,这个张量用于存储下最大值的下标,例子中最大值9的下标为1。

2. 二维Tensor

对二维Tensor使用max/min函数,必须搞清楚的就是dim参数,先说结论:

①. dim为0,用于查找每列的最大值。返回行下标索引。

②. dim为1,用于查找每行的最大值。返回列下标索引。

③. 不添加dim参数,返回所有值中的最大值,且无索引。这里放在4.中展示。

从这里看就有些奇怪了,因为众所周知,二维情况下,第0维为行,第1维为列。为什么dim为0时返回每列的最大值。

先看一个例子,以一个两行三列的Tensor(size=2x3)维例:
 

t=torch.randn(2,3)
print(t)
print("-------max dim=0 -------")
print(torch.max(t,dim=0))
print("-------max dim=1 -------")
print(torch.max(t,dim=1))

 

输出结果为: 

Pytorch的max()与min()函数_二维_02

当dim=0时,输出最大值为,第一列最大值0.6301,第二列最大值0.8937,第三列最大值0.3851。

当dim=1时,输出最大值为,第一行最大值0.8937,第二行最大值0.6301。

 

我们结论是正确的,我们从下标来分析一下:

首先,当dim=0时,三个列最大值的下标,分别为[1][0]、[0][1]、[1][2]。(以及返回的索引张量[1,0,1])

           当dim=1时,两个行最大值的下标,分别为[0][1]、[1][0]。(以及返回的索引张量[1,0])

我们能够看到,max()得到的最大值,本质上,是除了dim维以外,取其余维度逐一遍历分组(红色下标),组内补上每一个dim维后的几个数据的内部比较。

对dim参数的结论:

在其他维度均确定的情况下,比较所有dim维对应的数据,找到其中的最大值,并返回索引。

我们根据此例进行分析:

当dim=0时,除了dim等于的第0维,还有第1维,遍历第1维,得到[0],[1],[2]。再补上第0维,根据遍历第1维得到的,三个1维下标分为三组。第一组([0][0],[1][0])、第二组([0][1],[1][1])、第三组([0][2],[1][2])。进行内部比较,得到三个组内最大值,即[0.6301,0.8937,0.3851],得到索引[1,0,1]。所以,也就是每一列的最大值了。

同理可以分析该例子中,dim=1的情况。

但是对于二维Tensor来说,记住结论比理解这个更容易。当三维及以上时,理解 这个就变得很重要了。

3、二维以上Tensor使用

这里主要使用病分析一个,三维的Tensor使用max操作来验证我们上面的结论。(对三维张量第0维顺着层,第1维顺着行,第2维度顺着列)。

例子:

>>> t=torch.randn(2,2,2)
>>> print(t)
tensor([[[ 1.0462,  0.0361],
         [ 0.3875, -0.1129]],

        [[ 0.6716, -1.5034],
         [-1.4784,  0.8816]]])
>>> print("-------max dim=0 -------")
-------max dim=0 -------
>>> print(torch.max(t,dim=0))
torch.return_types.max(
values=tensor([[1.0462, 0.0361],
        [0.3875, 0.8816]]),
indices=tensor([[0, 0],
        [0, 1]]))
>>> print("-------max dim=1 -------")
-------max dim=1 -------
>>> print(torch.max(t,dim=1))
torch.return_types.max(
values=tensor([[1.0462, 0.0361],
        [0.6716, 0.8816]]),
indices=tensor([[0, 0],
        [0, 1]]))
>>> print("-------max dim=2 -------")
-------max dim=2 -------
>>> print(torch.max(t,dim=2))
torch.return_types.max(
values=tensor([[1.0462, 0.3875],
        [0.6716, 0.8816]]),
indices=tensor([[0, 0],
        [0, 1]]))
>>> 

输出结果:

 

分析:

      ①. 对于dim=0,遍历除了第0维外,得到的第1维、第2维组合有[0][0],[0][1],[1][0],[1][1]。所以分为的组有第一组([0][0][0],[1][0][0]),第二组([0][0][1],[1][0][1]),第三组([0][1][0],[1][1][0]),第四组([0][1][1],[1][1][1])。

      数据就是(0.9560,0.0632),(1.6869,0.3790),(1.1282,0.8084),(0.8298,-1.4528)

      max结果得到:最大值[0.9560,1.6869,1.1282,0.8298],索引[0,0,0,0]

      ②. 对于dim-1,遍历除了第1维外,得到的第0维、第2维组合有[0]_[0],[0]_[1],[1]_[0],[1]_[1]。将第1维加入后,既可以分为([0][0][0],[0][1][0]),([0][0][1],[0][1][1]),([1][0][0],[1][1][0]),([1][0][1],[1][1][1]).

      得到的数据组为(0.9560,1.1282),(1.6869,0.8298),(0.0632,0.8084),(0.3790,-1.4528)

      max结果得到: 最大值[1.1282,1.6869,0.8084,0.3790],索引[1,0,1,0]

4. 无dim参数的max()函数

当使用torch.max()函数时,不添加dim函数,则返回所有元素中值最大值(格式为size为1Tensor),且无索引。

例子:


t=torch.randn(2,2,2)

print(t)

print(torch.max(t))

输出结果:

Pytorch的max()与min()函数_二维_03

结果输出,所有元素中的最大值。

二、min()函数

与max相同,但是返回为最小值。

-------end------