1. Introduction
Deep convolutional neural networks (CNNs) were fundamental in revolutionizing the field of computer vision. Similarly, the pioneering induction of the transformer [
1] architecture in natural language processing [
2] has resulted in the AI revolution with large language models (LLMs) such as ChatGPT [
3], Bard [
4], and Llama [
5,
6] among others have yielded impressive performances. The transformer uses a simple similarity computation in the form of an inner product on the learned positional encoded embeddings of a sequence of
n input words. If the matrix
Q and
K contain rows representing the embedding of each word
, then
referred to as the “attention”, which contains the dot product similarity of each input word with every other word in the input sequence. If there are
words being input, referred to as the context, then
, and
.
Like parallel feature maps in a CNN, each layer in the transformer divides the attention calculation into parallel heads. The output from a transformer layer has the same dimensionality as the input and is obtained by a simple matrix computation of
where
V is similar to
K and contains rows of learned position encoded embeddings of input words. For language models, where text generation is carried out based on a given context, the attention matrix is masked triangularly so that future tokens are not visible in the training process. Multiple layers of transformer blocks are used before feeding the result of the last layer to a classification head. Because attention computation in each head is
, for long contexts, this becomes a computational bottleneck. Many approaches have been proposed in the past years to reduce the quadratic time complexity of attention to either linear or sub-quadratic complexity. Some of the notable works include Transformer-XL [
7], Linformer [
8], Longformer [
9], Reformer [
10], Performer [
11], Perceiver-AR [
12], LaMemo [
13], and ∞-former [
14] among others. We provide a brief background to the above-mentioned approaches used in reducing the attention complexity. Then we elaborate on the Transformer LS [
15] that we further enhance in this work.
To handle long sequences efficiently, transformer models employ approximations like sparse, low-rank, and linear attention, which reduce computational cost but can compromise accuracy. Sparse attention [
16,
17] approximates full attention by only computing attention between a subset of token pairs that drastically cuts computation costs. However, this might lead to missing crucial long-range dependencies, as it assumes discarded tokens are insignificant (e.g., Longformer). Therefore, the selection of the attention pattern is critical to balancing efficiency and context capture.
Low-rank attention approximates the full attention matrix using smaller matrices, reducing complexity to
from
. Here the performance is dependent on the context capture by the lower rank ‘
’ (e.g., Transformer LS [
15], Linformer [
8], and Reformer [
10]). This method assumes the full attention matrix’s essential information can be captured by a lower-dimensional representation, balancing computational efficiency with potential information loss.
Linear attention achieves
attention complexity via kernel tricks, but its transformations limit accuracy on long sequences, as this relies on the assumption that the simplified computation retains essential information (e.g., Performer [
11], linear transformers [
18]). This trade-off between speed and expressiveness necessitates careful selection of the approximation method to maintain performance on long-range tasks. Hence, as described above, the trade-off between computation and performance remains a key challenge in transformer architectures.
One of the effective designs towards overcoming the above challenge was proposed in Transformer LS. Although Transformer LS performs efficient compression on the input sequence, this compression results in segment fragmentation that leads to these key shortcomings:
It reduces the dimensionality of the input sequence resulting in a loss of context;
The input sequence is divided into smaller, potentially non-overlapping segments. Since information is isolated within segments, it disrupts the natural flow and continuity of information.
These two factors make it more challenging for the language model to capture overall context and relationships between distant elements in the sequence. Our CacheFormer architecture specifically makes the following contributions that can be summarized as follows:
We develop an innovative attention mechanism where the highly attentive segments are dynamically cached and retrieved in an uncompressed form. This enhances the model’s performance by retrieving the most relevant information;
Long attention uses chunked segments in many existing models; however, this results in loss of information due to segment fragmentation. We improve on this shortcoming by implementing an overlapping attention mechanism via projections of segments that have an overlap with the adjacent segments;
We effectively combine the short attention, segment-based compressed long attention, highly attentive dynamically cached attention, and the overlapping segment attention. This results in an architecture that can efficiently handle long sequences and leads to an improved performance in language modeling.
Long-context language models excel in tasks that require deep understanding of contextual information, such as document summarization, complex question answering, and code analysis. These models enable more accurate and nuanced analysis of lengthy documents in legal, scientific, and medical fields. Further, they serve well in creative content generation and advanced customer support by preserving crucial context and improving reasoning. Our work further enhances the long context handling for the above applications.
2. Background and Related Work
An important earlier work in handling long contexts was presented by Transformer-XL. The authors divided the context into segments and used segment-level recurrence and a corresponding positional encoding to allow it to handle longer contexts. It achieved impressive results on the perplexity and BPC at that time. Linformer [
8] accomplished
O(
n) complexity through linear self-attention. The authors demonstrate that the attention is typically low rank and thus can be approximated by a low-rank matrix. Here, from the original
, and
matrices
,
, and
are projected to lower dimension matrices where
,
where
k < n. Thus, attention
. The output
, i.e., same as the original transformer. Since
k is fixed, the attention complexity is
O(
n).
Although Linformer [
8] reduced the attention complexity significantly, especially if
, note that it cannot be effectively used in autoregressive training and generation, as the projection of
compresses the information along the context, making the masking of attention for future tokens invalid. However, for classification problems where masking of attention is not needed, their architecture is effective in reducing complexity.
Another approach introduced by Longformer [
9] used sparse attention patterns instead of the full dense attention. The authors proposed sliding window attention, where tokens attended only to the nearby past, a dilated sliding window, and a mix of global and sliding window attention where some tokens attend to all tokens while others only attend to nearby tokens. For autoregressive modeling Longformer [
9] used dilated sliding window attention. Another notable work in reducing the attention complexity was performed by Reformer [
10]. The authors’ key idea was to use locality-sensitive hashing, which reduces the attention complexity to
. Note that because of the hashing process, the architecture is not suited for autoregressive modeling.
A different approach to reduce the attention complexity was taken by Performer [
11], where the attention is decomposed as a product of non-linear functions of original query and key matrices referred to as random features. This allows the attention to be encoded more efficiently via the transformer query and key matrices. Further efficient handling of long contexts accomplished by Perceiver AR [
12] divided the input sequence into smaller key/value and query components. These components underwent cross attention in the first layer with a latent
where
l is the size chosen in splitting the input sequence into the query part. The remaining layers operate on the
size instead of the usual
size as in a standard transformer. Although this cross attention on the partitioned input sequence results in efficient handling of long sequences, because of the reduced query size, the equivalent effect is more like a sliding window attention.
More recently, a different approach to handling long contexts was proposed via structure state space models. Structured State space sequence model (SSMs) [
19] proposed an architecture that was based on a new parameterization that can be computed much more efficiently. A variation of the state space approach proposed by Mega [
20] uses a single-head gated attention mechanism equipped with exponential moving average to incorporate inductive bias of position-aware local dependencies into the position-agnostic attention mechanism. They also present its variation with linear time complexity for handling long sequences. Further progression on the state space models yielded better results as demonstrated in Hungry hungry hippos [
21] and Mamba [
22], who achieved a very low perplexity score. Most recently xLSTM [
23] introduced exponential gating and parallelization in LSTMs to achieve extended memory. Some of the model sizes consisted of several billion parameters. We outperform the smaller version of these models with a similar size to ours on the perplexity metric, as shown in
Table 1 and
Table 2.
An interesting concept in handling long sequences was presented by Transformer LS [
15]. Here a sliding window approach is used in handling near-term attention, while a set of compressed segments for the entire past context is used as long-term attention. Both short and long attention are combined in the overall attention. The slight drawback of the approach is that the longer context is effectively used in compressed form and thus may lose some key contextual information in being able to generate the output in an autoregressive environment.
The challenges in long-range sequence modeling are primarily related to the computational cost of self-attention, the difficulty of learning and maintaining long-range dependencies, and the challenge of evaluating long-range understanding. This makes the learning diluted or less precise as the model is unable to capture the global context that limits the model’s ability to handle longer sequences. The current approaches still have limitations when it comes to handling global context, as explained above.
We address this problem by further augmenting the long–short attention by using uncompressed highly attentive segments. Since long–short attention divides the context into equal size segments before projecting each segment to a smaller size, there is potential for a loss of information due to segment fragmentation. We also improve this aspect by using overlapping segments and augment this to the existing long–short model. Thus, our enhanced long–short architecture involves four components in the overall attention, a sliding window attention, long attention based on compressed segments, long attention based on overlapping segments, and uncompressed segmented attention for few high-attentive segments beyond the sliding window part. We describe the details of our design in
Section 3. For completeness, we summarize the composition of a transformer, followed by the ideas of a long–short transformer, that we build upon in our work.
A recent architecture, termed the bi-directional transformer architecture (BiXT) [
24], uses cross attention between different segments of the long sequence, in both directions. This enables a more comprehensive understanding of the entire sequence to capture long-range dependencies. However, this cross-attention mechanism is complex, which adds to the computational overhead and still contributes to some information loss at the boundaries between segments.
LongVQ [
25] presents the potential to improve computational efficiency and long-range dependency handling. It uses vector quantization (VQ) that inherently results in some information loss during the quantization process. LongVQ also uses segmentation and therefore results in fragmentation that could affect performance on long-range tasks requiring fine-grained details.
In another recent development (attention tensorization [
26]), the key innovation is that the input is a higher-dimensional tensor transformation of a simple sequence, enabling the capture of more complex relationships and dependencies among the sequences. However, working with high-dimensional tensors is computationally expensive, and the model’s increased complexity could lead to overfitting.
Capturing long context using innovative attention methods is a very active area of research in language modeling. In the latest research, Deepseek [
27] demonstrates remarkable similarity to our work, where they combine three attention mechanisms to enhance their long-range performance and implement the ‘top-k’ technique to discard the less relevant segments. Moreover our ‘top-k’ retrieval implementation is superior to that of the Deepseek model as we retrieve the most similar segments in uncompressed form.
3. Canonical Transformer
In normal multi-headed attention, if
are the query, key, and value transformations of the input embeddings with a sequence length of
n and embedding dimension of
d, then the scaled dot-product attention in the
-th Head
is given as follows:
where
is the dimension of each head. The output in each transformer layer is obtained by catenation of the output of all heads and transformed further via this projection matrix:
After feeding the embedding of a sequence of one hot encoded word,
x (with position encoding
PE added) through
p transformer layers, a classification layer is used at the output of the last layer to decide the output produced by the transformer. For autoregressive text generation, the classification layer’s final output is equal to the size of the dictionary of unique words in the corpus:
4. Long–Short Transformer
Transformer LS [
15] aggregated the local attention around a smaller window (sliding window), with a projection of the full sequence attention to a smaller size, so that we can efficiently handle long sequences without the quadratic attention complexity. For short attention, the approach here is to use a segment-level sliding window attention, where the input sequence is divided into disjoint segments with length
w (e.g.,
w = 128, and sequence length is 1024). For non-autoregressive applications, all tokens within a segment attend to all tokens within its home segment, as well as w/2 consecutive tokens on the left and right side of its home segment (zero-padding when necessary), resulting in an attention span over a total of 2
w key-value pairs. This is depicted in
Figure 1.
For each query
at the position
t within the
head, the 2
w key-value pairs within its window are
. The short attention
is then given by the following equation:
Execution-wise, the segment-level sliding window attention (referred to as short attention) is more time efficient than the per-token sliding window attention where each token attends to itself and w tokens to its left and right, and its memory consumption scales linearly with sequence length. For auto-regressive applications, the future tokens in the current segment are masked, and only the previous segment is used.
The compression is performed on the feature dimension initially through a projection with dimensions,
→ (
), where
is the embedding dimensionality and
is the target length. The dynamic projection matrix ‘
’ is computed by multiplying
with
length key
→ (
) where (
)
i.e.,
. This product results to ‘
’ with dimensions
. The transpose of this projection matrix
is applied to the key vector →
. This product results in a modified key
, with dimensionality →
thereby compressing its sequence length. Similar compression is performed for the value vector as well. This is a standard dimensionality-reduction technique illustrated in
Figure 2 and is used in popular models like Performer and Transformer LS.
For long attention, the key and value transformations for the input sequence are first divided into segments of fixed size s and then projected to a smaller dimension r, where the projection .
Mathematically, the long attention
(in each head
) as followed by the long–short transformer can be described as follows:
The output of in the
head is the following:
Note that the long attention is effectively performed on a compressed form of K and V, as the projection causes the input sequence of size n to be compressed to size r. This results in full attention now being replaced with the implicit product of two low-rank matrices and , and thus the computational complexity of long attention is reduced from to .
The long–short transformer [
15] integrates the short and long attentions into a single attention. While the short attention can attend to the most recent input, the long attention is in compressed form. Further, the long attention is based on segmentation of the input sequence that may suffer from segment fragmentation as the information in each segment is compressed via the projection mechanism.
5. CacheFormer: Enhanced Long Attention Transformer
The long-term attention in the existing long–short transformer is performed at a compressed level (projection to r causes an effective compression of the input context). Therefore, one of our enhancements is to augment the long attention with an attention that is based on a subset of highly attentive uncompressed segments.
5.1. CacheFormer Long Attention with Segment Caching
The subset of segments that are selected for attention at the uncompressed level is completely dynamic and obtained by the vector magnitude of the compressed segment-wise attention. In simple words, we examine the segment-wise long attention as given by Equation (6). Since , and if there are segments, then each row in contains a set of row vectors of size , as denoted by segmented attention in Equation (8). Here rows represent the original query, and columns represent the target compressed length keys. The resulting attention learns the similarity between the two.
The magnitude of each vector
in Equation (8) indicates the attention of word
to the
compressed segment in the long attention. This phenomenon is also explained in
Appendix A.1.
For execution efficiency, we average the segment attention vectors
in
p consecutive rows starting from
,
and the sequence continues until
is reached. This results in a segment attention matrix
where
m =
n/p. Then we choose top-
k segments by magnitude of each vector in each row of the segment attention matrix
. Note that each entry in the segment attention matrix,
, indicates the segment number that has high attention to the sequence of
p words positioned from
to (
)] in the input context, as shown in Equation (9). Rather than using these attentive segments in compressed form, we extract them from the segmented
K and
V matrices before performing any compression on them. The example in
Appendix A.2 can be accessed for further explanation.
As in cache memory design (in computer architecture), in case of a cache miss, we not only retrieve the needed data from the RAM but also bring a few consecutive following words, as there is high probability that these may be needed in the near future. In the case of segments that we determine most attentive (by the top-k order), we also retrieve u consecutive segments.
To clarify our approach, if the sequence length is
n = 1024 and the long attention segment size = 16, then there will be 64 segments in the uncompressed
K and
V matrices. If the projection size
r = 256 (ratio of 1024/256 = 4), then each segment of size 16 will be compressed to a size of 4, resulting in a long attention matrix
of size 1024 × (64 × 4), i.e., 1024 × 256. If we choose to average
p = 32 consecutive rows in
, and take the magnitude of each of the 1 × 4 vectors in each row (corresponding to the 64 segments), then the segment attention matrix
will be 32 × 64. Taking the index of top
k entries in each row of
will give us the index of the most attentive
k segments to the corresponding set of 32 words in the input sequence. Assembling these top-
k attentive segments and one segment before and one segment after the attentive segment (if
u = 3) will result in 15 segments per row. If
k = 5 is chosen in
top-k and
u = 3 which indicates using (
u of many nearby segments for each attentive segment. Thus, the cache
K and
V matrices
(e.g., 32 × (15 × 16) = 32 × 240 in this case) contain the most-attentive 15 segments in uncompressed form. Note that we stack the
‘
p’ times to match the dimensionality with
Q. From the most attentive
k u segments in
, we can obtain the cache attention
as follows:
Further pictorial representation is available in
Appendix A.3.
5.2. CacheFormer Long Attention with Overlapping Segments
In addition to the original long attention in the long–short transformer that uses the projections on each segment, we augment the existing long attention by using overlapping segments (with 50% overlap in augmented long attention), as shown in
Figure 3. The motivation behind the overlap is to provide context continuity and reduce the effect of segment fragmentation that occurs in long attention. Zero-padding in the beginning segment is added to ensure the same dimensionality for the overlapped long segment attention. Here all the keys
are transformed through a projection matrix
) that consists of the second half of the previous segment and the first half of the following segment. The overlapped long segment attention
similar to Equation (5) is given below and is further explained in
Appendix A.4:
5.3. Aggregated Attention in CacheFormer
The final attention in our enhanced architecture is obtained by aggregating the four attentions discussed earlier:
- (1)
The segment-based compressed long attention, as proposed in Transformer LS;
- (2)
The short attention, that uses segment-wise sliding window in Transformer LS;
- (3)
Our cache attention based on dynamic retrieving of uncompressed high-attention segments,
- (4)
Our overlapping segment-based compressed attention, .
We add the two similar-sized long and overlapping attentions,
and
, and
indicates the catenation of different-sized attentions,
and
. Thus, the final enhanced attention
where
is expressed as follows:
w is the window size in short, i.e., sliding window attention;
r is the compressed projection target size in the long attention;
factor is for retrieving top-k attentive segments;
u − 1 is the number of neighboring segments to retrieve for cache attention;
is the segment size in long attention.
For example, in = 5, u = 3; segment size in short attention is w = 128; segment size in long attention is s = 16; and compression target length of r = 256. Hence, for an input sequence length of 2048, the size of our combined attention matrix is 2048 × 752. The time complexity of the different components in CacheFormer’s attention is as follows:
For the short attention → , where is the sliding window size;
For both long and overlapping long attentions, i.e., , → , where is the compressed output size from each of the long segments;
For cached attention → , where is the number of the top attentive segments, and is the long attention segment size.
Since the dominant term in the above four components is the long attention, the overall time complexity of our enhanced attention is
. Effectively, this is very close to the sliding window attention. To further elaborate on our attention computation in Equation (14), note that the dimensionality of short sliding attention
in the LS Transformer is
and its compressed long attention’s,
dimensionality is
. During our caching mechanism, we augment attentions
and
with dimensionalities
and
respectively, to the long–short attention. Since
and
deal with sequence lengths compressed to similar dimensions, they have similar shapes. Therefore, we can sum up the two attention matrices along the similar dimensions to conserve size and overall attention complexity, whereas our caching attention
and
have different shapes; hence, they cannot be summed up. and concatenation is the only choice. One can refer to
Figure A5 and
Figure A6 in
Appendix A.5 for pictorial representation.
6. Results
Perplexity is a key metric in natural language processing (NLP) that measures how well a model predicts text. It is calculated as the exponential of the average negative log-likelihood per token. Lower perplexity indicates better predictions. Instead of focusing on the absolute best results for perplexity and BPC, which often are achieved through extremely refined training schedules and large model sizes, we focus on the improvements over the baseline, i.e., Transformer LS. Therefore, the results we show are more accurate reflection of the architectural improvements of our design. The baseline architecture is also programmed by us, and the enhancements we propose are programmed in the same implementation and can be selectively turned on or off to see the contribution of each enhancement. We also use similar training schedules for the different architectures being compared.
Table 1 shows the perplexity results performed on a WikiText-103 dataset. It uses a sequence length of 1024, short attention segment size of 128, long attention segment size of 16, compression of the long sequence by a factor of 4, i.e.,
r = 256, and different values of
k in top
cache attention, and neighboring segments’ retrieval
of 1 or 3, which indicates that the segment before the attentive segment and the one after are retrieved.
Table 1.
CacheFormer outperforms Transformer LS across all configurations.
Table 1.
CacheFormer outperforms Transformer LS across all configurations.
Model | Model Size | Perplexity |
---|
Long-Short Baseline | 122.52 million | 23.74 |
CacheFormer (k = 3, u = 1) | 122.52 million | 23.31 |
CacheFormer (k = 5, u = 1) | 122.52 million | 22.75 |
CacheFormer (k = 7, u = 1) | 122.52 million | 21.32 |
CacheFormer (k = 5, u = 3) | 122.52 million | 21.26 |
Note that our enhanced architecture does not cause any increase in the number of model parameters over the baseline long–short transformer. The models used for results in
Table 1 have 12 layers, 12 heads, and an embedding size of 768 (for all architectural variations). For a sequence length of 1024 (which is same as used in GPT-2), using seven segments (
k = 7,
u = 1) yielded considerable improvement in perplexity. Increasing
beyond 7 did not seem to considerably reduce perplexity further.
Since we have two major enhancements of cache attention and overlapping segment-based attention over the baseline,
Table 3 shows an ablation study of the effects of each architectural improvement. An ablation study uses controlled experiments to systematically remove or modify components of a proposed model, thereby isolating and quantifying their individual contributions to overall performance.
Figure 4 depicts the 64 attention vectors for each segment (from compressed long attention after averaging
p = 256 rows) corresponding to the 64 segments during the beginning of training. The highest top-
k magnitude vectors then determine the segment to use in uncompressed form for our cache attention. The darker red color depicts higher magnitude attention vectors, while the blue color indicates lower magnitude vectors.
Table 4 shows the BPC results on the enwik-8 dataset, a benchmark utilized for character-level language modeling, where the models predict the next character in a sequence rather than the next word. Bits per character (BPC) quantifies a language model’s predictive accuracy by measuring the average number of bits needed to encode each character in a text sequence, with lower values indicating more precise predictions. In
Table 4, the 23 million model uses eight layers, eight heads, and an embedding size of 512. The 34.88 million models used 12 layers.
It is interesting to note that the relative improvement in BPC by our enhanced architecture is less pronounced as compared to the perplexity improvements. This could be attributed to the fact that majority of improvements are attributed to cache attention, which uses a few highly attentive uncompressed segments in long attention. This benefits the perplexity, which is a measure of the model’s prediction capability, but BPC not as much, as BPC is more of a compression-efficiency measure of the model.
Table 2.
CacheFormer outperforms several modern language models with comparable model size on perplexity.
Table 2.
CacheFormer outperforms several modern language models with comparable model size on perplexity.
Architecture | Model Size (Millions) | Perplexity |
---|
Long–Short (Baseline) | 122.52 | 23.74 |
Transformer-XL (Standard) | 151 | 24 |
∞-former | 160 | 24.22 |
LaMemo | 151 | 23.77 |
H3 (Hungry Hungry Hippos) | 125 | 23.7 |
Llama | 125 | 23.16 |
Mamba | 125 | 22.49 |
xLSTM [7:1] | 125 | 21.47 |
CacheFormer with Overlapping Segments and Enhanced Caching (k = 7, u = 1) | 122.52 | 21.32 |
Our models for CacheFormer were implemented on a computer using NVIDIA RTX 4090 GPU. The 122.52 million parameter models used an embedding dimension of 768 with 12 layers and 12 attention heads in each layer. These model parameters were chosen so that we can have a fair comparison with equivalent-size models in the reported literature. Increasing the embedding dimensions and/or the number of layers results in a bigger model size with better language modeling capabilities if trained on enough training data. Adam optimizer was used with an initial learning rate of 1 × 10−4. We used a batch size of eight and trained our models for 500,000 iterations. Our GitHub repository provides all the necessary codes to reproduce the results reported in the paper.
We use the perplexity metric on the popular WikiText-103 dataset that is designed to train and assess large language models in tasks that require capturing long-term dependencies. It consists of over 100 million tokens extracted from verified “good” and “featured” articles on Wikipedia.
Table 3.
Ablation study of CacheFormer’s architectural enhancements.
Table 3.
Ablation study of CacheFormer’s architectural enhancements.
| Long–Short baseline | CacheFormer with overlapping segments only | CacheFormer with cache attention only (k = 7, u = 1) | CacheFormer with overlapping segments and cache attention (k = 7, u = 1) |
Model Size
| 122.52 million | 122.52 million | 122.52 million | 122.52 million |
Perplexity
| 23.74 | 23.47 | 21.67 | 21.32 |
Table 4.
Comparison of BPC on the enwik-8 Benchmark.
Table 4.
Comparison of BPC on the enwik-8 Benchmark.
Model | Model Size | BPC |
---|
Long-Short Baseline | 23 million | 1.192 |
CacheFormer (k = 7, u = 1) | 23 million | 1.188 |
Long-Short Baseline | 34.88 million | 1.173 |
CacheFormer (k = 7, u = 1) | 34.88 million | 1.167 |
7. Discussion
Since the uncompressed segments to be used in our cache attention design are dynamically decided based on the input sequence, the execution time increases as more segments (i.e., higher k) are used. When we use a sequence length of 1024, compression r = 256, k = 7, u = 1, and a short attention segment size of 128, then the size of aggregated attention (short, long, cache, overlapping) is 1024 × 624. Since our cache attention mechanism, is completely dynamic and uses the most attentive segments in uncompressed form, we average the attention vectors over p rows (to improve efficiency of execution) as given by Equation (9).
If we use a sequence length of 1024 and an average over 256 rows, then the segments determined by our cache attention mechanism part way through the training of the model appear as shown in
Table 5. Note that to implement the autoregressive behavior, the input sequence cannot attend to a future segment. Our implementation guarantees that the input sequence can only attend to a previous segment. For example, when attending to words 768–1023 in the input sequence, the maximum segment that the cache attention can use is 47 (if the long segment size is 16, then there are 64 segments in the 1024 size sequence).
“Lost in the middle” [
28], one of the important recent papers in handling long contexts, has indicated that current language models do not robustly make use of information in long input contexts. They studied different models and concluded that “performance is often highest when relevant information occurs at the beginning or at the end of the input context and significantly degrades when models must access relevant information in the middle of long contexts”.
Note that our cache attention model addresses this aspect nicely in the sense that it uses attentive segments dynamically regardless of whether they are needed at the beginning or the middle of input context. For example, the last row in
Table 5 indicates the highest attentive segments that are used. Segments 32, 35, and 37 are relatively in the middle of the input context. When we determine the most attentive segment to use in our cache attention, if the neighboring segment parameter count
u > 1, then as we look at the segment index of the next or previous index, a duplicate may occur as the next segment may already be one of the highly attentive segments. Similarly, if the highly attentive segments belong to a future segment, we replace them with one of the allowed segments. Since information segmentation should not occur, the segment we select to be added is the one that is contiguous to an existing high-attention segment.
8. Conclusions
Handling long contexts in an efficient manner without loss of performance is an important area of research in language models. Although many approaches have been recently proposed to address this problem, we present a new innovative solution that is motivated by the cache and virtual memory concepts in computer architecture. In such designs, if there is a cache or page miss, the needed data are retrieved from the disk or RAM. We handle long contexts by diving them into small segments. From the magnitude of the compressed attention vectors, we determine the most attentive segments and then use these in uncompressed form.
Like the cache memory design, we also use consecutive segments near the high-attention segments to improve the language model’s predictive performance. Our results concerning perplexity indicate significant improvement over the baseline architecture that uses short and long compressed attention.
For the BPC, the cache attention mechanism does not show remarkable improvement on the baseline. We conjecture that the BPC that favors compression capability does not benefit from the relevant segment usage that our model provides, which is helpful in model prediction capability. Another advantage of our approach is that the use of high-attention segments is dynamic and depends on the input sequence. Thus, if the model needs to use information in the middle or anywhere in the input context, it is provided in uncompressed form via the high-attention determination of the compressed segments.
As demonstrated in
Table 2, CacheFormer outperforms several similar-sized SOTA language models by 8.5% on average. In our future work, we plan to work on enhancing the efficiency of our implementation that aggregates the four different attention mechanisms. This will enable us to train on larger models and work on diverse task-specific applications of our design.
9. Limitations
The only shortcoming of our approach we feel is that the dynamic segment attention is relatively slow during training. We partially overcome this by initially pretraining the model without dynamic attention and then fine tune it on our dynamic cached attention. Our future work involves applying the cache attention to reduce the model complexity of large language models. Further, we are in the process of creating a hierarchical cache design so that very long contexts can be efficiently handled.
Further, our model sizes and datasets were constrained by computational resources available to us. We used GPU RTX 4090 and therefore could not use larger datasets, such as PG-19, and run larger models with larger embedding size, layers, and heads.