Attention
Whereas attention-based models, such as transformers, have gained significant attention for natural language processing (NLP) and image processing, their potential for implementation in complex-valued problems such as signal processing remains relatively untapped. Here, we include complex-valued variants of several attention-based techniques.
- class complextorch.nn.modules.attention.CVMultiheadAttention(n_heads: int, d_model: int, d_k: int, d_v: int, dropout: float = 0.1, SoftMaxClass: ~torch.nn.modules.module.Module = <class 'complextorch.nn.modules.softmax.CVSoftMax'>)
Complex-Valued Multihead Attention
Multihead self attention extended to complex-valued tensors.
By default, the
CVMultiheadAttentionemploys thecomplextorch.nn.CVSoftmax, which applies the traditional softmax to the magnitude of the complex-valued tensor while leaving the phase information unchanged.- forward(q: CVTensor, k: CVTensor, v: CVTensor) CVTensor
Defines the computation performed at every call.
Should be overridden by all subclasses.
Note
Although the recipe for forward pass needs to be defined within this function, one should call the
Moduleinstance afterwards instead of this since the former takes care of running the registered hooks while the latter silently ignores them.
- class complextorch.nn.modules.attention.CVScaledDotProductAttention(temperature: float, attn_dropout: float = 0.1, SoftMaxClass: ~torch.nn.modules.module.Module = <class 'complextorch.nn.modules.softmax.CVSoftMax'>)
Complex-Valued Scaled Dot-Product Attention
The ever-popular scaled dot-product attention is the backbone of many attention-based methods, most notably the transformer.
Implements the operation:
\[\text{Attention}(Q, K, V) = \mathcal{S}(Q K^T / t) V\]where \(Q, K, V\) are complex-valued tensors, \(t\) is known as the temperature typically \(t = \sqrt{d_{attn}}\), and \(\mathcal{S}\) is the softmax function.
For complex-values, the traditional softmax function cannot be applied, and variants must be applied. Included in this library are several options for complex-valued softmax and similar masking functions.
By default, the
CVScaledDotProductAttentionemploys thecomplextorch.nn.CVSoftmax, which applies the traditional softmax to the magnitude of the complex-valued tensor while leaving the phase information unchanged.