torch.chain_matmul¶
-
torch.
chain_matmul
(*matrices)[source]¶ Returns the matrix product of the 2-D tensors. This product is efficiently computed using the matrix chain order algorithm which selects the order in which incurs the lowest cost in terms of arithmetic operations ([CLRS]). Note that since this is a function to compute the product, needs to be greater than or equal to 2; if equal to 2 then a trivial matrix-matrix product is returned. If is 1, then this is a no-op - the original matrix is returned as is.
- Parameters
matrices (Tensors...) – a sequence of 2 or more 2-D tensors whose product is to be determined.
- Returns
if the tensor was of dimensions , then the product would be of dimensions .
- Return type
Example:
>>> a = torch.randn(3, 4) >>> b = torch.randn(4, 5) >>> c = torch.randn(5, 6) >>> d = torch.randn(6, 7) >>> torch.chain_matmul(a, b, c, d) tensor([[ -2.3375, -3.9790, -4.1119, -6.6577, 9.5609, -11.5095, -3.2614], [ 21.4038, 3.3378, -8.4982, -5.2457, -10.2561, -2.4684, 2.7163], [ -0.9647, -5.8917, -2.3213, -5.2284, 12.8615, -12.2816, -2.5095]])