Shortcuts

Source code for torch.nn.modules.linear

import math

import torch
from torch import Tensor
from torch.nn.parameter import Parameter
from .. import functional as F
from .. import init
from .module import Module


class Identity(Module):
    r"""A placeholder identity operator that is argument-insensitive.

    Args:
        args: any argument (unused)
        kwargs: any keyword argument (unused)

    Examples::

        >>> m = nn.Identity(54, unused_argument1=0.1, unused_argument2=False)
        >>> input = torch.randn(128, 20)
        >>> output = m(input)
        >>> print(output.size())
        torch.Size([128, 20])

    """
    def __init__(self, *args, **kwargs):
        super(Identity, self).__init__()

    def forward(self, input: Tensor) -> Tensor:
        return input


[docs]class Linear(Module): r"""Applies a linear transformation to the incoming data: :math:`y = xA^T + b` Args: in_features: size of each input sample out_features: size of each output sample bias: If set to ``False``, the layer will not learn an additive bias. Default: ``True`` Shape: - Input: :math:`(N, *, H_{in})` where :math:`*` means any number of additional dimensions and :math:`H_{in} = \text{in\_features}` - Output: :math:`(N, *, H_{out})` where all but the last dimension are the same shape as the input and :math:`H_{out} = \text{out\_features}`. Attributes: weight: the learnable weights of the module of shape :math:`(\text{out\_features}, \text{in\_features})`. The values are initialized from :math:`\mathcal{U}(-\sqrt{k}, \sqrt{k})`, where :math:`k = \frac{1}{\text{in\_features}}` bias: the learnable bias of the module of shape :math:`(\text{out\_features})`. If :attr:`bias` is ``True``, the values are initialized from :math:`\mathcal{U}(-\sqrt{k}, \sqrt{k})` where :math:`k = \frac{1}{\text{in\_features}}` Examples:: >>> m = nn.Linear(20, 30) >>> input = torch.randn(128, 20) >>> output = m(input) >>> print(output.size()) torch.Size([128, 30]) """ __constants__ = ['in_features', 'out_features'] in_features: int out_features: int weight: Tensor def __init__(self, in_features: int, out_features: int, bias: bool = True) -> None: super(Linear, self).__init__() self.in_features = in_features self.out_features = out_features self.weight = Parameter(torch.Tensor(out_features, in_features)) if bias: self.bias = Parameter(torch.Tensor(out_features)) else: self.register_parameter('bias', None) self.reset_parameters() def reset_parameters(self) -> None: init.kaiming_uniform_(self.weight, a=math.sqrt(5)) if self.bias is not None: fan_in, _ = init._calculate_fan_in_and_fan_out(self.weight) bound = 1 / math.sqrt(fan_in) init.uniform_(self.bias, -bound, bound) def forward(self, input: Tensor) -> Tensor: return F.linear(input, self.weight, self.bias) def extra_repr(self) -> str: return 'in_features={}, out_features={}, bias={}'.format( self.in_features, self.out_features, self.bias is not None )
# This class exists solely for Transformer; it has an annotation stating # that bias is never None, which appeases TorchScript class _LinearWithBias(Linear): bias: Tensor def __init__(self, in_features: int, out_features: int) -> None: super().__init__(in_features, out_features, bias=True)
[docs]class Bilinear(Module): r"""Applies a bilinear transformation to the incoming data: :math:`y = x_1^T A x_2 + b` Args: in1_features: size of each first input sample in2_features: size of each second input sample out_features: size of each output sample bias: If set to False, the layer will not learn an additive bias. Default: ``True`` Shape: - Input1: :math:`(N, *, H_{in1})` where :math:`H_{in1}=\text{in1\_features}` and :math:`*` means any number of additional dimensions. All but the last dimension of the inputs should be the same. - Input2: :math:`(N, *, H_{in2})` where :math:`H_{in2}=\text{in2\_features}`. - Output: :math:`(N, *, H_{out})` where :math:`H_{out}=\text{out\_features}` and all but the last dimension are the same shape as the input. Attributes: weight: the learnable weights of the module of shape :math:`(\text{out\_features}, \text{in1\_features}, \text{in2\_features})`. The values are initialized from :math:`\mathcal{U}(-\sqrt{k}, \sqrt{k})`, where :math:`k = \frac{1}{\text{in1\_features}}` bias: the learnable bias of the module of shape :math:`(\text{out\_features})`. If :attr:`bias` is ``True``, the values are initialized from :math:`\mathcal{U}(-\sqrt{k}, \sqrt{k})`, where :math:`k = \frac{1}{\text{in1\_features}}` Examples:: >>> m = nn.Bilinear(20, 30, 40) >>> input1 = torch.randn(128, 20) >>> input2 = torch.randn(128, 30) >>> output = m(input1, input2) >>> print(output.size()) torch.Size([128, 40]) """ __constants__ = ['in1_features', 'in2_features', 'out_features'] in1_features: int in2_features: int out_features: int weight: Tensor def __init__(self, in1_features: int, in2_features: int, out_features: int, bias: bool = True) -> None: super(Bilinear, self).__init__() self.in1_features = in1_features self.in2_features = in2_features self.out_features = out_features self.weight = Parameter(torch.Tensor(out_features, in1_features, in2_features)) if bias: self.bias = Parameter(torch.Tensor(out_features)) else: self.register_parameter('bias', None) self.reset_parameters() def reset_parameters(self) -> None: bound = 1 / math.sqrt(self.weight.size(1)) init.uniform_(self.weight, -bound, bound) if self.bias is not None: init.uniform_(self.bias, -bound, bound) def forward(self, input1: Tensor, input2: Tensor) -> Tensor: return F.bilinear(input1, input2, self.weight, self.bias) def extra_repr(self) -> str: return 'in1_features={}, in2_features={}, out_features={}, bias={}'.format( self.in1_features, self.in2_features, self.out_features, self.bias is not None )
# TODO: PartialLinear - maybe in sparse?

Docs

Access comprehensive developer documentation for PyTorch

View Docs

Tutorials

Get in-depth tutorials for beginners and advanced developers

View Tutorials

Resources

Find development resources and get your questions answered

View Resources