3.3.2. Adaptive Graph Module
One of the pivotal components of our novel framework is the Adaptive Graph Module (AGM). This module intakes feature embeddings from a support set and a query set, subsequently adjusting the features of the support set based on the information from the query set. The output is the feature embeddings of the support and query sets in an alternative feature space. Essentially, our proposed AGM re-projects the features, enabling the model to calculate the distance or similarity between the query samples and various class prototypes within the channel feature space.
For simplicity, let
, which represents the feature representation of all samples from the support and query sets after the
layer of AGM. These samples are processed through the next layer of AGM as follows:
Here, denotes the aggregated sample information, is an adjacency matrix representing the relationships between samples at that layer, W is a projection matrix, denotes a non-linear activation function, and stands for FeedForward Network. The AGM module can be dissected into two segments: (1) Graph Construction: This part details how we express samples from the support and query sets as a graph; (2) Message Passing: We delve into the mechanism of message passing within Graph Neural Networks in the context of our problem. Specifically, to obtain adaptive prototypes, we direct the flow of information from the query-set towards the support-set samples, which allows for dynamic adjustments based on the query.
- (1).
Graph Construction
For simplicity, we denote as the total number of samples encompassing both the support and query sets.
Let represent a graph constituted by nodes from these sets, where V is the collection of nodes, defined as . Here, symbolizes the malware sample at the AGM layer. For samples drawn from the support set, i ranges from 1 to , and for those from the query set, i ranges from K to T. Given the discussion pertains to samples from a singular layer of the AGM, we simplify our notation by omitting the subscript l, thus reducing it to . E is the set of edges, where for any delineates the relationship between the and samples.
However, the semantic relationships between samples remain somewhat nebulous. The subsequent discussion aims to elaborate on the precise construction process of these relationships, enhancing the understanding and analytical capabilities within the graph model.
We define
as a measure of the relationship between the
and
samples. The underlying premise is that the closer or more relevant the samples are to each other, the larger the value of
should be. This conceptualization allows us to capture the intrinsic structure and relationship dynamics within the data. The formula for this relationship measure can be represented as follows:
Here, represents the distance function between different samples, with denoting the parameters that can be learned. This function can take various forms, including both parametric and non-parametric expressions. For non-parametric forms, could be: (1) Euclidean Distance: ; (2) Cosine Similarity: . For parametric forms, could be: (3) Sum: ; (4) Dot Product: .
In our constructed sample graph, to enable a more flexible framework that could capture a wider variety of relationships and to allow the model to learn a more nuanced and accurate representation, we employed a parametric distance function
. Furthermore, to maintain the symmetry of distance, ensuring
, we incorporated a methodology delineated in [
24], which adopts a Multilayer Perceptron (MLP) stacked after the absolute difference between two vectors. This can be mathematically represented as:
Here,
represents the learnable parameters. By using this architecture, the symmetry of the distance function
is satisfied, and the distance property identity
is easily learned. For illustration, we visualize the MLP architecture in
Figure 6.
Having established the pairwise distances between samples, we can readily construct the adjacency matrix
A, where the proximity of two samples inversely correlates with the strength of their connection. This relationship is mathematically represented as:
Subsequently, we normalize each row of the adjacency matrix. For this process, we adopt the normalization technique proposed in the Graph Attention Network (GAT) [
32], which involves applying a LeakyReLU activation function to the adjacency matrix followed by a softmax computation to determine the attention coefficients
for the edges. The formula for this computation is as follows:
With the graph constructed, defining its nodes as the collective set of samples from both the support and query sets, and its edges as the relationships between each pair of samples, we now proceed to the message-passing phase. Specifically, we aim to transmit the query set’s information to the support set. This process allows for the acquisition of adaptive features that are refined in response to the query set’s characteristics.
- (2).
Message Passing
Unlike traditional Graph Neural Networks, the direction of message passing in the Fully Connected graph constructed in the previous section is not bidirectional. Primarily, it is important to note that within the query set, information sharing is not feasible. This is understandable, since each sample in the query set is independent, with no inherent relations between them.
Our objective was to allow the support set to encompass the information from the query-set samples, thereby acquiring adaptive prototypes. This was inspired by [
29], which maintained the invariant features of the support set and directed the flow of information towards the query set. However, our goal differed, in that we sought prototypes adapted to the query set, meaning the support set should contain information from the query set. As illustrated in
Figure 3, the blue lines between samples represent mutual information transfer, while the yellow lines indicate information flow solely from the query to the support set.
To achieve the directional requirements of the aforementioned information transfer, while ensuring that the sum of each row in the normalized adjacency matrix equaled one, we introduced a mask matrix. The visualization of the mask matrix is shown in
Figure 7. With a total sample count of
T for the support and query sets, the mask and adjacency matrices were of dimension
. The first
rows of the mask were zero, and the subsequent rows were set to
(to become zero after softmax calculation).
The entire process can be articulated as:
With the modified adjacency matrix in place, we then proceeded to the standard graph-convolution-aggregation and update processes. The aggregation process involved gathering the information of neighbors, first obtaining the projection of each sample
, and then aggregating information according to the adjacency matrix, as shown in Formula (
2). The update process involved merging the node’s own information with the aggregated neighbor information, where we simply employed a residual connection:
We also experimented with more complex update mechanisms, such as the one used in GAT [
32]:
However, the empirical results suggested that the other update mechanisms did not offer a clear advantage. Instead, the simple residual connection model demonstrated more stable convergence.
As described in Formula (
1), after updating the features of each sample, we enhanced the model’s expressiveness and obtained more flexible feature-dimension representation by adding a Feed-Forward Network (FFN) following the graph-convolution operation, achieving the feature embedding of each sample in the new feature space.
3.3.3. Proto Aggregation Layer
In the preceding section, we explored the transformation of sample features from the support and query sets through an AGM into . Assuming our model comprised L AGM layers, in addition to the initial features extracted via a CNN4 architecture, we obtained sets of sample features for both the support and query sets, denoted as . Each set, , contained features from the support and query sets.
The proto aggregate layer focuses on computing the prototype of each class within the support-set sample features. We adhered to the most classical prototype calculation method [
20], utilizing the mean of all features within the same class in the support set as the prototype for that class. The computation proceeds as follows:
Prototype Calculation: For each class
c within the support set at layer
l, the prototype
is computed as the mean of all feature vectors belonging to that class. Mathematically, this is represented as:
Here, is the prototype for class c at layer is the number of samples in class c, and represents the set of all feature vectors in the support set at layer l that belong to class c. This formula ensures that the prototype is the centroid of the features in the class, effectively summarizing the class’s overall position in the feature space.
For each set of sample features, we computed prototypes using the aforementioned method, ultimately obtaining sets of prototypes, denoted as .
3.3.4. Attention-Based Dynamic Proto-Layer
In this section, we introduce an innovative
Attention-Based
Dynamic
Proto-
Layer (ADPL), which marries the principles of prototype networks with the dynamic-convolution methodology. Upon traversing the proto aggregate layer, we acquire
prototype sets denoted by
P. Inspired by the dynamic-convolution concept proposed by [
33], ADPL leverages an attention mechanism to adeptly distribute weights across these prototypes, contingent upon the query-set embeddings’ nuances.
As shown in
Figure 4, the ADPL mainly consists of three parts:
Distance Computation: For each prototype set
, where there are
C classes of prototypes denoted as
, we calculate the Euclidean distance between the query embedding
and each class’s prototype. This metric, denoted as
, assesses how closely each prototype corresponds to the query instance. It forms the foundational metric for subsequent attention-weight allocation, influencing the model’s decision-making process in classification or retrieval tasks. The distance for the
prototype set is computed as follows:
Here, represents the Euclidean distance for the prototype set to the query sample ; is the query-embedding vector; is the class prototype in the set; and C is the total number of classes.
Attention Mechanism: We introduce an attention mechanism where the input is the embedding from the first feature space (the 0
th space) of the query set. This mechanism comprises a Fully Connected (FC) layer succeeded by Batch Normalization (BN) and Rectified Linear Unit (ReLU) activation. A subsequent FC layer projects the dimensions to align with the number of prototypes, and a softmax function is applied, to yield the attention weights
for each prototype. To address the issue of the softmax layer’s output labels approximating one-hot encoding or, in other words, to prevent the model from excessively focusing on prototypes specific to certain groups, we use a large temperature in softmax to flatten attention as follows:
where
is the output of the last FC layer in the attention branch (see
Figure 4). The attention weights reflect the importance or relevance of each prototype in relation to the query.
Weighted Summation: In the final step, a weighted summation is performed, where the Euclidean distances are multiplied by their corresponding attention weights. Since a smaller distance indicates that the query set sample is closer to the class, the predicted probability should be higher; therefore, the logits need to be negated. The formula for the weighted summation is expressed as:
The logits are the aggregated result of this process and represent the final predictive output of the model.
By integrating the attention mechanism, the ADPL allows our model to dynamically emphasize the most pertinent prototypes and provides a more nuanced and context-aware method for prototype weighting. This ensures that our model remains robust and efficient, even when faced with the challenge of learning from a limited number of samples.