torch.where¶
-
torch.where(condition, x, y) → Tensor¶ Return a tensor of elements selected from either
xory, depending oncondition.The operation is defined as:
Note
The tensors
condition,x,ymust be broadcastable.- Parameters
condition (BoolTensor) – When True (nonzero), yield x, otherwise yield y
x (Tensor) – values selected at indices where
conditionisTruey (Tensor) – values selected at indices where
conditionisFalse
- Returns
A tensor of shape equal to the broadcasted shape of
condition,x,y- Return type
Example:
>>> x = torch.randn(3, 2) >>> y = torch.ones(3, 2) >>> x tensor([[-0.4620, 0.3139], [ 0.3898, -0.7197], [ 0.0478, -0.1657]]) >>> torch.where(x > 0, x, y) tensor([[ 1.0000, 0.3139], [ 0.3898, 1.0000], [ 0.0478, 1.0000]])
-
torch.where(condition) → tuple of LongTensor
torch.where(condition)is identical totorch.nonzero(condition, as_tuple=True).Note
See also
torch.nonzero().