torch.flatten¶
-
torch.
flatten
(input, start_dim=0, end_dim=-1) → Tensor¶ Flattens a contiguous range of dims in a tensor.
- Parameters
Example:
>>> t = torch.tensor([[[1, 2], [3, 4]], [[5, 6], [7, 8]]]) >>> torch.flatten(t) tensor([1, 2, 3, 4, 5, 6, 7, 8]) >>> torch.flatten(t, start_dim=1) tensor([[1, 2, 3, 4], [5, 6, 7, 8]])