To address the aforementioned challenges in practical federated learning, this section introduces a personalized federated learning framework with edge-side pruning. Through pruning on the edge, local model structures that align with the features of the local data are obtained, thereby meeting the personalized requirements. Additionally, decision-making from personalized models assists the central server in filtering out more important features extracted from convolutional networks, thereby addressing the performance degradation issue caused by Non-IID data.
3.2. Adaptive Pruning for Local Training
In the federated learning process, the global model is aggregated from individual client models, and the computation of the global model follows the principles outlined by the following formula:
The performance of the global model is closely related to the performance of client models. According to Equation (1), it can be assumed that, when the loss of each client decreases, the aggregated global model loss will also decrease. Model pruning can be employed to facilitate the rapid convergence of the model, necessitating the identification of a suitable criterion for model pruning. One commonly used method involves performing a Taylor expansion on the loss function
:
In deep learning, it is impractical to compute the Hessian matrix
in each training cycle due to the computational requirements (
complexity, where
N refers to parameter size). To simplify computations, the focus is primarily on the first term in Equation (2). During model pruning,
represents all the weights in the model and can be expressed using the following formula:
Therefore, it can be deduced that . Therefore, we can conclude that the rate of loss reduction in the model is primarily correlated with the gradient information , and, the larger the retained , the more quickly the loss decreases. This establishes the criteria in model pruning, determining whether a particular weight should be pruned in the current pruning iteration.
According to the lottery ticket hypothesis [
15], there exists a sparse subnetwork for every model, and training this subnetwork does not compromise on time and accuracy compared to the complete model. To quickly identify the winning lottery ticket in the initialization stage, initial pruning is performed on the client side. By horizontally comparing the masks uploaded by various clients to the server, not only can the “winning tickets” be more accurately filtered out, but the process is also faster. The initialization stage workflow is shown in
Figure 2.
During the initialization phase, clients perform initial pruning on their local data. This step immediately removes some redundant structures from the model, compressing the model size. Starting training with a lightweight model from the beginning can effectively reduce the time and computational costs of training. This process relies entirely on user data on the client side, making it personalized from the start. After initial pruning, each client sends the pruned weights and masks to the server. The server then decides the lottery ticket network and sends it back to the clients.
Various methods were employed to decide the global sparse network. One simple decision method is to set thresholds for parameters. For a particular weight, if the number of clients retaining that weight exceeds the threshold, the weight is retained. The calculation method is shown in the following formula:
In addition to this method, an approach based on the distinctive characteristics of each layer in the model is employed. Different methods are designed for selecting the lottery ticket network at each layer. The process is illustrated in Algorithm 1. This method not only safeguards specific weights, such as those in the input and output layers, but also accelerates the identification of the lottery tickets, typically within three pruning rounds. The Layerwise Pruning algorithm prunes really quickly: if the algorithm loops for over five iterations, the model will become too sparse to train, so it is appropriate to set
close to the density of the model after the algorithm loops for two or three iterations.
Algorithm 1 Layerwise Pruning |
- 1:
- 2:
while not terminate do - 3:
clients do local pruning using Algorithm 2 - 4:
clients send mask to the server - 5:
for in do - 6:
if is input layer or output layer: then - 7:
- 8:
else - 9:
- 10:
for n = 2 to N do - 11:
- 12:
end for - 13:
end if - 14:
end for - 15:
remained nodes/all nodes - 16:
- 17:
end while - 18:
return
|
Upon entering the adaptive pruning stage, a small sparse model is obtained on the client side. For clarity, in the pruning process, we refer to the finest granularity of pruning as “nodes.” Here, the deployed local pruning method is a combination of magnitude pruning and first-order pruning. From the previous derivation (Equation (
1)–(
4)), it can be deduced that
. The sum of squared gradient value
is a direct influencing factor for rapidly reducing loss. Therefore, a three-step local pruning method is introduced, as is shown in
Figure 3: Firstly, the nodes with absolute values smaller than
are pruned, while protection is initially applied to nodes with larger absolute weight values; Secondly, the formerly pruned nodes and the newly pruned nodes together make the selection set; Thirdly, recover those nodes with relatively high gradients in the selection set. By doing this, a balance has been achieved between considering magnitude and squared gradients. The local adaptive pruning algorithm workflow is primarily outlined in Algorithm 2.
Algorithm 2 Local Pruning |
- 1:
- 2:
- 3:
Selection Set = { | in } - 4:
- 5:
- 6:
- 7:
for in : do - 8:
if then - 9:
- 10:
- 11:
- 12:
else - 13:
- 14:
end if - 15:
end for - 16:
- 17:
return
|
Therefore, in each subsequent pruning iteration, the adaptive pruning algorithm will either remove or reintroduce some nodes based on their importance. The selection is made among nodes with small absolute weights and nodes that have already been pruned. In this process, to avoid mistakenly pruning important nodes, pruned nodes with higher may be reintroduced during the iteration. This approach makes it easier to recover winning tickets when they are mistakenly pruned during the process, thus mitigating any potential impact on model performance.
Because the gradient distribution of the model is unpredictable, it is difficult to decide on a threshold for filtering squared gradients beforehand. Therefore, employing an algorithm that can autonomously determine a suitable threshold for squared gradients is advantageous. In the function for local pruning, the set initially equals the union of all the squared gradients of the entire selection set. As is repeatedly included back into , the average of becomes larger. Meanwhile, gradually decreases in descending order until they converge at a critical point. Through this approach, there is no need to set a threshold manually, as the function will automatically determine the threshold.
By combining the initial pruning described above with subsequent adaptive pruning methods, the model’s size can be significantly reduced, its training speed can be accelerated, and a dynamic balance in the pruning process can be achieved after several rounds.
3.3. Federated Learning Method based on Adaptive Pruning
In traditional federated learning approaches, both server and client models share the same model structure. While this facilitates model aggregation, it compromises personalized features. To address the need for personalized models and cope with diverse client data, individualized pruning is applied at each client during the federated learning process. Each client submits personalized sparse parameters to the server, as illustrated in the workflow in
Figure 4.
The method involves the following steps:
Step 1: After receiving models uploaded by clients, the server aggregates personalized parameters from each model and sends the aggregated result, the global model, to all clients. The aggregation process can be represented by Algorithm 3. Generally, is chosen to be around N/2;
Step 2: Clients receive the global model and load it into local caches. Due to the adoption of personalized pruning strategies, the sparse structure of local models differs from the global model. To address the issue of merging parameters with different sparse structures, a set of dense tensors is cached locally. These tensors are used to receive and merge local parameters with global model parameters, as illustrated in the following formula:
With this arrangement, the client would use the updated global model for the weights included in both the global and local masks (i.e., in ), while retaining the previous local model for the remaining of the locally significant weights (in );
Step 3: After receiving global parameters from the server, local training is conducted, using the merged tensors to obtain new model parameters, which can be represented as
. In a pruning round, the pruning algorithm from Algorithm 3 is executed after local training. Usually, the pruning frequency is set to once per 50 or 100 local training epochs. After obtaining the resulting new mask
, multiply it with the model parameters to obtain
, which is then sent to the server.
Algorithm 3 Global Merge Algorithm |
- 1:
for n = 1 to N do - 2:
- 3:
- 4:
end for - 5:
- 6:
- 7:
- 8:
return
|
This approach enables personalized pruning federated learning with different sparse model structures on each client. While achieving the flexibility to aggregate various model parameter shapes, this method also increases the computational requirements on local devices.