Next Article in Journal
Robust Parking Space Recognition Approach Based on Tightly Coupled Polarized Lidar and Pre-Integration IMU
Previous Article in Journal
Harnessing the Power of Natural Mineral Waters in Bread Formulations: Effects on Chemical, Physical, and Physicochemical Properties
 
 
Font Type:
Arial Georgia Verdana
Font Size:
Aa Aa Aa
Line Spacing:
Column Width:
Background:
Article

Autocorrelation Matrix Knowledge Distillation: A Task-Specific Distillation Method for BERT Models

Department of Electronic Information Engineering, School of Physics and Electronic Information, Henan Polytechnic University, Wenyuan Street, Jiaozuo 454099, China
*
Author to whom correspondence should be addressed.
These authors contributed equally to this work.
Appl. Sci. 2024, 14(20), 9180; https://doi.org/10.3390/app14209180
Submission received: 10 September 2024 / Revised: 5 October 2024 / Accepted: 7 October 2024 / Published: 10 October 2024

Abstract

:
Pre-trained language models perform well in various natural language processing tasks. However, their large number of parameters poses significant challenges for edge devices with limited resources, greatly limiting their application in practical deployment. This paper introduces a simple and efficient method called Autocorrelation Matrix Knowledge Distillation (AMKD), aimed at improving the performance of smaller BERT models for specific tasks and making them more applicable in practical deployment scenarios. The AMKD method effectively captures the relationships between features using the autocorrelation matrix, enabling the student model to learn not only the performance of individual features from the teacher model but also the correlations among these features. Additionally, it addresses the issue of dimensional mismatch between the hidden states of the student and teacher models. Even in cases where the dimensions are smaller, AMKD retains the essential features from the teacher model, thereby minimizing information loss. Experimental results demonstrate that BERTTINY-AMKD outperforms traditional distillation methods and baseline models, achieving an average score of 83.6% on GLUE tasks. This represents a 4.1% improvement over BERTTINY-KD and exceeds the performance of BERT4-PKD and DistilBERT4 by 2.6% and 3.9%, respectively. Moreover, despite having only 13.3% of the parameters of BERTBASE, the BERTTINY-AMKD model retains over 96.3% of the performance of the teacher model, BERTBASE.

1. Introduction

Pre-trained language models based on Transformer architecture have consistently achieved exceptional performance across a wide range of natural language processing (NLP) tasks in recent years. Examples of these models include GPT [1], BERT [2], ET-BERT [3], BFCN [4], and PAL-BERT [5]. Despite excellent performance, these models are limited in their applicability to edge devices due to their large parameter count and high computational demands [6,7]. Consequently, reducing computational costs and model size while maintaining original performance has become a vital area of research [8,9].
Pre-trained language models perform well in many tasks, but in specific subjects or tasks further training and fine-tuning are usually required to improve performance, often consuming significant computing resources and time. To address this, researchers have developed various model compression and acceleration techniques for efficient deployment on edge devices, including pruning [10], quantization [11], knowledge distillation (KD) [12], and low-rank decomposition [13]. Among these, KD significantly reduces computing time and resource consumption by transferring the knowledge of a complex teacher model to a simplified student model. Although the student model may perform slightly worse than the teacher model, it can still maintain sufficient accuracy in resource-limited environments [14]. The key to efficient deployment is balancing computing requirements and model performance [15,16].
Task-specific knowledge distillation has shown significant advantages in improving model performance [17,18]. Raphael successfully compressed the BERT model by distilling task-specific knowledge into a lightweight BiLSTM model [19]. BERT-PKD guides the training of student models by extracting knowledge from the intermediate layers of BERT [20]. In addition, the researchers enabled the student model to more comprehensively capture the abstract knowledge of the teacher model by matching the internal representation of BERT [21]. DynaBERT achieves dynamic compression of the BERT model by adaptively adjusting the model width and depth [22]. BERT of Theseus reduces model complexity by gradually replacing BERT layers [23]. TinyBERT employs a two-stage distillation method to significantly reduce model size and accelerate inference [24]. However, these methods have certain limitations. Previous methods primarily focused on learning the features of the teacher model, while overlooking the interrelationships between these features, which are crucial in many complex natural language processing tasks [25]. In addition, when calculating the hidden states loss, the hidden states dimension of the student model must match that of the teacher model, otherwise additional calculations are required to align the dimensions of the two, which limits the flexibility of the student model structure.
The main research question of this paper is how to enable the student model to more effectively capture and learn the characteristics of the teacher model during the knowledge distillation process, while fundamentally solving the hidden states dimension mismatch between the teacher and student models. The goal is to significantly reduce the model parameters while maintaining high performance. We propose the Autocorrelation Matrix Knowledge Distillation (AMKD) approach to address the limitations of traditional KD methods.Unlike existing techniques such as BERT-PKD, DistilBERT, or TinyBERT, AMKD excels at capturing complex feature interactions while directly resolving hidden states dimension mismatches without additional projection layers.
AMKD not only allows the student model to learn the features of the teacher model but also captures the complex relationships among these features. The relationships between features play a critical role in NLP tasks. By minimizing the differences between the autocorrelation matrices of the student and teacher models, the student model learns not only individual feature behaviors but also the relationships between features in the teacher model. This approach effectively captures high-order feature interactions, enabling the student model to better understand and learn the complex features of the teacher model. More importantly, AMKD effectively addresses the issue of hidden states dimension mismatch between student and teacher models without requiring additional projection layers for dimension alignment. AMKD retains the essential information from the teacher model, minimizing information loss during training. Compared to traditional methods, AMKD significantly enhances the performance of the student model, demonstrating superior flexibility and robustness.
We evaluated AMKD on multiple NLP tasks and demonstrated that it significantly enhances the performance of small BERT models in specific tasks. AMKD addresses the limitations of traditional knowledge distillation in learning complex feature relationships and overcoming hidden states dimension mismatches. The resulting student models exhibit greater flexibility and robustness, offering new approaches for model compression and acceleration in resource-constrained environments.

2. Preliminaries

This section begins with an introduction to the Transformer architecture and the core components of this framework [26], followed by a discussion of KD techniques [14]. Our proposed AMKD method is developed based on these foundational concepts.

2.1. Transformer Architecture

The Transformer architecture, introduced by Vaswani et al. in 2017 [26], is a widely adopted deep learning model, particularly in natural language processing and machine translation. Unlike traditional models such as RNN [27] and CNN [28], the Transformer captures long-range dependencies through self-attention. The model consists of an encoder, which converts input sequences into hidden representations, and a decoder, which generates outputs from these representations. The key components include multi-head attention, feed-forward networks [29], residual connections [30], and layer normalization [31].
Multi-Head Attention Mechanism. The multi-head attention mechanism is a fundamental component of the Transformer architecture, enabling the model to attend to multiple segments of the input sequence concurrently. The computation of attention involves several steps, starting with the calculation of scaled dot-product attention. Given a query matrix Q , a key matrix K , and a value matrix V , the attention scores are derived using the following equation:
Attention ( Q , K , V ) = softmax Q K T d k V ,
where Q K T denotes the dot product between the query and key matrices, reflecting their similarity, and d k represents the dimensionality of the key, which is used to scale the dot product to prevent gradient instability. The softmax function is subsequently applied to transform these similarities into weights, which are then multiplied by the value matrix V to yield the weighted output.
In the multi-head attention mechanism, the query, key, and value matrices Q , K , and V are each projected into h different subspaces through separate linear projections. These subspaces allow the model to perform attention computations in parallel across multiple different representation spaces. Specifically, for each head i, the input matrices are transformed into the subspaces using different projection matrices W i Q , W i K , and W i V , and the attention computation is performed as follows:
head i = Attention ( Q W i Q , K W i K , V W i V ) ,
where W i Q R d model × d k , W i K R d model × d k , and W i V R d model × d v are the linear transformation matrices that project the queries, keys, and values into different subspaces. The outputs of all the heads are concatenated and then processed through a linear transformation matrix W O to produce the final output:
MultiHead ( Q , K , V ) = Concat ( head 1 , , head h ) W O ,
where W O R h d v × d model is the linear transformation matrix that maps the concatenated multi-head outputs back to the original dimensionality. This process allows the multi-head attention mechanism to capture different aspects of the input sequence from multiple subspaces, enhancing the model’s representational capability.
Feed-Forward Networks. In each encoder and decoder layer of the Transformer, there is also an independent feed-forward network applied to every position in the sequence. This network consists of two linear transformations with a ReLU activation function in between:
FFN ( x ) = max ( 0 , x W 1 + b 1 ) W 2 + b 2
Residual Connections and Layer Normalization. Following the feed-forward network, the output from each sub-layer undergoes processing through residual connections and layer normalization, as defined below:
LayerNorm ( x + Sublayer ( x ) )

2.2. Knowledge Distillation

KD is a technique that extracts knowledge from a complex deep model (teacher model) and transfers it to a smaller model (student model) [14]. By imitating the output distribution of the teacher model, the student model can reduce the complexity of the model and the demand for computing resources while maintaining high performance. This method is often used to pursue a balance between performance and efficiency when deploying models. The goal of KD is to make the output of the student model as close as possible to the output of the teacher model by optimizing a loss function. This loss function can be expressed as
L KD = z Z L f S ( z ) , f T ( z ) ,
where L ( · ) is the loss function that measures the difference between the student model and the teacher model, f S ( z ) and f T ( z ) represent the output results of the student model and the teacher model, respectively, z is the input data, and Z is the training data set. By minimizing this loss function, the student model can effectively learn the knowledge of the teacher model.

3. Method

This section presents a novel distillation approach, AMKD. AMKD effectively improves the ability of the student model to capture and understand semantic relationships from the teacher model, resolving the dimensional mismatch issues encountered during the distillation process.

3.1. Overview of AMKD

The core idea of AMKD is to enable the small student model S to learn from the knowledge of the large teacher model T , thereby improving the performance of the student model on specific tasks under the guidance of the teacher model. Both the student and teacher models are composed of an embedding layer, several Transformer layers, and a prediction layer. Each Transformer layer consists of an attention layer and a hidden layer.
AMKD consists of four distillation components: prediction layer distillation, hidden states distillation, attention matrix distillation, and embedding layer distillation. Assuming that the teacher model contains M Transformer layers and the student model contains N Transformer layers, Figure 1 visually presents the overall distillation framework. H i S and H j T represent the hidden states of the i-th layer of the student model and the j-th layer of the teacher model, respectively, D and D are the hidden states dimensions of the two, l represents the input sequence length, and T is the matrix transpose operation. h represents the number of attention heads, and A i S and A j T represent the attention matrices of the student model and the teacher model, respectively.
In contrast to previous approaches, we leverage the autocorrelation matrix of hidden states to capture complex feature relationships, rather than focusing solely on the individual features of the teacher model. The student model learns not only the behavior of each feature from the teacher model but also the interactions between features. Furthermore, AMKD effectively resolves the issue of hidden states dimensional mismatches between the teacher and student models, allowing for greater flexibility in the configuration of the student model while reducing computational costs. Next, we will provide a detailed description of the four distillation strategies in AMKD.

3.2. Prediction Layer Distillation

To enable the student model to mimic the output from the prediction layer of the teacher model, we employ the KD technique introduced by Hinton et al. [14]. In particular, we compute the Kullback–Leibler (KL) divergence between the logits of the student and teacher models to more closely align the outputs of the student model with those of the teacher. The calculation is presented as follows:
L soft = i softmax z i T t log softmax z i T t softmax z i S t ,
where z T and z S represent the logits output by the teacher and student models, respectively. Logits refer to the raw scores produced by a neural network in a classification task, representing un-normalized predictions before the application of the softmax function. The parameter t serves as a temperature value to smooth the probability distributions. The index i is used to denote the different categories in the classification task. Additionally, we consider the hard label loss for the student model, L hard , which is calculated as the cross-entropy between the true labels and the predicted probabilities.
The total loss of the student model consists of two main components: the hard label loss and the soft label loss. The total loss is defined as follows:
L pred = ( 1 α ) · L hard + α · L soft ,
where α is a weighting factor, chosen from the range [ 0 , 1 ] , that balances the hard label loss and the soft label loss.

3.3. Attention Matrix Distillation

Attention matrix distillation aims to transfer the attention weight information from the Transformer structure in the teacher model to the student model. These weights capture rich linguistic information, which is crucial for natural language understanding [32]. By applying attention matrix distillation, the student model can better inherit the language comprehension capabilities of the teacher model. The loss function for attention matrix distillation is defined as follows:
L atten = 1 h i = 1 N k = 1 h MSE ( A i k S , A j k T ) ,
where N represents the quantity of Transformer layers in the student model, h indicates the quantity of attention heads, A i k S R l × l represents the attention matrix for the k-th head in the i-th layer of the student model, l signifies the length of the input, and MSE ( · ) is the mean squared error loss. The i-th layer of the student model learns from the corresponding j-th layer of the teacher model.

3.4. Hidden States Distillation

In the Transformer model, the hidden states are an important part of the intermediate representation of each layer, capturing the semantic information and features of the input sequence. In traditional knowledge distillation methods, the student model usually needs to use a projection layer to solve the problem of mismatching the hidden states dimension with the teacher model. The hidden states matrix H i S R l × D of the student model and the hidden states matrix H j T R l × D of the teacher model cannot be directly compared due to their different dimensions. Here, l represents the length of the input sequence, D and D represent the hidden states dimensions of the teacher model and the student model, respectively, and D is usually smaller than D. To solve this problem, traditional methods introduce a projection layer P R D × D to project the hidden states of the student model to a dimension that matches the teacher model:
H i S P R l × D
However, this method has two major limitations: first, the projection operation adds additional computational complexity, especially in large-scale model scenarios, where the computational overhead increases significantly; second, the dimension conversion process may lead to information loss, meaning that the student model is unable to fully capture the characteristic performance of the teacher model.
To overcome these problems, AMKD takes a different approach: by calculating the autocorrelation matrix of the hidden states, explicit dimensional projection is avoided. Specifically, AMKD calculates the autocorrelation matrix of the hidden states matrix of each layer, converting it from the ( length , hidden _ size ) shape to a unified ( length , length ) shape; that is,
C i S = H i S · ( H i S ) T R l × l , C j T = H j T · ( H j T ) T R l × l ,
where C represents the autocorrelation matrix. By calculating the autocorrelation matrix, AMKD effectively solves the dimension mismatch problem. No matter how different the hidden states dimensions of the student model and the teacher model are, after autocorrelation matrix conversion, they can be compared on a unified l × l dimension.
The autocorrelation matrix can not only capture the performance of a single feature, but also capture the high-order relationship between features. The formula is as follows:
C u v T = k = 1 D H u k T H v k T , C u v S = k = 1 D H u k S H v k S ,
where C u v represents the inner product between the u-th row and the v-th row of the hidden states matrix H . As shown in Figure 2, the left figure shows the autocorrelation matrix of the teacher model, and the right figure shows the autocorrelation matrix of the student model. By comparing these two matrices, we can see how the student model learns the complex feature relationships in the teacher model through AMKD. The color depth represents the correlation between features, and the darker the color, the stronger the correlation.
In AMKD, we use mean squared error (MSE) to measure the difference between the autocorrelation matrices of the student model and the teacher model, and the hidden states distillation loss is defined as
L hidden = i = 1 N MSE C i S , C j T ,
By minimizing the MSE of the autocorrelation matrix, AMKD ensures that the student model can not only learn the individual features of the teacher model, but also capture the complex relationships between the features. This significantly improves the performance of the student model in complex tasks.

3.5. Embedding Layer Distillation

We also performed embedding layer distillation to enable the student model to learn the semantic information from the embedding layer of the teacher model. This process is similar to computing the hidden states loss: by calculating the inner product of the embedding layer outputs from both the student model and the teacher model, the embedding layer distillation loss can be derived. The formula is as follows:
L embed = MSE ( ( E S · ( E S ) T ) , ( E T · ( E T ) T ) ) ,
where E S and E T represent the embedding layer outputs of the student model and the teacher model, respectively, and T represents the matrix transpose operation. This approach enables the student model to effectively learn the key features of the embedding layer in the teacher model while also capturing the complex relationships between embedding layer features.

3.6. Overall Loss Function

We combined prediction layer distillation, attention matrix distillation, hidden states distillation, and embedding layer distillation. The total distillation loss of AMKD can be expressed as follows:
L AMKD = L pred + L atten + L hidden + L embed = ( 1 α ) · L hard + α · L soft + ( L atten + L hidden + L embed ) ,
where α is a weighting factor, chosen from the range [ 0 , 1 ] , that balances the hard label loss and the soft label loss.

3.7. Data Augmentation

To improve the generalization and robustness of the student model, we adopted a simple yet effective data augmentation technique. The specific steps are detailed in Algorithm 1. By randomly shuffling the positions of tokens in the input sequence, we increased data diversity and reduced the dependence of the model on specific sequences, thereby enhancing the adaptability and stability of the student model. This method is easy to implement and has a low computational overhead.
Algorithm 1 Short Disorder Data Augmentation
Require: 
A list of tokens tokens, probability distribution p of 5 disorder types
Ensure: 
A new disordered list of tokens
  1:
o u t p u t s t o k e n s [ : ]
  2:
l l e n ( t o k e n s )
  3:
for  i = 0   to  l 2   do
  4:
     p e r m randomly choose from { 0 , 1 , 2 , 3 , 4 } based on p
  5:
    if  p e r m = = 1  then
  6:
        Swap o u t p u t s [ i ] and o u t p u t s [ i + 1 ]
  7:
         i i + 1
  8:
    else if  p e r m = = 2  and  i < l 2  then
  9:
        Swap o u t p u t s [ i ] and o u t p u t s [ i + 2 ]
 10:
         i i + 2
 11:
    else if  p e r m = = 3  and  i < l 2  then
 12:
        Rotate o u t p u t s [ i ] , o u t p u t s [ i + 1 ] , o u t p u t s [ i + 2 ] to the left
 13:
         i i + 2
 14:
    else if  p e r m = = 4  and  i < l 2  then
 15:
        Rotate o u t p u t s [ i ] , o u t p u t s [ i + 2 ] , o u t p u t s [ i + 1 ] to the right
 16:
         i i + 2
 17:
    end if
 18:
end for
 19:
return  o u t p u t s

3.8. Skip-Layer Distillation

In the skip-layer distillation method, the student model acquires knowledge by selectively learning different depth levels of the teacher model. Assuming that the number of layers of the teacher model is M and the number of layers of the student model is N, we first calculate the average skip-layer interval by dividing M by N. When M and N are not divisible, the skip-layer interval may be a non-integer. To address this, we adopt a rounding-down strategy, taking the integer part of the result as the skip-layer interval. For the remaining layers, additional learning is performed based on their importance to ensure that key information is not ignored. This skip-layer selection strategy can more comprehensively acquire knowledge at different levels of the teacher model, thereby effectively improving the overall performance of the student model.

4. Experiments

In this section, we evaluated the proposed AMKD method across various NLP tasks.

4.1. Datasets

We evaluated the proposed method using the GLUE dataset. The GLUE benchmark includes a variety of natural language understanding tasks, designed to comprehensively evaluate how well a model can handle different linguistic phenomena [33]. These tasks include Multi-Genre Natural Language Inference (MNLI), which involves determining the relationship between a premise and a hypothesis [34]; Quora Question Pairs (QQP), which evaluates the semantic equivalence between question pairs [35]; Question Answering Natural Language Inference (QNLI), which determines whether a given context contains the answer to a question; Stanford Sentiment Treebank (SST-2), which involves the sentiment classification of movie review sentences [36]; Microsoft Research Paraphrase Corpus (MRPC), which assesses whether sentence pairs are paraphrases [37]; and Recognizing Textual Entailment (RTE), which involves a determination of whether one sentence entails another [38].
For the machine reading comprehension task, we used the Stanford Question Answering Dataset (SQuAD v1.1) for evaluation [39]. This dataset, created by Rajpurkar et al. in 2016, contains 100,000 question–answer pairs collected through crowdsourcing, with the task of finding the text snippet within a Wikipedia passage that answers the question. SQuAD v2.0 builds on this by introducing cases where no clear answer is available, making the task more realistic and requiring the model to determine whether an answer exists and appropriately handle cases with no answer [40].

4.2. Training Details and Baselines

In our experiment, we used a BERT Base model fine-tuned for specific tasks as the teacher model [2]. The BERT Base model has 12 layers ( M = 12 ), a hidden size of 768 ( D = 768 ), a feed-forward size of 3072 ( D i = 3072 ), and 12 attention heads ( h = 12 ), totaling 109 million parameters. The student model was a pre-trained language model, BERT Tiny , with a smaller parameter size. The BERT Tiny model has 4 layers ( N = 4 ), a hidden size of 312 ( D = 312 ), a feed-forward size of 1200 ( D i = 1200 ), and 12 attention heads ( h = 12 ), totaling 14.5 million parameters.
The distillation temperature t was chosen from the set {1, 4, 8}, while the value of α was chosen from {0.2, 0.5, 0.8, 1}. The learning rate was chosen from { 1 × 10 5 , 2 × 10 5 , 5 × 10 5 } , and the batch size was chosen from {16, 32}. For the fine-tuning task, the number of epochs was set to 5 without data augmentation, whereas the distillation task was conducted over 25 epochs with data augmentation. The sequence length for task-specific distillation was consistently set to 128. During the skip learning process, we set k = 3 , meaning that BERT Tiny learned from BERT Base at every third layer. The best results were selected from each experiment.
We compared our method, BERT Tiny -AMKD, with several baseline methods, including BERT-KD [14], which represents the traditional prediction layer distillation method, BERT-PKD, and DistilBERT. We fine-tuned the published pre-trained models BERT 4 -PKD [20] and DistilBERT 4 [19] on each specific task using the recommended hyperparameters. Table 1 provides a structural comparison of the models used in our experiments. BERTsmall is one of the 24 smaller BERT models available from the official Google repository [41]. Table 2 offers a detailed comparison of the distillation positions employed by the baseline methods and our proposed AMKD method. Our AMKD approach enables the student to more comprehensively learn knowledge from the teacher model.

4.3. Experimental Results on GLUE

The experimental results are presented in Table 3. BERTBASE (Teacher) represents our implementation of the BERT teacher model [2]. BERTTINY-FT shows the results of directly fine-tuning the pre-trained BERTTINY model on each task. BERTTINY-KD represents the results of applying the prediction distillation method to the BERTTINY model [14]. BERTTINY-AMKD is our proposed method. The results of TinyBERT4 are obtained by directly calculating the MSE of the hidden states. Other baseline data are obtained by directly fine-tuning the publicly available pre-trained models [19,20]. Accuracy is used as the evaluation metric for all tasks except for the MRPC task, which uses F1 score. The number of training samples for each dataset is provided below the dataset name.
The results from the four-layer student models suggest that a substantial reduction in model size results in a significant performance gap between BERTTINY-FT (or BERTSMALL) and BERTBASE.
BERTTINY-AMKD outperforms BERTTINY-FT in all GLUE tasks, with an average improvement of 4.7%, indicating that the proposed AMKD method can significantly enhance the performance of small models in various downstream tasks. In addition, BERTTINY-AMKD achieved an average score of 83.6% in the GLUE task, which is significantly higher than the 79.5% of the traditional knowledge distillation method BERTTINY-KD.
BERTTINY-AMKD achieves 96.3% of the performance of the BERTBASE teacher model while using only 13.3% of the parameters. This demonstrates its ability to maintain high accuracy while significantly reducing computational resources. BERTTINY-AMKD outperforms the four-layer KD baseline models BERT4-PKD and DistilBERT4 by 2.6% and 3.9%, respectively, while utilizing only 28% of their parameters.
Compared to TinyBERT4, BERTTINY-AMKD has obvious advantages in performance and flexibility. TinyBERT4 uses MSE loss directly on hidden states, while AMKD not only learns individual features but also captures high-order dependencies through the autocorrelation matrix. This makes BERTTINY-AMKD outperform TinyBERT4 on average in all GLUE tasks, an improvement of 1.7%. Figure 3 shows the average performance of different models on the GLUE task. It can be seen intuitively that BERT TINY -AMKD has significantly improved performance compared to other models.
In the MNLI task, BERTTINY-AMKD achieved accuracy rates of 82.1% on MNLI-m and 81.7% on MNLI-mm, significantly higher than the 76.5% and 76.1% achieved by BERTTINY-KD. Furthermore, it outperformed the 79.9% and 79.3% achieved by BERT4-PKD, as well as the 78.9% and 78.0% achieved by DistilBERT4. On the CoLA (Corpus of Linguistic Acceptability) dataset, despite the significant performance gap between all four layer models and the teacher model, BERTTINY-AMKD still achieved notable improvements. To further analyze the performance of different methods on GLUE tasks, we give a detailed comparison of several methods in Figure 4. AMKD outperforms other models in every task, with a performance almost close to that of the teacher model.
BERT4-PKD and DistilBERT4 initialize the student model from specific layers of BERTBASE during training, which requires the student model to match the hidden states dimensions of the teacher model. In contrast, our proposed AMKD method allows for greater flexibility in the configuration of the student model. AMKD not only allows the student model to learn the features of the teacher model, but also captures the complex relationships among these features.

4.4. Experimental Results on SQuAD

We further validated the effectiveness of AMKD in question answering (QA) tasks using the SQuAD v1.1 and SQuAD v2.0 datasets. Unlike the GLUE tasks, QA tasks require more nuanced knowledge to identify the correct answer, which makes the learning process more complex. We did not use data augmentation in the experiments. The experimental results are shown in Table 4.
BERTTINY-AMKD outperformed the two four-layer baseline models, BERT4-PKD and DistilBERT4, on both the SQuAD v1.1 and SQuAD v2.0 datasets. These results once again demonstrate the effectiveness of the AMKD method in capturing and learning teacher model features. Even in challenging tasks like question answering, the performance of the student model is significantly improved.

4.5. Performance of AMKD-Last vs. AMKD-Skip

We compared two techniques: BERTTINY(AMKD-Last) and BERTTINY(AMKD-Skip). AMKD-Last means that the student model only learns from the last few layers of the teacher model, while AMKD-Skip lets the student model extract knowledge from the teacher model every k layers. Table 5 summarizes the experimental results of these two AMKD methods.
Although both methods outperform the KD baseline, BERTTINY(AMKD-Skip) shows slightly better performance than BERTTINY(AMKD-Last). This performance advantage is likely due to the ability of AMKD-Skip to extract information every k layers, capturing a broader range of semantic representations from lower to higher layers. In contrast, AMKD-Last focuses only on the final layers, resulting in less comprehensive semantic information.

4.6. Ablation Studies

This section explores the impact of different parts of the Transformer architecture on the effectiveness of distillation through ablation experiments.
The ablation experiment results in Table 6 demonstrate that removing different distillation objectives significantly affects the learning performance of AMKD. The removal of attention distillation (w/o Atten) has the most substantial impact on overall model performance, particularly on the MNLI task. Similarly, removing hidden layer distillation (w/o Hidden) also results in a significant performance reduction. In contrast, removing embedding layer distillation (w/o Embed) causes a smaller decrease in performance.

5. Conclusions

BERTTINY-AMKD significantly outperforms traditional distillation methods and other four-layer baseline models. In the GLUE task, BERTTINY-AMKD achieves an average score of 83.6%, which is 4.1% higher than BERTTINY-KD, 2.6% higher than BERT4-PKD, and 3.9% higher than DistilBERT4. In the SQuAD benchmark test, the EM and F1 scores of BERTTINY-AMKD in SQuAD v1.1 increased to 72.1% and 81.8%, respectively, also surpassing other four-layer baseline models. Ablation experiments show that attention matrix distillation and hidden states distillation are crucial to performance improvement, while the layer-skipping distillation strategy can more comprehensively obtain different levels of knowledge from the teacher model. These experimental results demonstrate that the AMKD method can significantly improve the accuracy of small models in natural language processing tasks.
By introducing the autocorrelation matrix, AMKD not only effectively captures the complex relationships between the features of the teacher model, but also skillfully handles the dimensional differences between the student model and the teacher model, reducing information loss. This method is suitable not only for classification tasks but also for a range of applications, including reading comprehension and regression tasks. AMKD significantly improves the performance of small BERT models in NLP tasks, providing an efficient and flexible solution for deploying pre-trained language models in resource-constrained environments. The AMKD method we proposed is mainly used for the distillation of specific tasks. If conditions permit, distillation can be performed in the pre-training stage in the future to generate a general small BERT model suitable for a wider range of tasks. In addition, it is difficult to adapt the fixed temperature parameter t in the distillation process to different samples. In the future, it will be possible to consider dynamically adjusting the temperature to better adapt to the needs of different samples.

Author Contributions

Conceptualization, K.Z.; methodology, J.L.; software, J.L.; validation, J.L. and B.W.; formal analysis, J.L.; investigation, H.M.; resources, K.Z.; data curation, B.W.; writing—original draft preparation, J.L.; writing—review and editing, K.Z.; visualization, B.W.; supervision, K.Z.; project administration, K.Z.; funding acquisition, K.Z. All authors have read and agreed to the published version of the manuscript.

Funding

This research was funded by Henan Provincial Science and Technology Research Project OF FUNDER grant number 232102211005.

Institutional Review Board Statement

Not applicable.

Informed Consent Statement

Not applicable.

Data Availability Statement

The original contributions presented in the study are included in the article; further inquiries can be directed to the corresponding author.

Conflicts of Interest

The authors declare no conflicts of interest.

Abbreviations

The following abbreviations are used in this manuscript:
AMKDAutocorrelation Matrix Knowledge Distillation
NLPNatural Language Processing
KDKnowledge Distillation

References

  1. Achiam, O.J.; Adler, S.; Agarwal, S.; Ahmad, L.; Akkaya, I.; Aleman, F.L.; Almeida, D.; Altenschmidt, J.; Altman, S.; Anadkat, S.; et al. GPT-4 Technical Report; OpenAI: San Francisco, CA, USA, 2023. [Google Scholar]
  2. Devlin, J.; Chang, M.W.; Lee, K.; Toutanova, K. Bert: Pre-training of deep bidirectional transformers for language understanding. arXiv 2018, arXiv:1810.04805. [Google Scholar]
  3. Lin, X.; Xiong, G.; Gou, G.; Li, Z.; Shi, J.; Yu, J. ET-BERT: A Contextualized Datagram Representation with Pre-training Transformers for Encrypted Traffic Classification. In Proceedings of the ACM Web Conference 2022, Lyon, France, 25–29 April 2022. [Google Scholar]
  4. Shi, Z.; Luktarhan, N.; Song, Y.; Tian, G. BFCN: A Novel Classification Method of Encrypted Traffic Based on BERT and CNN. Electronics 2023, 12, 516. [Google Scholar] [CrossRef]
  5. Zheng, W.; Lu, S.; Cai, Z.; Wang, R.; Wang, L.; Yin, L. PAL-BERT: An Improved Question Answering Model. Comput. Model. Eng. Sci. 2023, 139, 2729–2745. [Google Scholar] [CrossRef]
  6. Wu, T.; Hou, C.; Zhao, Z.; Lao, S.; Li, J.; Wong, N.; Yang, Y. Weight-Inherited Distillation for Task-Agnostic BERT Compression. arXiv 2023, arXiv:2305.09098. [Google Scholar]
  7. Piao, T.; Cho, I.; Kang, U. SensiMix: Sensitivity-Aware 8-bit index & 1-bit value mixed precision quantization for BERT compression. PLoS ONE 2022, 17, e0265621. [Google Scholar]
  8. Liu, Y.; Lin, Z.; Yuan, F. ROSITA: Refined BERT cOmpreSsion with InTegrAted techniques. arXiv 2021, arXiv:2103.11367. [Google Scholar] [CrossRef]
  9. Lin, Y.J.; Chen, K.Y.; Kao, H.Y. LAD: Layer-Wise Adaptive Distillation for BERT Model Compression. Sensors 2023, 23, 1483. [Google Scholar] [CrossRef]
  10. Hoefler, T.; Alistarh, D.; Ben-Nun, T.; Dryden, N.; Peste, A. Sparsity in Deep Learning: Pruning and growth for efficient inference and training in neural networks. arXiv 2021, arXiv:2102.00554. [Google Scholar]
  11. Zhang, J.; Zhou, Y.; Saab, R. Post-training Quantization for Neural Networks with Provable Guarantees. SIAM J. Math. Data Sci. 2022, 5, 373–399. [Google Scholar] [CrossRef]
  12. Muksimova, S.; Umirzakova, S.; Mardieva, S.; Cho, Y.I. Enhancing Medical Image Denoising with Innovative Teacher–Student Model-Based Approaches for Precision Diagnostics. Sensors 2023, 23, 9502. [Google Scholar] [CrossRef]
  13. Kaushal, A.; Vaidhya, T.; Rish, I. LORD: Low Rank Decomposition Of Monolingual Code LLMs For One-Shot Compression. arXiv 2023, arXiv:2309.14021. [Google Scholar]
  14. Hinton, G.; Vinyals, O.; Dean, J. Distilling the knowledge in a neural network. arXiv 2015, arXiv:1503.02531. [Google Scholar]
  15. Qi, P.; Zhou, X.; Ding, Y.; Zhang, Z.; Zheng, S.; Li, Z. FedBKD: Heterogenous Federated Learning via Bidirectional Knowledge Distillation for Modulation Classification in IoT-Edge System. IEEE J. Sel. Top. Signal Process. 2023, 17, 189–204. [Google Scholar] [CrossRef]
  16. Jiao, X.; Yin, Y.; Shang, L.; Jiang, X.; Chen, X.; Li, L.; Wang, F.; Liu, Q. LightMBERT: A Simple Yet Effective Method for Multilingual BERT Distillation. arXiv 2021, arXiv:2103.06418. [Google Scholar]
  17. Jiang, M.; Lin, J.; Wang, Z.J. ShuffleCount: Task-Specific Knowledge Distillation for Crowd Counting. In Proceedings of the 2021 IEEE International Conference on Image Processing (ICIP), Anchorage, AK, USA, 19–22 September 2021; pp. 999–1003. [Google Scholar]
  18. Wu, Y.; Chanda, S.; Hosseinzadeh, M.; Liu, Z.; Wang, Y. Few-Shot Learning of Compact Models via Task-Specific Meta Distillation. In Proceedings of the 2023 IEEE/CVF Winter Conference on Applications of Computer Vision (WACV), Waikoloa, HI, USA, 3–7 January 2022; pp. 6254–6263. [Google Scholar]
  19. Tang, R.; Lu, Y.; Liu, L.; Mou, L.; Vechtomova, O.; Lin, J. Distilling task-specific knowledge from bert into simple neural networks. arXiv 2019, arXiv:1903.12136. [Google Scholar]
  20. Sun, S.; Cheng, Y.; Gan, Z.; Liu, J. Patient knowledge distillation for bert model compression. arXiv 2019, arXiv:1908.09355. [Google Scholar]
  21. Aguilar, G.; Ling, Y.; Zhang, Y.; Yao, B.; Fan, X.; Guo, E. Knowledge Distillation from Internal Representations. arXiv 2019, arXiv:1910.03723. [Google Scholar] [CrossRef]
  22. Hou, L.; Huang, Z.; Shang, L.; Jiang, X.; Chen, X.; Liu, Q. Dynabert: Dynamic bert with adaptive width and depth. Adv. Neural Inf. Process. Syst. 2020, 33, 9782–9793. [Google Scholar]
  23. Xu, C.; Zhou, W.; Ge, T.; Wei, F.; Zhou, M. Bert-of-theseus: Compressing bert by progressive module replacing. arXiv 2020, arXiv:2002.02925. [Google Scholar]
  24. Jiao, X.; Yin, Y.; Shang, L.; Jiang, X.; Chen, X.; Li, L.; Wang, F.; Liu, Q. Tinybert: Distilling bert for natural language understanding. arXiv 2019, arXiv:1909.10351. [Google Scholar]
  25. Sanh, V.; Debut, L.; Chaumond, J.; Wolf, T. DistilBERT, a distilled version of BERT: Smaller, faster, cheaper and lighter. arXiv 2019, arXiv:1910.01108. [Google Scholar]
  26. Vaswani, A.; Shazeer, N.; Parmar, N.; Uszkoreit, J.; Jones, L.; Gomez, A.N.; Kaiser, Ł.; Polosukhin, I. Attention is all you need. In Proceedings of the Advances in Neural Information Processing Systems, Long Beach, CA, USA, 4–9 December 2017; Volume 30.
  27. Luo, Y.; Yu, J. Music Source Separation With Band-Split RNN. IEEE/ACM Trans. Audio Speech Lang. Process. 2022, 31, 1893–1901. [Google Scholar] [CrossRef]
  28. Alzubaidi, L.; Zhang, J.; Humaidi, A.J.; Al-dujaili, A.; Duan, Y.; Al-Shamma, O.; Santamaría, J.I.; Fadhel, M.A.; Al-Amidie, M.; Farhan, L. Review of deep learning: Concepts, CNN architectures, challenges, applications, future directions. J. Big Data 2021, 8, 53. [Google Scholar] [CrossRef]
  29. Sonkar, S.; Baraniuk, R. Investigating the Role of Feed-Forward Networks in Transformers Using Parallel Attention and Feed-Forward Net Design. arXiv 2023, arXiv:2305.13297. [Google Scholar]
  30. Biçici, E.; Kanburoglu, A.B.; Türksoy, R.T. Residual Connections Improve Prediction Performance. In Proceedings of the 2023 4th International Informatics and Software Engineering Conference (IISEC), Ankara, Turkiye, 21–22 December 2023; pp. 1–5. [Google Scholar]
  31. Cui, Y.; Xu, Y.; Peng, R.; Wu, D. Layer Normalization for TSK Fuzzy System Optimization in Regression Problems. IEEE Trans. Fuzzy Syst. 2023, 31, 254–264. [Google Scholar] [CrossRef]
  32. Sáenz, C.A.C.; Becker, K. Understanding stance classification of BERT models: An attention-based framework. Knowl. Inf. Syst. 2023, 66, 419–451. [Google Scholar] [CrossRef]
  33. Wang, A.; Singh, A.; Michael, J.; Hill, F.; Levy, O.; Bowman, S.R. GLUE: A Multi-Task Benchmark and Analysis Platform for Natural Language Understanding. arXiv 2018, arXiv:1804.07461. [Google Scholar]
  34. Williams, A.; Nangia, N.; Bowman, S.R. A broad-coverage challenge corpus for sentence understanding through inference. arXiv 2017, arXiv:1704.05426. [Google Scholar]
  35. Chen, Z.; Zhang, H.; Zhang, X.; Zhao, L. Quora Question Pairs. 2018. Online Resource. Available online: https://api.semanticscholar.org/CorpusID:233225749 (accessed on 8 October 2024).
  36. Socher, R.; Perelygin, A.; Wu, J.; Chuang, J.; Manning, C.D.; Ng, A.Y.; Potts, C. Recursive deep models for semantic compositionality over a sentiment treebank. In Proceedings of the 2013 Conference on Empirical Methods in Natural Language Processing, Seattle, WA, USA, 18–21 October 2013; pp. 1631–1642. [Google Scholar]
  37. Dolan, B.; Brockett, C. Automatically constructing a corpus of sentential paraphrases. In Proceedings of the Third international workshop on paraphrasing (IWP2005), Jeju Island, Republic of Korea, 14 October 2005. [Google Scholar]
  38. Bentivogli, L.; Clark, P.; Dagan, I.; Giampiccolo, D. The Fifth PASCAL Recognizing Textual Entailment Challenge. TAC 2009, 7, 1. [Google Scholar]
  39. Rajpurkar, P.; Zhang, J.; Lopyrev, K.; Liang, P. Squad: 100,000+ questions for machine comprehension of text. arXiv 2016, arXiv:1606.05250. [Google Scholar]
  40. Rajpurkar, P.; Jia, R.; Liang, P. Know what you don’t know: Unanswerable questions for SQuAD. arXiv 2018, arXiv:1806.03822. [Google Scholar]
  41. Turc, I.; Chang, M.W.; Lee, K.; Toutanova, K. Well-read students learn better: On the importance of pre-training compact models. arXiv 2019, arXiv:1908.08962. [Google Scholar]
Figure 1. An overview of the AMKD method, showing how to distill the knowledge of an M-layer Transformer teacher model T to an N-layer Transformer student model S .
Figure 1. An overview of the AMKD method, showing how to distill the knowledge of an M-layer Transformer teacher model T to an N-layer Transformer student model S .
Applsci 14 09180 g001
Figure 2. Comparison of the autocorrelation matrix between the teacher model and the student model.
Figure 2. Comparison of the autocorrelation matrix between the teacher model and the student model.
Applsci 14 09180 g002
Figure 3. Comparison of average performance of AMKD and other methods on GLUE tasks.
Figure 3. Comparison of average performance of AMKD and other methods on GLUE tasks.
Applsci 14 09180 g003
Figure 4. Radar chart comparison of performance of AMKD and other methods on GLUE tasks.
Figure 4. Radar chart comparison of performance of AMKD and other methods on GLUE tasks.
Applsci 14 09180 g004
Table 1. Comparison of architecture and parameter size between the BERT teacher model and student models.
Table 1. Comparison of architecture and parameter size between the BERT teacher model and student models.
Model#LayerHidden SizeFeed-ForwardSpeedup#ParamsRelative
BERTBASE1276830721.0×109M100%
BERTTINY431212009.4×14.5M13.3%
BERTSMALL451220485.7×29.2M26.8%
BERT4-PKD476830723.0×52.2M47.9%
DistilBERT4476830723.0×52.2M47.9%
Table 2. Comparison of distillation components in different approaches.
Table 2. Comparison of distillation components in different approaches.
ModelTeacher ModelPrediction LayerEmbedding LayerHidden StatesAttention Matrix
BERT-KDBERTBASE
DistilBERTBERTBASE
BERT-PKDBERTBASE
BERT-AMKDBERTBASE
Table 3. Comprehensive evaluation results of various models on the GLUE benchmark. The best result for each task is shown in bold.
Table 3. Comprehensive evaluation results of various models on the GLUE benchmark. The best result for each task is shown in bold.
ModelMNLI-mMNLI-mmQQPQNLISST-2MRPCRTEAvg
(393K) (393K) (364K) (105K) (67K) (3.7K) (2.5K)
BERTBASE (Google)84.683.489.290.593.588.966.485.2
BERTBASE (Teacher)84.383.291.391.693.789.873.686.8
BERTTINY-FT75.474.983.584.887.683.262.678.9
BERTSMALL77.677.087.086.489.783.461.880.4
BERT4-PKD79.979.388.385.189.482.662.381.0
DistilBERT478.978.087.685.291.482.454.179.7
TinyBERT480.581.088.585.790.783.363.581.9
BERTTINY-KD76.576.184.685.188.283.562.879.5
BERTTINY-AMKD82.181.789.487.992.186.965.483.6
Table 4. Comprehensive evaluation results of baseline models and BERTTINY-AMKD on SQuAD. Evaluation metrics: EM (Exact Match) and F1 (F1 Score).
Table 4. Comprehensive evaluation results of baseline models and BERTTINY-AMKD on SQuAD. Evaluation metrics: EM (Exact Match) and F1 (F1 Score).
ModelSQuAD v1.1SQuAD v2.0
EM F1 EM F1
BERTBASE (Teacher)79.887.973.276.4
BERT4-PKD68.778.159.663.9
DistilBERT470.279.859.463.2
BERTTINY-AMKD72.181.865.268.4
Table 5. Performance comparison between AMKD-Last and AMKD-Skip on the GLUE benchmark.
Table 5. Performance comparison between AMKD-Last and AMKD-Skip on the GLUE benchmark.
ModelMNLI-mMNLI-mmQQPQNLISST-2MRPCRTEAvg
BERTTINY (AMKD-Last)81.180.588.286.391.887.065.082.8
BERTTINY (AMKD-Skip)82.181.789.487.992.186.965.483.6
Table 6. Ablation studies on distillation components in AMKD learning.
Table 6. Ablation studies on distillation components in AMKD learning.
ModelMNLI-mMNLI-mmSST-2MRPCAvg
BERTTINY-AMKD82.181.792.186.985.7
w/o Embed81.481.391.286.285.0
w/o Atten79.378.789.583.682.8
w/o Hidden79.578.991.084.283.4
Disclaimer/Publisher’s Note: The statements, opinions and data contained in all publications are solely those of the individual author(s) and contributor(s) and not of MDPI and/or the editor(s). MDPI and/or the editor(s) disclaim responsibility for any injury to people or property resulting from any ideas, methods, instructions or products referred to in the content.

Share and Cite

MDPI and ACS Style

Zhang, K.; Li, J.; Wang, B.; Meng, H. Autocorrelation Matrix Knowledge Distillation: A Task-Specific Distillation Method for BERT Models. Appl. Sci. 2024, 14, 9180. https://doi.org/10.3390/app14209180

AMA Style

Zhang K, Li J, Wang B, Meng H. Autocorrelation Matrix Knowledge Distillation: A Task-Specific Distillation Method for BERT Models. Applied Sciences. 2024; 14(20):9180. https://doi.org/10.3390/app14209180

Chicago/Turabian Style

Zhang, Kai, Jinqiu Li, Bingqian Wang, and Haoran Meng. 2024. "Autocorrelation Matrix Knowledge Distillation: A Task-Specific Distillation Method for BERT Models" Applied Sciences 14, no. 20: 9180. https://doi.org/10.3390/app14209180

Note that from the first issue of 2016, this journal uses article numbers instead of page numbers. See further details here.

Article Metrics

Article metric data becomes available approximately 24 hours after publication online.
Back to TopTop