torch.median¶
-
torch.
median
(input) → Tensor¶ Returns the median value of all elements in the
input
tensor.Warning
This function produces deterministic (sub)gradients unlike
median(dim=0)
- Parameters
input (Tensor) – the input tensor.
Example:
>>> a = torch.randn(1, 3) >>> a tensor([[ 1.5219, -1.5212, 0.2202]]) >>> torch.median(a) tensor(0.2202)
-
torch.
median
(input, dim=-1, keepdim=False, out=None) -> (Tensor, LongTensor)
Returns a namedtuple
(values, indices)
wherevalues
is the median value of each row of theinput
tensor in the given dimensiondim
. Andindices
is the index location of each median value found.By default,
dim
is the last dimension of theinput
tensor.If
keepdim
isTrue
, the output tensors are of the same size asinput
except in the dimensiondim
where they are of size 1. Otherwise,dim
is squeezed (seetorch.squeeze()
), resulting in the outputs tensor having 1 fewer dimension thaninput
.Warning
indices
does not necessarily contain the first occurrence of each median value found, unless it is unique. The exact implementation details are device-specific. Do not expect the same result when run on CPU and GPU in general. For the same reason do not expect the gradients to be deterministic.- Parameters
Example:
>>> a = torch.randn(4, 5) >>> a tensor([[ 0.2505, -0.3982, -0.9948, 0.3518, -1.3131], [ 0.3180, -0.6993, 1.0436, 0.0438, 0.2270], [-0.2751, 0.7303, 0.2192, 0.3321, 0.2488], [ 1.0778, -1.9510, 0.7048, 0.4742, -0.7125]]) >>> torch.median(a, 1) torch.return_types.median(values=tensor([-0.3982, 0.2270, 0.2488, 0.4742]), indices=tensor([1, 4, 4, 3]))