Shortcuts

MultiheadAttention

class torch.nn.MultiheadAttention(embed_dim, num_heads, dropout=0.0, bias=True, add_bias_kv=False, add_zero_attn=False, kdim=None, vdim=None)[source]

Allows the model to jointly attend to information from different representation subspaces. See reference: Attention Is All You Need

MultiHead(Q,K,V)=Concat(head1,,headh)WOwhereheadi=Attention(QWiQ,KWiK,VWiV)\text{MultiHead}(Q, K, V) = \text{Concat}(head_1,\dots,head_h)W^O \text{where} head_i = \text{Attention}(QW_i^Q, KW_i^K, VW_i^V)
Parameters
  • embed_dim – total dimension of the model.

  • num_heads – parallel attention heads.

  • dropout – a Dropout layer on attn_output_weights. Default: 0.0.

  • bias – add bias as module parameter. Default: True.

  • add_bias_kv – add bias to the key and value sequences at dim=0.

  • add_zero_attn – add a new batch of zeros to the key and value sequences at dim=1.

  • kdim – total number of features in key. Default: None.

  • vdim – total number of features in value. Default: None.

  • Note – if kdim and vdim are None, they will be set to embed_dim such that

  • key, and value have the same number of features. (query,) –

Examples:

>>> multihead_attn = nn.MultiheadAttention(embed_dim, num_heads)
>>> attn_output, attn_output_weights = multihead_attn(query, key, value)
forward(query: torch.Tensor, key: torch.Tensor, value: torch.Tensor, key_padding_mask: Optional[torch.Tensor] = None, need_weights: bool = True, attn_mask: Optional[torch.Tensor] = None) → Tuple[torch.Tensor, Optional[torch.Tensor]][source]
Parameters
  • key, value (query,) – map a query and a set of key-value pairs to an output. See “Attention Is All You Need” for more details.

  • key_padding_mask – if provided, specified padding elements in the key will be ignored by the attention. When given a binary mask and a value is True, the corresponding value on the attention layer will be ignored. When given a byte mask and a value is non-zero, the corresponding value on the attention layer will be ignored

  • need_weights – output attn_output_weights.

  • attn_mask – 2D or 3D mask that prevents attention to certain positions. A 2D mask will be broadcasted for all the batches while a 3D mask allows to specify a different mask for the entries of each batch.

Shape:
  • Inputs:

  • query: (L,N,E)(L, N, E) where L is the target sequence length, N is the batch size, E is the embedding dimension.

  • key: (S,N,E)(S, N, E) , where S is the source sequence length, N is the batch size, E is the embedding dimension.

  • value: (S,N,E)(S, N, E) where S is the source sequence length, N is the batch size, E is the embedding dimension.

  • key_padding_mask: (N,S)(N, S) where N is the batch size, S is the source sequence length. If a ByteTensor is provided, the non-zero positions will be ignored while the position with the zero positions will be unchanged. If a BoolTensor is provided, the positions with the value of True will be ignored while the position with the value of False will be unchanged.

  • attn_mask: 2D mask (L,S)(L, S) where L is the target sequence length, S is the source sequence length. 3D mask (Nnumheads,L,S)(N*num_heads, L, S) where N is the batch size, L is the target sequence length, S is the source sequence length. attn_mask ensure that position i is allowed to attend the unmasked positions. If a ByteTensor is provided, the non-zero positions are not allowed to attend while the zero positions will be unchanged. If a BoolTensor is provided, positions with True is not allowed to attend while False values will be unchanged. If a FloatTensor is provided, it will be added to the attention weight.

  • Outputs:

  • attn_output: (L,N,E)(L, N, E) where L is the target sequence length, N is the batch size, E is the embedding dimension.

  • attn_output_weights: (N,L,S)(N, L, S) where N is the batch size, L is the target sequence length, S is the source sequence length.

Docs

Access comprehensive developer documentation for PyTorch

View Docs

Tutorials

Get in-depth tutorials for beginners and advanced developers

View Tutorials

Resources

Find development resources and get your questions answered

View Resources