torch.max¶
-
torch.
max
(input) → Tensor¶ Returns the maximum value of all elements in the
input
tensor.Warning
This function produces deterministic (sub)gradients unlike
max(dim=0)
- Parameters
input (Tensor) – the input tensor.
Example:
>>> a = torch.randn(1, 3) >>> a tensor([[ 0.6763, 0.7445, -2.2369]]) >>> torch.max(a) tensor(0.7445)
-
torch.
max
(input, dim, keepdim=False, out=None) -> (Tensor, LongTensor)
Returns a namedtuple
(values, indices)
wherevalues
is the maximum value of each row of theinput
tensor in the given dimensiondim
. Andindices
is the index location of each maximum value found (argmax).Warning
indices
does not necessarily contain the first occurrence of each maximal value found, unless it is unique. The exact implementation details are device-specific. Do not expect the same result when run on CPU and GPU in general. For the same reason do not expect the gradients to be deterministic.If
keepdim
isTrue
, the output tensors are of the same size asinput
except in the dimensiondim
where they are of size 1. Otherwise,dim
is squeezed (seetorch.squeeze()
), resulting in the output tensors having 1 fewer dimension thaninput
.- Parameters
Example:
>>> a = torch.randn(4, 4) >>> a tensor([[-1.2360, -0.2942, -0.1222, 0.8475], [ 1.1949, -1.1127, -2.2379, -0.6702], [ 1.5717, -0.9207, 0.1297, -1.8768], [-0.6172, 1.0036, -0.6060, -0.2432]]) >>> torch.max(a, 1) torch.return_types.max(values=tensor([0.8475, 1.1949, 1.5717, 1.0036]), indices=tensor([3, 0, 0, 1]))
-
torch.
max
(input, other, out=None) → Tensor
Each element of the tensor
input
is compared with the corresponding element of the tensorother
and an element-wise maximum is taken.The shapes of
input
andother
don’t need to match, but they must be broadcastable.Note
When the shapes do not match, the shape of the returned output tensor follows the broadcasting rules.
- Parameters
Example:
>>> a = torch.randn(4) >>> a tensor([ 0.2942, -0.7416, 0.2653, -0.1584]) >>> b = torch.randn(4) >>> b tensor([ 0.8722, -1.7421, -0.4141, -0.5055]) >>> torch.max(a, b) tensor([ 0.8722, -0.7416, 0.2653, -0.1584])