1. Introduction
Transformer models [
1] have significantly advanced neural networks across various domains, including natural language processing [
2,
3] and computer vision [
4,
5]. A key contributor to their success is a scaled dot-product attention mechanism, which enables Transformers to capture contextual information within data, leading to state-of-the-art performance in numerous tasks. However, this attention mechanism has quadratic space and time complexity with respect to the sequence length due to its softmax function, which limits the context length that can be computed [
6,
7,
8,
9,
10].
To address this limitation, several methods have been developed to improve the efficiency of the attention mechanism, often by linearizing the original attention using kernel methods [
6,
10] or low-rank approximations [
9]. These efficient attention mechanisms achieve linear theoretical complexity with respect to sequence length and exhibit comparable performance to the original attention in long sequential tasks, such as Long-Range Arena tasks [
11] and language modeling tasks [
12]. Despite these improvements, prior research on linear attention mechanisms has primarily focused on efficiency in handling long sequences while neglecting other critical properties of neural networks, such as systematic generalization.
Systematic generalization refers to the ability to generalize unseen data by combining familiar concepts in novel ways [
13,
14,
15,
16,
17,
18]. This capability has been extensively studied in Transformers, particularly in enabling them to generalize on out-of-distribution (OOD) datasets. For instance, Mittal et al. [
19] introduced a compositional attention mechanism to enable flexible and dynamic operations among attention heads. Meanwhile, Csordás et al. [
20] examined the limitations of Transformers on systematic generalization benchmarks, showing that simple techniques, such as scaling embeddings, can effectively enhance their ability to learn systematic generalization tasks. However, compared to standard Transformers, the systematic generalization capabilities of efficient attention mechanisms remain underexplored.
In this work, we investigate the systematic generalization capabilities of efficient attention mechanisms, with a focus on the Linear Transformer [
6], and explore various methods to enhance their performance. In our preliminary experiments on systematic generalization tasks, we identify two major issues in the attention components (
Queries,
Keys, and
Values) that contribute to unstable generalization during training: (i) unconstrained norms of
Queries and
Keys, and (ii) high correlation among
Values across the sequence. The linear attention mechanism operates by computing attention based on
Queries and accumulated
Keys over the sequence. As noted in previous research [
21,
22], this design can lead to instability during training, when the norms of
Queries and
Keys increase dramatically. In our preliminary experiments, we observe that these unconstrained values negatively affect the systematic generalization performance of linear attention mechanisms. Additionally, in non-causal settings, the linear attention mechanism often fails to learn distinct features across the sequence, resulting in highly correlated
Values.
To address the instability in the systematic generalization of the Linear Transformer, we propose two simple yet effective techniques, a normalization term for attention components and an orthogonality loss, both motivated by the issues identified in our preliminary experiments. First, we apply a normalization term to
Queries and
Keys (as in [
21]) to prevent excessively large norm values. We explore various strategies as the normalization term, including
L1,
L2, and
RMS layer normalization [
23]. Additionally, we introduce an orthogonality loss, an auxiliary loss that encourages
Values within a sequence to be orthogonal during training. This loss reduces the correlation among
Values generated from distinct input features, thereby improving the generalization performance.
In summary, the main contributions of this work are as follows:
We investigate the systematic generalization capabilities of linear attention mechanisms and identify key limitations, including unconstrained norms of Queries and Keys and high correlation among Values.
We propose normalization techniques and auxiliary loss functions to address these limitations and improve the stability and generalization performance of linear attention mechanisms.
We evaluate our proposed methods on various systematic generalization tasks, including the sort-of-CLEVR [
24], SCAN [
16], Mathematics dataset [
25], and PCFG [
18] tasks. Our experimental results demonstrate that the proposed methods enhance the training stability and systematic generalization capabilities of the Linear Transformer.
3. Preliminary Experiments and the Design of Proposed Methods
In this section, we present preliminary experiments aimed at evaluating the systematic generalization capabilities of existing linear attention mechanisms, identifying their potential limitations, and using these findings to motivate the design of our proposed methods.
We employ the sort-of-CLEVR task [
24] for the preliminary experiments, which evaluates systematic generalization in models through visual question answering. The task includes different types of questions:
Unary (questions about the properties of a single visual object),
Binary, and
Ternary (questions about relationships among multiple visual objects). We conduct several trials (in each trial, the model is trained on a different training seed) using the Linear Transformer [
6] on the sort-of-CLEVR task, following the setup of prior work [
19]. Detailed experimental settings can be found in
Section 5.
Table 1 and
Figure 1 present the results of the vanilla Transformer and the Linear Transformer across five trials. The results indicate that the Linear Transformer exhibits worse systematic generalization performance than the vanilla Transformer and shows instability in generalization depending on the training seed. For example, while the Linear Transformer achieves comparable accuracy to the vanilla Transformer in
Unary and
Ternary question types in trials 4 and 5, it fails to achieve a similar performance in trials 1, 2, and 3.
To further investigate the source of this instability, we perform additional analyses comparing the attention components from a successful trial (trial 5) and a poor-performing trial (trial 1). These analyses are informed by prior research [
21,
22] suggesting that unconstrained attention components can lead to unstable performance in linear attention. Through this investigation, we identify two key flaws in the attention components—
Queries,
Keys, and
Values—that may contribute to the instability of the linear attention mechanism.
First, we examine the norm distributions of
Queries and
Keys. The results, as shown in
Figure 2, clearly indicate that the
Keys’ norms in the poorly performing trial (trial 1) are significantly higher than those in the successful trial (trial 5) across all attention heads. Similarly, the
Queries’ norms in the poor trial are higher than in the successful trial for all heads, except for head index 4. These findings suggest that the lack of constraints on
Queries and
Keys may lead to unstable generalization performance. Based on this observation, we hypothesize that the relatively high norm distribution contributes to the instability of the linear attention mechanism and investigate normalization methods applied to
Queries and
Keys as a potential solution.
Next, we analyze the representational quality of
Values, which is closely linked to the overall performance of the attention mechanism, through the lens of similarity.
Figure 3 shows cosine similarity heatmaps for
Values across the sequence in the non-causal setting. In the poorly performing trial (
Figure 3a), the
Values exhibit a high correlation, whereas in the successful trial (
Figure 3b), the
Values show a significantly lower correlation. This high correlation may suggest that the attention mechanism fails to transfer distinct information about individual tokens to subsequent layers, ultimately leading to poor generalization performance. To address this issue, we explore the use of auxiliary objective functions to encourage the model to learn more distinct representations across the sequence during training.
In the following section, we build upon these preliminary findings and propose representation learning methods to enhance the systematic generalization performance of the linear attention mechanism.
6. Experiment Result
In this section, we demonstrate the effectiveness of our proposed methods in enhancing the systematic generalization performance of the Linear Transformer across the several tasks introduced in
Section 5. We organize the experiments into two categories based on model architectures: (i) encoder-only architecture and (ii) encoder–decoder architecture.
6.1. Experiments with Encoder-Only Architecture
We first evaluate the proposed methods on the sort-of-CLEVR task, as conducted in the preliminary experiments.
Table 2 presents the mean accuracy of our proposed methods over five trials for each question type. The results indicate that the proposed normalization layers and orthogonality loss function significantly improve both the generalization performance and stability of the Linear Transformer across all question types.
First, applying the proposed normalization layers to Queries and Keys proves effective in enhancing the generalization performance of the Linear Transformer, regardless of the specific normalization method used. While the vanilla Linear Transformer achieves an accuracy of 77.4% on Unary questions, 77.0% on Binary questions, and 58.2% on Ternary questions, the L2 normalization layer improves these results to 99.1%, 83.7%, and 67.8%, respectively. Other normalization methods also show performance gains, with the RMS layer normalization method achieving accuracies of 98.7%, 83.5%, and 66.0%, and the L1 normalization layer reaching 98.8%, 81.9%, and 66.8%.
Second, the proposed orthogonality loss function with a regularization coefficient of also enhances the vanilla Linear Transformer’s performance, achieving accuracies of 98.7%, 81.7%, and 66.5%. Furthermore, combining the orthogonality loss function with the normalization layers results in a significant performance boost with reduced variance. Notably, the combined method with the L2 normalization layer achieves accuracies of 99.0%, 87.6%, and 69.0%, surpassing the systematic generalization performance of the vanilla Transformer.
Next, we extend our comparison to include the Compositional Transformer [
19], which is specifically designed to enhance systematic generalization in Transformers. As shown in
Table 2, our proposed methods—particularly the combination of
L2 normalization and the orthogonality loss—achieve comparable performance with the Compositional Transformer. Additionally, we compare our methods to other efficient attention mechanisms, including the Performer [
35] and Cosformer [
10]. As shown in
Table 2, while both the Performer and Cosformer outperform the Linear Transformer, their generalization performance falls short of the vanilla Transformer. Applying our proposed methods to the Linear Transformer yields a better performance than the Performer and Cosformer, demonstrating the effectiveness of our approach.
Ablation Study and Analysis
We perform an ablation study on the orthogonality loss function combined with the
L2 normalization layer, varying the regularization coefficient from
to
. As shown in
Figure 5, the results demonstrate that the proposed orthogonality loss function consistently improves the performance of the Linear Transformer across all coefficient values in the sort-of-CLEVR task. However, the results also suggest that an optimal regularization coefficient is required to achieve the best performance, with
yielding the highest accuracy on the sort-of-CLEVR task.
Additionally, we investigate the effect of the proposed methods on different hyperparameters of the Linear Transformer, specifically the model dimension and the number of attention heads. As shown in
Table 3, the proposed methods improve the generalization performance of the Linear Transformer across all hyperparameter settings. These results demonstrate the efficacy of the proposed methods across a wide range of conditions.
Next, we analyze the effect of the orthogonality loss function on the
Values. Similar to our approach in
Section 3, we examine the correlation among
Values. As shown in
Figure 6, the proposed methods effectively reduce the correlation among
Values. These results confirm that the orthogonality loss function operates as intended by lowering correlation, thereby enhancing generalization performance.
6.2. Experiments with Encoder–Decoder Architecture
Next, we evaluate the proposed methods in an encoder–decoder architecture on the SCAN, Mathematics dataset, and PCFG tasks.
6.2.1. SCAN Task
Table 4 presents the mean accuracy of the proposed methods on the SCAN task over five trials for both the IID and OOD settings. Interestingly, unlike for the sort-of-CLEVR task, the vanilla Linear Transformer outperforms the vanilla Transformer in the OOD setting, achieving an accuracy of 46.9%. This suggests that the Linear Transformer may have potential in some systematic generalization tasks. Our proposed methods further improve the performance of the Linear Transformer.
First, applying normalization layers improves OOD generalization performance across all normalization strategies, with accuracies of 48.4% for RMS, 49.5% for L1, and 50.6% for L2. Second, the orthogonality loss function also enhances the baseline model, resulting in an OOD accuracy of 48.4%. Finally, combining both methods further improves generalization performance (except for the L1 case), with OOD accuracies of 50.4% for RMS, 44.9% for L1, and 56.3% for L2. Notably, as with the sort-of-CLEVR task, the best performance on the SCAN task is achieved by combining the orthogonality loss function with the L2 normalization layer.
6.2.2. “Add_or_Sub” Problem of the Mathematics Dataset
We also evaluate the proposed methods on the “add_or_sub” problem of the Mathematics dataset. As shown in
Table 5, the performance improvement depends on the combination of the proposed methods.
First, applying normalization layers alone degrades the Linear Transformer’s generalization performance, with OOD accuracies of 65.2% for RMS, 65.6% for L1, and 66.0% for L2. Similarly, the orthogonality loss function also reduces the baseline model’s performance, yielding an OOD accuracy of 64.3%. However, when both methods are applied together, the generalization performance improves (except for in the L1 case), with OOD accuracies of 67.1% for RMS, 65.5% for L1, and 67.2% for L2. These results suggest that the proposed methods, while designed to address issues observed in the preliminary experiments, may not be effective for all systematic generalization tasks. Furthermore, they highlight the importance of identifying the optimal combination of proposed methods to enhance the Linear Transformer’s performance properly.
6.2.3. “Place_Value” Problem of the Mathematics Dataset
Next, we evaluate the proposed methods on the “place_value” problem of the Mathematics dataset. As shown in
Table 6, the vanilla Linear Transformer struggles with training instability and fails to learn the “place_value” problem where the target output length is extremely short (1 in this case). While
RMS normalization does not resolve this instability, the
L1 and
L2 normalization layers effectively address the issue by preventing the norms of attention components from becoming excessively large, achieving OOD accuracies of 21.6% for
L1 and 18.0% for
L2.
In contrast, applying the orthogonality loss function alone does not resolve the instability of the “place_value” problem. Furthermore, as with the “add_or_sub” problem, performance varies depending on the combination of methods. When both normalization layers and the orthogonality loss are applied together, OOD accuracies of 17.9% for L1 and 20.5% for L2 are achieved. These results indicate that normalization strategies without additional trainable parameters, such as L1 and L2, are effective for addressing training instability.
6.2.4. PCFG Task
Finally, we evaluate the proposed methods for the PCFG task. As shown in
Table 7, the vanilla Linear Transformer suffers from training instability, and neither
RMS normalization nor the orthogonality loss resolves this issue, similar to for the “place_value” problem. On the other hand, applying
L1 and
L2 normalization layers effectively mitigates the instability, achieving accuracies of 56.9% and 45.4%, respectively. Furthermore, combining the orthogonality loss function with the normalization layers leads to further improvement, achieving 58.4% for
L1 and 47.4% for
L2.
7. Discussion
In this section, we discuss the limitations of the proposed methods.
First, although the proposed methods effectively improve the systematic generalization of the Linear Transformer, the degree of improvement varies depending on the combination of normalization strategies and orthogonality loss used for each task. While the combination of L2 normalization and orthogonality loss generally performs better than others, the underlying reasons for this are still unclear. Moreover, other combinations sometimes outperform L2 normalization with orthogonality loss on specific tasks, such as the “place_value” problem of the Mathematics dataset and the PCFG task. This need for heuristic tuning to find the optimal configuration can impose an additional burden on researchers wishing to utilize these methods. In future work, we will investigate the relationship between normalization strategies and orthogonality loss.
Second, applying the proposed methods to the Linear Transformer introduces additional computational overhead. In our study, adding normalization layers slightly increased training time, whereas applying the orthogonality loss resulted in, on average, a 1.5 times longer training time across all tasks. Furthermore, although the orthogonality loss function can improve the Linear Transformer’s generalization performance, its computational overhead may become more severe as the sequence length increases.
Third, as shown in
Figure 6, while the orthogonality loss function effectively reduces correlation among
Values across the sequence, some features still exhibit relatively high correlations. An over-penalizing configuration (i.e., a high value of
) in the orthogonality loss function can negatively impact features that should remain similar, thereby degrading the model’s generalization performance, as shown in
Figure 5. One possible solution is incorporating the similarity between input features to implement context-aware regularization methods. We leave this for future work.