Attention

Whereas attention-based models, such as transformers, have gained significant attention for natural language processing (NLP) and image processing, their potential for implementation in complex-valued problems such as signal processing remains relatively untapped. Here, we include complex-valued variants of several attention-based techniques.

class complextorch.nn.modules.attention.CVMultiheadAttention(n_heads: int, d_model: int, d_k: int, d_v: int, dropout: float = 0.1, SoftMaxClass: ~torch.nn.modules.module.Module = <class 'complextorch.nn.modules.softmax.CVSoftMax'>)

Complex-Valued Multihead Attention

Multihead self attention extended to complex-valued tensors.

By default, the CVMultiheadAttention employs the complextorch.nn.CVSoftmax, which applies the traditional softmax to the magnitude of the complex-valued tensor while leaving the phase information unchanged.

forward(q: CVTensor, k: CVTensor, v: CVTensor) CVTensor

Defines the computation performed at every call.

Should be overridden by all subclasses.

Note

Although the recipe for forward pass needs to be defined within this function, one should call the Module instance afterwards instead of this since the former takes care of running the registered hooks while the latter silently ignores them.

class complextorch.nn.modules.attention.CVScaledDotProductAttention(temperature: float, attn_dropout: float = 0.1, SoftMaxClass: ~torch.nn.modules.module.Module = <class 'complextorch.nn.modules.softmax.CVSoftMax'>)

Complex-Valued Scaled Dot-Product Attention

The ever-popular scaled dot-product attention is the backbone of many attention-based methods, most notably the transformer.

Implements the operation:

\[\text{Attention}(Q, K, V) = \mathcal{S}(Q K^T / t) V\]

where \(Q, K, V\) are complex-valued tensors, \(t\) is known as the temperature typically \(t = \sqrt{d_{attn}}\), and \(\mathcal{S}\) is the softmax function.

For complex-values, the traditional softmax function cannot be applied, and variants must be applied. Included in this library are several options for complex-valued softmax and similar masking functions.

By default, the CVScaledDotProductAttention employs the complextorch.nn.CVSoftmax, which applies the traditional softmax to the magnitude of the complex-valued tensor while leaving the phase information unchanged.

forward(q: CVTensor, k: CVTensor, v: CVTensor) CVTensor

Implements the complex-valued scaled dot-product attention operation.

Parameters:
  • q (CVTensor) – complex-valued query tensor

  • k (CVTensor) – complex-valued key tensor

  • v (CVTensor) – complex-valued value tensor

Returns:

mathcal{S}(Q K^T / t) V

Return type:

CVTensor