Shortcuts

torch.solve

torch.solve(input, A, out=None) -> (Tensor, Tensor)

This function returns the solution to the system of linear equations represented by AX=BAX = B and the LU factorization of A, in order as a namedtuple solution, LU.

LU contains L and U factors for LU factorization of A.

torch.solve(B, A) can take in 2D inputs B, A or inputs that are batches of 2D matrices. If the inputs are batches, then returns batched outputs solution, LU.

Note

Irrespective of the original strides, the returned matrices solution and LU will be transposed, i.e. with strides like B.contiguous().transpose(-1, -2).stride() and A.contiguous().transpose(-1, -2).stride() respectively.

Parameters
  • input (Tensor) – input matrix BB of size (,m,k)(*, m, k) , where * is zero or more batch dimensions.

  • A (Tensor) – input square matrix of size (,m,m)(*, m, m) , where * is zero or more batch dimensions.

  • out ((Tensor, Tensor), optional) – optional output tuple.

Example:

>>> A = torch.tensor([[6.80, -2.11,  5.66,  5.97,  8.23],
                      [-6.05, -3.30,  5.36, -4.44,  1.08],
                      [-0.45,  2.58, -2.70,  0.27,  9.04],
                      [8.32,  2.71,  4.35,  -7.17,  2.14],
                      [-9.67, -5.14, -7.26,  6.08, -6.87]]).t()
>>> B = torch.tensor([[4.02,  6.19, -8.22, -7.57, -3.03],
                      [-1.56,  4.00, -8.67,  1.75,  2.86],
                      [9.81, -4.09, -4.57, -8.61,  8.99]]).t()
>>> X, LU = torch.solve(B, A)
>>> torch.dist(B, torch.mm(A, X))
tensor(1.00000e-06 *
       7.0977)

>>> # Batched solver example
>>> A = torch.randn(2, 3, 1, 4, 4)
>>> B = torch.randn(2, 3, 1, 4, 6)
>>> X, LU = torch.solve(B, A)
>>> torch.dist(B, A.matmul(X))
tensor(1.00000e-06 *
   3.6386)

Docs

Access comprehensive developer documentation for PyTorch

View Docs

Tutorials

Get in-depth tutorials for beginners and advanced developers

View Tutorials

Resources

Find development resources and get your questions answered

View Resources