Higher-order Linear Attention


TL;DR

This paper proposes a new attention mechanism called Higher-order Linear Attention (HLA), which achieves higher-order interactions through compact prefix sufficient statistics while preserving linear-time complexity and streaming computation capability, thereby addressing the quadratic complexity bottleneck of standard attention mechanisms without sacrificing expressiveness.

Key Definitions

The foundation of modern large language models (LLMs) is the Transformer architecture and its core component—scaled dot-product attention. However, its computational and memory complexity grows as $O(n^2)$ with sequence length $n$, which severely limits the use of these models in long-context settings.

To address this bottleneck, the field has seen a variety of efficient alternatives, including Linear Attention, modern recurrent neural networks (RNNs), and State Space Models (SSMs). These methods typically achieve linear-time complexity and $O(1)$ state updates at inference time. However, most of them are limited to first-order or kernel-based approximations, which may constrain model expressiveness.

The core problem this paper aims to solve is: how can we design a mechanism that has the data dependence and higher-order interaction capabilities of attention, while also achieving efficient streaming computation and parallel training like modern recurrent architectures?

Method

The core contribution of this paper is Higher-order Linear Attention (HLA), which enables streaming computation of higher-order interactions through compact prefix summaries.

Second-order HLA

As a starting point, the paper begins with second-order tensor attention:

\[\mathbf{T}_{2} := (\mathbf{Q}\mathbf{K}^{\top})(\mathbf{Q}\mathbf{K}^{\top})^{\top} = \mathbf{Q}(\mathbf{K}^{\top}\mathbf{K})\mathbf{Q}^{\top} \in \mathbb{R}^{n \times n}\]

The key is that it depends on the second-order moment of keys, $\mathbf{K}^{\top}\mathbf{K}$. This motivates streaming computation by maintaining prefix summaries. At time step $t$, the following summaries are maintained:

The update cost of these summaries is $O(d^2 + d d_v)$, independent of sequence length.

Based on these summaries, the output of second-order HLA (by default in unnormalized form) at time step $t$ is defined as:

\[\mathbf{o}_{t} \coloneqq \mathbf{q}_{t}^{\top}\mathbf{S}_{t}^{K}\mathbf{C}_{t}^{QV}\]

Normalization is also possible:

\[\mathbf{o}_{t} = \frac{\mathbf{q}_{t}^{\top}\mathbf{S}_{t}^{K}\mathbf{C}_{t}^{QV}}{\mathbf{q}_{t}^{\top}\mathbf{S}_{t}^{K}\mathbf{m}_{t}^{Q}+\varepsilon}\]

Here, $\mathbf{S}_t^K$ acts as a data-dependent, learnable metric matrix, enriching the model’s expressiveness. When $\mathbf{S}_t^K = \mathbf{I}$, this form reduces to a linear attention mechanism.

Innovation 1: Causal masking via extended summaries

Standard attention mechanisms require causal masking in computation to ensure that, in autoregressive tasks, the output at the current time step depends only on past information. Applying masking directly in HLA would break the factorized computation structure.

To solve this, the paper introduces two additional extended prefix summaries:

\[\mathbf{G}_{t} \coloneqq \sum_{i\leq t}\left(\mathbf{k}_{i}\mathbf{k}_{i}^{\top}\right)\mathbf{C}_{i-1}^{QV} \in \mathbb{R}^{d\times d_v}\] \[\mathbf{h}_{t} \coloneqq \sum_{i\leq t}\left(\mathbf{k}_{i}\mathbf{k}_{i}^{\top}\right)\mathbf{m}_{i-1}^{Q} \in \mathbb{R}^{d}\]

With these correction terms, the strictly causal second-order HLA output can be computed exactly without materializing any $n \times n$ matrix. For example, the unnormalized causal output is:

\[\mathbf{o}_{t} = \mathbf{q}_{t}^{\top}(\mathbf{S}_{t}^{K}\mathbf{C}_{t}^{QV} - \mathbf{G}_{t})\]

All summaries, including $\mathbf{G}_t$ and $\mathbf{h}_t$, support constant-time online updates, preserving the efficiency of streaming computation.

Innovation 2: Parallel training via associative scans

Purely recurrent models are inefficient to train on GPUs. To enable efficient parallel training, the paper defines an associative operator \(⊕\) for HLA state updates and uses associative scans, such as Blelloch scan, to compute prefix sums.

This method can partition the sequence into blocks and perform scans in parallel within and across blocks, producing activations exactly identical to those of a serial loop, thereby enabling efficient and exact parallel training. This framework can also be extended to the case with exponential decay $\gamma$.

Masked (Second Order) HLA with Within-Chunk Scan

Asymmetric Higher-Order Linear Attention (AHLA)

This paper also proposes a complementary variant called AHLA. It computes the left-cascaded product $\mathbf{Q}(\mathbf{K}^\top\mathbf{Q})(\mathbf{K}^\top\mathbf{V})$ instead of the symmetric form used in HLA. AHLA also supports streaming computation and causal masking, but uses different prefix summaries, for example:

Its streaming output is $\mathbf{o}_{t}^{\textsc{AHLA}} = \mathbf{q}_{t}^{\top}\mathbf{E}_{t}$. AHLA has a computational cost of $O(d d_v)$, and in some cases is more efficient than HLA.

Experimental Conclusions

This paper mainly focuses on the algorithmic structure and theoretical derivations, and does not provide specific experimental results or performance comparisons with other models.

Summary This paper presents a complete, scalable attention framework—Higher-Order Linear Attention (HLA). Its main contributions and advantages are as follows:

In short, HLA, as a building block that can directly replace standard attention, cleverly combines the data-dependent weighting properties of attention with the high efficiency of modern recurrent architectures, providing a powerful and principled tool for building scalable long-context language models.