Next Article in Journal
Hole Appearance Constraint Method in 2D Structural Topology Optimization
Previous Article in Journal
Dynamic Byzantine Fault-Tolerant Consensus Algorithm with Supervised Feedback Mechanisms
 
 
Font Type:
Arial Georgia Verdana
Font Size:
Aa Aa Aa
Line Spacing:
Column Width:
Background:
Article

Generalized Federated Learning via Gradient Norm-Aware Minimization and Control Variables

College of Systems Engineering, National University of Defense Technology, Changsha 410073, China
*
Author to whom correspondence should be addressed.
Mathematics 2024, 12(17), 2644; https://doi.org/10.3390/math12172644
Submission received: 24 July 2024 / Revised: 22 August 2024 / Accepted: 23 August 2024 / Published: 26 August 2024

Abstract

:
Federated Learning (FL) is a promising distributed machine learning framework that emphasizes privacy protection. However, inconsistencies between local optimization objectives and the global objective, commonly referred to as client drift, primarily arise due to non-independently and identically distributed (Non-IID) data, multiple local training steps, and partial client participation in training. The majority of current research tackling this challenge is mainly based on the empirical risk minimization (ERM) principle, while giving little consideration to the connection between the global loss landscape and generalization capability. This study proposes FedGAM, an innovative FL algorithm that incorporates Gradient Norm-Aware Minimization (GAM) to efficiently search for a local flat landscape. FedGAM specifically modifies the client model training objective to simultaneously minimize the loss value and first-order flatness, thereby seeking flat minima. To directly smooth the global flatness, we propose the more significant FedGAM-CV, which employs control variables to correct local updates, guiding each client to train models in a globally flat direction. Experiments on three datasets (CIFAR-10, MNIST, and FashionMNIST) demonstrate that our proposed algorithms outperform existing FL baselines, effectively finding flat minima and addressing the client drift problem.

1. Introduction

An extensive amount of valuable data is generated daily by wearable sensors, smartphones, autonomous vehicles, and an increasing number of Internet of Things (IoT) devices [1,2,3,4,5]. Service providers in domains such as computer vision, natural language processing, and speech recognition benefit from such large-scale data, as deep learning techniques, which require substantial amounts of data, can be employed to enhance AI-based products and user experiences [6,7,8]. However, centralizing all user data on a single server is strongly discouraged and limited due to growing concerns about data privacy [9,10]. In response, a decentralized machine learning framework called Federated Learning [11] (FL) has emerged, allowing multiple training participants (clients) to collaboratively train a global model while preserving privacy. The core idea of FL is to enable collaborative training among multiple clients without sharing the original data of clients by exchanging model parameters. This model training method, predicated on privacy protection, paves the way for the widespread adoption and practical application of Artificial Intelligence (AI).
However, FL still faces the problem of client drift during its application [12,13,14,15]. This issue manifests as a divergence between the local updates implemented on the client and the global update, indicating a lack of consistency between the local objectives and the global objective. There are several main causes contributing to this problem. First, the distribution of data among clients is not independent and identically distributed (Non-IID). Secondly, an excessive number of local epochs leads to overfitting of the local model. Lastly, unstable communication results in only a limited number of clients participating in a training round. In order to address the issue of client drift, current prevalent approaches can be classified into the following two categories: (1) approaches that rely on local target regularization [16,17,18,19,20] and (2) methods that rely on model aggregation optimization [21,22,23,24,25,26]. Nevertheless, the aforementioned methods are founded upon the empirical risk minimization (ERM) principle [27], which is a fundamental principle in machine learning aiming to find the optimal model parameters by minimizing the average loss on the training data. In FL, the ERM principle is also widely applicable, and the majority of current FL algorithms are based on ERM. These algorithms focus solely on minimizing the loss value during model training, indicating a need for further research and improvement.
The correlation between a model’s generalization capability and flat minima in the loss landscape has been demonstrated from both theoretical and empirical perspectives [28,29,30,31]. However, current FL algorithms seldom investigate the correlation between these two factors. The Gradient Norm-Aware Minimization [32] (GAM) algorithm proposed in recent years aims to find flat minima. GAM defines the maximum gradient norm within the neighborhood of model parameters as first-order flatness. By simultaneously minimizing the loss value and first-order flatness, GAM flattens the loss surface, thereby reducing the model’s generalization error. Therefore, incorporating GAM into FL can enhance the generalization capability of FL models.
To address the issue of client drift, we propose FedGAM, which shifts the objective of client model training to simultaneously minimize the loss value and first-order flatness. FedGAM enhances the flatness of the client model’s loss landscape, thereby improving the performance of the global model by enhancing the generalization capability of local models.
Although FedGAM has the potential to enhance the generalization capability of the global model to some extent, it does not directly contribute to the flatness of the loss landscape of the global model. The mean of the client’s flat minima may still be located in sharp regions. To steer the entire FL training process towards a globally flat direction, we propose a supplementary algorithm, FedGAM-CV, based on FedGAM. FedGAM-CV introduces control variables during the client training process to correct the local optimization objectives, ensuring that each client model updates in the direction of the global optimization objective. This approach enhances the generalization capability of the global model by improving its flatness.
Our main contributions are summarized as follows:
  • We address one of the most troublesome FL challenges, i.e., client drift. To enhance the generalization capability of the global model, we first propose FedGAM, which shifts the objective of client model training to simultaneously minimize the loss value and first-order flatness, aiming to find flat minima.
  • To achieve direct smoothing of the global model, we propose FedGAM-CV based on FedGAM. FedGAM-CV leverages control variable techniques to better align the updates from each client, guiding them towards a common global flat region.
  • We conduct extensive experiments to demonstrate the superiority of FedGAM and FedGAM-CV. The results of these experiments demonstrate that FedGAM and FedGAM-CV not only achieve stronger generalization performance than baseline algorithms but also exhibit robustness against causes of client drift in various settings.
The rest of this paper is organized as follows. We review the related literature in Section 2. Next, in Section 3, we introduce the preliminary knowledge and provide a detailed explanation of client drift. Subsequently, in Section 4, we propose FedGAM and FedGAM-CV. We compare the proposed algorithms with several FL baseline algorithms through extensive experiments in Section 5. Finally, we conclude the article in Section 6.

2. Related Works

This section provides an overview of research related to client drift in FL, focusing on the following three main approaches: regularization of the local target, modeling aggregation, and FL based on flatness.

2.1. Regularization of the Local Target

The core idea behind regularization strategies is to reduce the deviation between client objectives and the global objective, thereby enhancing the performance of the global model. FedProx [16] incorporates an additional L2 regularization term into the objective function of the client model to limit the discrepancy between the client and global models. FedDyn [18] introduces a strategy for dynamically updating the regularization term to align local objectives as closely as possible with the global objective. SCAFFOLD [19] introduces control variables on both the server and client sides to estimate the optimization directions of the global and client models, respectively, and uses the difference between them to correct local updates. MOON [20] adopts the concept of contrastive learning, using the similarity between model representations to correct the model training of each client. This method not only reduces the gap between the global model and local model output vectors but also increases the difference in client model outputs across consecutive training rounds, further optimizing the model’s generalization capability. DRAG [17] introduces a “divergence” metric to quantify the deviation between local updates and historical global updates and dynamically adjusts strategies to “drag” local updates towards the global optimization objective.

2.2. Modeling Aggregation

Methods based on model aggregation optimization innovate and improve upon the traditional simple averaging approach for local updates. FedNova [26] normalizes and scales local updates based on the number of local iterations of each client before updating the global model. FedBE [23] employs Bayesian model ensembling for model aggregation, achieving more accurate model precision in a single round of communication than traditional weighted averaging methods. FedLAW [21] suggests that the sum of the aggregation weights of local models should be less than 1. This design induces a global weight shrinkage effect, which can improve the model’s generalization capability by effectively regularizing the aggregated model, thus preventing overfitting to any particular client’s data. Fed-CBS [22] introduces a heterogeneity-aware client sampling mechanism that effectively mitigates the issue of class imbalance among sampled clients. By considering the heterogeneity in the data distribution across clients, this mechanism ensures a more representative and balanced selection of clients for training, which can lead to improved model performance and fairness.

2.3. Flatness-Based Federated Learning

Numerous studies have demonstrated a close relationship between a model’s generalization capability and the flatness of the loss landscape [33,34,35,36]. Foret et al. proposed the Sharpness-Aware Minimization [28] (SAM) algorithm, which introduces sharpness-aware optimization techniques to find flat minima in the model’s loss landscape, thereby enhancing the model’s generalization capability. FedSAM [37,38] replaces the local optimizer with SAM, making the client’s optimization objective to find flat minima, thereby improving the generalization capability of the local model, which indirectly enhances the performance of the global model. Building on FedSAM, FedSpeed [39] and FedSMOO [40] point out that FedSAM cannot directly optimize the global model and introduce dynamic regularization terms to further mitigate the client drift problem. However, since the FedSAM algorithm ignores the model’s maximum gradient norm, it still exhibits significant generalization errors. Zhang et al. [32] found that SAM has certain limitations when multiple minima exist within the parameter neighborhood. To overcome this issue, they proposed GAM, which defines the maximum gradient norm within the perturbation radius as first-order flatness, thereby achieving superior generalization performance. Despite the significant potential of flatness in enhancing model generalization capability, research on integrating it with FL model training remains relatively scarce.

3. Preliminaries

3.1. General Federated Learning

An FL system is composed of N clients and a server, with local training data ( X i , Y i D i ) possessed by each client ( i [ N ] ). In the t-th round of communication, a subset ( S t ) of clients, denoted as S t = K , is formed by randomly selecting K clients from the N clients. Consider the global model ( w R d ), the loss function L ( · ; · ) , the size of the random sample taken from client i’s training data ( n i ), and the aggregation weight ( p i ) for client i, which must satisfy Equation (1). The main symbols used in this paper are defined in Table 1.
i N p i = 1 .
In FL, the objective is to train a global model (w) through the collaboration of multiple clients, without the need to centralize data on a single server. Each client (i) possesses a local dataset ( D i ) and updates its local model based on ERM. The empirical risk for client i is defined as  
f i ( w ) 1 n i j n i L w ; x j , y j .
The objective of the global model is to minimize the weighted sum of the empirical risks across all clients, which is expressed as
min w F ( w ) = N K i S t p i f i ( w ) .
Currently, almost all FL algorithms conform to a two-stage training framework. In this framework [41,42,43], local clients are responsible for training their respective models, while the central server employs an aggregation algorithm to combine the models of the clients and generate an updated global model.

3.2. Analysis of the Causes of Client Drift

In FL, local minima and global minima are defined as follows:
  • Local minima: For each client (i), the local minimum ( w i * ) satisfies the following condition:
    f i w i * = 0 , 2 f i w i * 0 .
  • Global minima: The global minimum ( w * ) satisfies:
    F w * = 0 , 2 F w * 0 .
Client drift refers to the phenomenon in FL where the global model update direction deviates due to inconsistencies between the optimization objectives of individual clients and the global objective. The main causes include the following:
  • The Non-IID nature of client data distributions: Different clients have varying data distributions, resulting in each client’s loss function having a different minimum ( w i * ), which leads to inconsistent local update directions.
  • Multiple local training steps: Before each round of global aggregation, clients perform multiple local training steps, causing the local model ( w i ) to overfit the local data and deviate from the global optimal solution ( w * ).
  • Partial client participation in training: Only a subset of clients participates in each training round, leading to a lack of representativeness in the global model updates, which further exacerbates client drift.

3.3. Zeroth-Order Flatness and First-Order Flatness

In machine learning and optimization, flat minima and sharp minima are concepts used to describe the shape of the loss function near its minimum. A flat minimum refers to a region near the minimum of the loss function where the loss value changes little over a larger range. Mathematically, a flat minimum can be described by the maximum eigenvalue of the Hessian matrix ( λ max 2 F w * ). If λ max 2 F w * near the minimum w * is small, the minimum is considered flat; otherwise, it is sharp. The primary distinction between the two is that models corresponding to flat minima are less sensitive to small perturbations in the input data and exhibit better generalization capability, performing well on unseen data. To enhance the generalization capability of models, the concepts of zeroth-order flatness and first-order flatness have been proposed in recent years.

3.3.1. Zeroth-Order Flatness

The most widely accepted mathematical definitions of flatness consider the maximum loss value within a given radius, referred to as zeroth-order flatness. In this study, we adopt the loss function proposed in SAM [28]:
f SAM ( w ) = f ( w ) + max w B ( w , ρ ) f w f ( w ) ,
where ρ represents the perturbation radius, which determines the extent of the neighborhood; B ( w , ρ ) represents a high-dimensional open ball centered at w with a radius of ρ ; and the calculation method for f ( · ) is shown in Equation (2). The second term on the right-hand side of Equation (6) can be interpreted as a measure of zeroth-order flatness.
Definition 1
(Zeroth-order flatness).  B ( w , ρ ) represents a high-dimensional open ball centered at w with a radius of ρ. For any perturbation radius of ρ > 0 , the zeroth-order flatness ( G ρ ( w ) ) of loss function f ( w ) at a point w is defined as follows:
G ρ ( 0 ) ( w ) max w B ( w , ρ ) f w f ( w ) .
From Equation (6), it is evident that incorporating zeroth-order flatness ensures that the gap between the maximum loss value ( f w ) and the current point ( f ( w ) ) is minimized. To some extent, this prevents the occurrence of sharp minima and enhances the flatness of the loss landscape.

3.3.2. First-Order Flatness

A recent study [32] proposing GAM analyzed the limitations of zeroth-order flatness and proposed first-order flatness, which represents the maximum gradient norm in the vicinity of the current point.
Definition 2
(First-order flatness).  B ( w , ρ ) represents a high-dimensional open ball centered at w with a radius of ρ. For any perturbation radius of ρ > 0 , the first-order flatness ( G ρ ( w ) ) of loss function f ( w ) at a point w is defined as follows:
G ρ ( 1 ) ( w ) ρ · max w B ( w , ρ ) f w .
The objective function of GAM is expressed as follows:
min w f G A M ( w ) = f ( w ) + α · G ρ ( 1 ) ( w ) .
Intuitively, GAM helps to find a minimum with a smaller maximum gradient norm within a parameter neighborhood by simultaneously minimizing the loss value and first-order flatness. Since the maximum gradient norm represents the extent of change in the loss value within the parameter neighborhood, by constraining this value, the loss function ( f ( w ) ) should remain relatively constant within the neighborhood.

3.3.3. Comparison with Zeroth-Order Flatness

We compare first-order flatness with zeroth-order flatness and demonstrate that zeroth-order flatness ( G ρ ( 0 ) ( w ) ) is bounded by first-order flatness G ρ ( 1 ) ( w ) . Assuming a perturbation value of ϵ * = arg max ϵ B ( 0 , ρ ) f ( w + ϵ ) , 0 R d , G ρ ( 0 ) ( w ) = f w + ϵ * f ( w ) . According to the mean value theorem, there exists a constant ( 0 c 1 ) such that
f w + ϵ * f ( w ) = f w + c · ϵ * ϵ * .
As a result, according to the Cauchy–Schwarz inequality,
G ρ ( 0 ) ( w ) = f w + ϵ * f ( w ) = f w + c · ϵ * ϵ * f w + c · ϵ * ϵ * max ϵ B ( 0 , ρ ) f ( w + ϵ ) · ρ = G ρ ( 1 ) ( w ) .
Therefore, a smaller G ρ ( 1 ) ( w ) results in a smaller G ρ ( 0 ) ( w ) , indicating that G ρ ( 1 ) ( w ) is a stronger measure of flatness than G ρ ( 0 ) ( w ) . This also indicates that first-order flatness has a broader range of applications compared to zeroth-order flatness.

4. The Proposed Algorithms

In this section, we first introduce FedGAM and derive the gradient update formula. Then, we propose FedGAM-CV and analyze the effectiveness of the control variate technique.

4.1. FedGAM: Federated Learning Based on Gradient Norm-Aware Minimization

To better enhance the generalization capability of client models and thereby improve the performance of the global model, we propose FedGAM. Unlike traditional training methods based on ERM, FedGAM focuses on simultaneously minimizing the loss value and first-order flatness. This ensures that the loss function ( f i ( w ) ) remains relatively stable within the neighborhood of w, resulting in a flatter loss landscape and improved generalization capability of the final trained model. The detailed algorithmic process of FedGAM is illustrated in Algorithm 1, and the objective function of FedGAM is expressed as follows:
min w F ( w ) = N K i S t p i f i G A M ( w ) ,
f i G A M ( w ) = f i ( w ) + α · G ρ ( w ) ,
where G ρ ( w ) = G ρ ( 1 ) ( w ) = ρ · max w B ( w , ρ ) f w , and the calculation of f i ( w ) is shown in Equation (2). α is the trade-off coefficient.
Algorithm 1 FedGAM
Input: Initial server model w; Learning rate η ; Perturbation radius ρ ; Trade-off coefficient α
Output: Updated global model w
1:
for each round t = 0 to T 1  do
2:
      Sample subset S t [ N ] of clients
3:
      Communicate w to all clients i in S t
4:
      for each client i in parallel do
5:
          w i , 0 = w
6:
          for  local epoch k = 0 to E 1  do
7:
                  Sample a minibatch ξ j , k
8:
                  w i , k + 1 = w i , k η f i G A M w i , k Compute f i GAM w i . k as Equation ( 19 )
9:
          end for
10:
          Communicate w i E to the server
11:
     end for
12:
      w = 1 K i S t w i , E
13:
end for
During the model training optimization process, it is necessary to frequently compute the gradient of the loss function ( f i G A M ( w ) ), which also requires the calculation of the gradient of the first-order flatness ( G ρ ( w ) ). However, directly computing G ρ ( w ) is extremely challenging.To compute the gradient of f i G A M ( w ) , referencing the estimation method in SAM [28], we have the following:  
G ρ ( w ) = ρ · w max μ B ( 0 , ρ ) f ( w + μ ) .
By applying a first-order Taylor expansion to f ( w + μ ) , we obtain the following:
μ * ( w ) = arg max μ B ( 0 , ρ ) f ( w + μ ) arg max μ B ( 0 , ρ ) f ( w ) + ( f ( w ) ) · μ = arg max μ B ( 0 , ρ ) ( f ( w ) ) · μ = ρ · G G ,
where G = f ( w ) . Letting
w g a m = w + μ * ( w ) ,
we obtain 
G ρ ( w ) ρ · w f w + μ * ( w ) = ρ · f w g a m + d μ * ( w ) d w · f w gam .
Upon eliminating higher-order terms, we derive the following:
G ρ ( w ) ρ · f w g a m .
Ultimately, when computing the gradient of f i G A M ( w ) , we have the following:
f i G A M ( w ) = f i ( w ) + α · G ρ ( w ) = f i ( w ) + α · ρ · f i w + ρ · f i ( w ) f i ( w ) .

4.2. FedGAM-CV: Federated Learning Based on Gradient Norm-Aware Minimization and Control Variables

FedGAM improves the flatness of the client model’s loss landscape, thereby boosting the local model’s generalization ability. However, FedGAM fails to take into consideration the flatness of the global model. In cases of high data heterogeneity, the average of local models’ flat minima may still fall into a sharp loss landscape, manifesting as a deviation between client models and the global model. To address this issue, we propose FedGAM-CV. FedGAM-CV adjusts the optimization objective of each client through the use of control variables, guiding the local optimization process towards directions more aligned with the global objective, thus steering the entire training process towards global flatness.

4.2.1. Method

FedGAM-CV sets control variables (c and c i for the server and client i, respectively), to estimate the optimization directions of the global model (w) and the client model ( w i ). These control variables serve to correct the optimization directions of both the client and the server. At the beginning of the training, the control variables (c and c i ) are initialized to 0, with 0 R d .
In each communication round, as shown in Algorithm 2 and Figure 1, the server sends the global model (w) and the control variable (c) to the selected clients, and the clients perform local training based on GAM. Unlike FedGAM, which uses w i , k η f i G A M w i , k as the computation method to update the model, FedGAM-CV adds a control term ( c c i ) to align local objectives with the global objective, updating the model parameters as shown in Equation (20), thereby adjusting the optimization direction of the client model.
w i , k + 1 = w i , k η f i G A M w i , k c i + c .
Algorithm 2 FedGAM-CV
Input: Initial server model w; Learning rate η ; Perturbation radius ρ ; Trade-off coefficient α ; Initial control variables c, c i
Output: Updated global model w
1:
for each round t = 0 to T 1  do
2:
      Sample subset S t [ N ] of clients
3:
      Communicate ( w , c ) to all clients i in S t
4:
     for each client i in parallel do
5:
           w i , 0 = w
6:
          for  local epoch k = 0 to E 1  do
7:
              Sample a minibatch ξ j , k
8:
              w i , k + 1 = w i , k η f i G A M w i , k c i + c     ▹ Compute f i GAM w i . k as Equation ( 19 )
9:
          end for
10:
          c i = c i c + 1 η E w i , 0 w i , E
11:
          Δ c i = 1 η E w i , 0 w i , E c
12:
          Communicate w i , E , Δ c i to the server
13:
    end for
14:
     c = c + 1 K i S t Δ c i
15:
     w = 1 K i S t w i , E
16:
end for
After completing E steps of local training, the client model parameters are updated, and the corresponding client control variable ( c i ) is also updated according to Equation (21). The client then uploads the updated parameters ( w i , E ) and the change ( Δ c i ) in the client control variable to the server, which uses the model averaging method to update the global model (w) and updates the global control variable (c) using Equation (22).
c i = c i c + 1 η E w i , 0 w i , E ,
c = c + 1 K i S t Δ c i .

4.2.2. Validity Analyses of Control Variables

In order to enhance the generalization capabilities of a client model on the global dataset, it is recommended that client i incorporate gradient information from all clients [44] throughout the gradient descent process based on GAM. Hence, in an ideal scenario, the model update method for client i should be the following:
w i = w i η 1 K m S t K f m G A M w m .
The update method described above can theoretically achieve the training effects of centralized training or training under independent and identically distributed data. However, this would require clients to exchange messages with all other clients in each communication round, leading to high communication costs. FedGAM-CV uses control variables to reduce the differences between client models. Based on Equations (20)–(22), we can derive the following conclusion:
c i f i G A M w i ,
c 1 K m S t K f m GAM w m .
Therefore, the formula for updating client model parameters in FedGAM-CV can be approximated as follows:
w i , k η f i GAM w i , k c i + c w i , k η 1 K m S t K f m GAM w m , k .
Indeed, by employing control variable techniques, the local update process of client i in FedGAM-CV can be considered an operation that integrates the gradient information from all clients participating in the training. As long as the proportion of clients participating in training is increased, the aggregated gradient information from all participating clients can be viewed as the gradient information of the global dataset, which represents the global optimization objective. Consequently, FedGAM-CV enables each client to train and update its model towards the direction of the global objective.

5. Experiments

We conduct extensive quantitative experiments to demonstrate the effectiveness of the proposed FedGAM and FedGAM-CV. Initially, we detail the experimental setup, followed by a performance comparison of the proposed algorithms against seven baseline algorithms on the FashionMNIST [45], CIFAR-10 [46], and MNIST [47] datasets. Secondly, we delve deeper into our study, finding that FedGAM-CV outperforms the baseline algorithms across three distinct client drift scenarios. Thirdly, we visualize the loss landscapes across different algorithms. Moreover, we conduct hyperparameter sensitivity experiments to identify suitable parameters. The code is available at https://github.com/xyccjd/FedGAM/tree/master, accessed on 18 March 2024.

5.1. Experimental Setup

(1) Datasets: To evaluate the performance of FedGAM and FedGAM-CV, we conduct experiments on the FashionMNIST, MNIST, and CIFAR-10 datasets. We utilize IID partitioning and Dirichlet partitioning as data partitioning methods. IID indicates that training samples are uniformly distributed across all clients. Similar to existing studies [14,44], we partition the data based on the Dirichlet distribution of label ratios (Dir( α ) [48], where a smaller α indicates stronger data heterogeneity). To demonstrate the performance of the proposed algorithms in heterogeneous data environments, we establish the following three data partitioning scenarios: IID, Dir(0.3), and Dir(0.7). Examples of client data partitioning results are illustrated in Figure 2, where, at α = 0.7, the category IDs in a client are mostly white, with only two or three categories in green. This indicates that the client possesses a large amount of data for two to three categories while having very little data for the other categories.
(2) Baselines: We compare FedGAM and FedGAM-CV against seven FL algorithms. Table 2 summarizes the main contributions of these seven algorithms.
(3) Metrics: Following the settings of previous studies [14,37,38], we use accuracy as the metric on both the federated training dataset and the global test dataset. Test accuracy demonstrates the generalization ability of the learned global model on the global test dataset. The gap between training and test accuracy represents the generalization gap of each algorithm, and a larger gap suggests a higher degree of overfitting in the algorithm’s models. For a fair comparison, all algorithms are run for the same number of rounds (set to 100 by default unless specified otherwise).
(4) Other settings: We implement FL algorithms using the PyTorch framework and conduct experiments on a GeForce RTX 4090. For FashionMNIST and MNIST, we utilize a CNN architecture comprising two 5 × 5 convolutional layers followed by 2 × 2 max pooling layers and two fully connected layers with ReLU activation, and for CIFAR-10, we employ ResNet-18 [49]. The number of clients is set to 100, with the local epoch set to 15, and the default clients’ participation rate is 1.0. The learning rate for all algorithms is set to 0.01. The regularization term weight for FedProx is 0.1. The temperature parameter for the contrastive loss in MOON is set to 0.5. The perturbation radius for FedSAM is 0.1. The parameter settings for FedSMOO are the same as those in [40]. For FedGAM and FedGAM-CV, the perturbation radius is set to 0.02, with a trade-off coefficient ( α ) of 0.2.

5.2. Overall Performance Comparison

In Table 3, we report a comparison of training and test accuracy between seven baseline algorithms (FedAvg, FedProx, FedNova, MOON, SCAFFOLD, FedSAM, and FedSMOO) and our proposed FedGAM and FedGAM-CV across three datasets (MNIST, FashionMNIST, and CIFAR-10) under three data partitioning strategies [IID, Dir(0.3), and Dir(0.7)]. To visually demonstrate the training process of each algorithm, we plot the test accuracy curves of different algorithms in Figure 3.
From the perspective of test accuracy, our proposed FedGAM-CV outperforms the other algorithms. Under different data partitioning strategies, FedGAM-CV achieves the highest test accuracy across all three datasets, indicating that the enhanced generalization capability brought by FedGAM-CV is independent of the dataset and data partitioning strategy. Additionally, FedSMOO demonstrates superior test accuracy on MNIST compared to other algorithms, except for FedGAM-CV. This improvement is attributed to the use of zeroth-order flatness and dynamic regularization techniques, which significantly enhance its performance. However, FedSMOO’s test accuracy on CIFAR-10 is considerably lower than that of MOON, FedGAM, and FedGAM-CV, suggesting that FedSMOO’s performance is unstable. This instability may be related to the complexity of the dataset or the network model. On the FashionMNIST dataset, compared to FedAvg, which is widely used as a benchmark, FedGAM’s test accuracy improves by 4.76%, 4.21%, and 2.31% under the Dir(0.3), Dir(0.7), and IID data partitioning methods, respectively, while FedGAM-CV improves by 5.56 % , 5.11 % , and 3.64 % , respectively. Similar results are observed on MNIST and CIFAR-10. Notably, as data heterogeneity increases, the accuracy gap between FedAvg and our proposed algorithms widens, demonstrating the stronger robustness of FedGAM and FedGAM-CV in heterogeneous data environments. Through these quantitative comparisons, the effectiveness of our proposed algorithms (FedGAM and FedGAM-CV) in enhancing the generalization capability of the global model is proven.
The gap between training and test accuracy reveals whether an algorithm is overfitting the training data. Due to the use of contrastive loss, MOON achieves the highest training accuracy in most cases on FashionMNIST, but the huge gap between training and test accuracy suggests that MOON’s global model may be overfitting the training dataset. On the FashionMNIST dataset, the gaps between training and test accuracy for MOON under three data partitioning methods (Dir(0.3), Dir(0.7), and IID) 2.17 % , 2.95 % , and 3.92 % , respectively. In contrast, the gaps for FedGAM are 0.74 % , 1.64 % , and 1.79 % , respectively. For FedGAM-CV, the gaps are: 0.61 % , 1.12 % , and 0.92 % , respectively— lowest among all algorithms. Similar results are observed on the MNIST and CIFAR-10 datasets. The small gap between training and test accuracy indicates that FedGAM and FedGAM-CV do not overfit the training dataset and achieve improved generalization.
FedGAM outperforms most baseline algorithms in terms of test accuracy and the gap between training and test accuracy. However, FedGAM-CV surpasses FedGAM. The difference between FedGAM and FedGAM-CV lies in the incorporation of control variable techniques, which indicates that the addition of control variables enhances the model’s global generalization capability and mitigates overfitting, thereby demonstrating the effectiveness of the control variable technique.

5.3. In-Depth Experiments

We carry out additional experiments in the following three areas to further analyze and investigate the performance of the proposed algorithms: (1) the impact of client drift, (2) loss landscape visualization, and (3) the hyperparameter sensitivity of FedGAM and FedGAM-CV.

5.3.1. Robustness to Client Drift

To investigate how client drift affects various FL algorithms, we conduct separate ablation studies investigating the following three primary factors: the degree of data heterogeneity (Non-IID level), the number of local epochs, and the participation rate of clients. Table 3 and Figure 3 illustrate the impact of different levels of data heterogeneity on each FL algorithm, with local epoch = 15 and a client participation rate of 1.0. Table 4 presents the impact of local epochs on test accuracy under Dir(0.3) with a client participation rate of 1.0. Table 5 shows the effect of the client participation rate on test accuracy under Dir(0.3) with local epoch = 15.
(1) Non-IID level: From Figure 3, it can be observed that as the degree of data heterogeneity increases, the overall performance of the global FL model declines. For instance, when the data partitioning method shifts from IID to Dir(0.3), the test accuracies of FedAvg, FedProx, MOON, SCAFFOLD, FedNova, FedSAM, FedSMOO, FedGAM, and FedGAM-CV on the FashionMNIST dataset decrease by 2.97 % , 5.81 % , 1.07 % , 2.82 % , 2.4 % , 2.64 % , 1.63 % , 0.52 % , and 1.05 % , respectively. Our proposed FedGAM and FedGAM-CV exhibit the first and second smallest declines, respectively. Notably, when the data partitioning method changes from Dir(0.7) to Dir(0.3), FedSAM’s test accuracy on CIFAR-10 decreases by only 0.43 % (the smallest observed decline) due to FedSAM’s utilization of sharpness-aware techniques. However, the test accuracy of FedSAM remains significantly lower than that of FedGAM and FedGAM-CV. The above analysis indicates that our method demonstrates a certain robustness to variations of Non-IID level while achieving higher generalization.
(2) Local epochs: As illustrated in Table 4, the test accuracies of FedGAM and FedGAM-CV increase with the number of local epochs. However, when the number of local epochs changes from 15 to 20, the test accuracies of MOON and FedNova decrease, indicating that their local models overfit at E = 20 , while our proposed algorithms do not exhibit local overfitting issues. When the number of local epochs decreases from 20 to 5, the test accuracies of FedAvg, FedProx, MOON, SCAFFOLD, FedNova, FedSAM, FedSMOO, FedGAM, and FedGAM-CV decrease by 10.25 % , 7.55 % , 2.59 % , 9.46 % , 9.08 % , 10.27 % , 3.96 % , 3.77 % , and 3.01 % , respectively. Excluding the overfitting MOON model, our proposed algorithms, i.e., FedGAM and FedGAM-CV, show the smallest declines. Moreover, disregarding MOON, as the number of local epochs changes, FedGAM and FedGAM-CV exhibit lower variance in test accuracy than other algorithms, indicating that our methods are highly robust to changes in the number of local epochs.
(3) Participation rate: A higher participation rate indicates that more clients are involved in the training process. As shown in Table 5, the test accuracy of different FL algorithms improves with an increase in participation rate. Notably, our proposed FedGAM-CV and FedGAM achieve higher test accuracies across various client participation rates, outperforming all baseline algorithms. Additionally, compared to other baseline algorithms, FedGAM and FedGAM-CV exhibit smaller variances in test accuracy. This suggests that FedGAM and FedGAM-CV have are more robust to changes in the client participation rate than the compared methods.

5.3.2. Loss Surface Visualization

To visualize the loss landscape at the point of minimum loss and compare the flatness, we select FedAvg, FedSAM, and FedSMOO as comparison objects. FedAvg is a widely used baseline algorithm, while FedSAM and FedSMOO introduce zeroth-order flatness to enhance model generalization. Following the method described in [50], we plot the loss surfaces shown in Figure 4.
From Figure 4, it can be observed that FedGAM and FedGAM-CV exhibit lower loss values and flatter loss landscapes, indicating stronger generalization abilities. Notably, the flatness of the loss landscape for FedGAM-CV is not significantly different from that of FedGAM (FedGAM-CV is slightly flatter). This is because both FedGAM and FedGAM-CV utilize first-order flatness, which demonstrates the effectiveness of first-order flatness in enhancing the flatness of the loss landscape. Additionally, the minimum loss value of FedGAM-CV is lower than that of FedGAM. This is because FedGAM directly averages local models to obtain the global model, which may still result in reduced global model performance due to client drift. In contrast, FedGAM-CV employs the control variate technique, steering the entire training process towards global flatness. The analysis indicates that our proposed FedGAM-CV effectively enhances the flatness of the global model’s loss landscape, thereby improving the generalization ability of the global model.

5.3.3. Hyperparameter Sensitivity

The perturbation radius ( ρ ) measures how much of the region we can explore to achieve a flat minimum, while the trade-off coefficient ( α ) represents the proportion of the maximum gradient norm in the objective function during local model training. To investigate the impact of ρ and α on FedGAM and FedGAM-CV, we conduct parameter sensitivity experiments on the FashionMNIST dataset, with the results shown in Figure 5.
From Figure 5, it can be observed that the trade-off coefficient ( α ) should not be too large. When α is set to 0.5, the performance of FedGAM and FedGAM-CV is even worse than that of classical federated learning (when the trade-off coefficient is 0, FedGAM degrades to classical ERM-based federated learning). For FedGAM, a trade-off coefficient of 0.2 is more appropriate, while for FedGAM-CV, a trade-off coefficient of 0.4 is more suitable. Additionally, FedGAM and FedGAM-CV are relatively robust to changes in the perturbation radius, with ρ set to 0.1 being the most appropriate.

6. Conclusions

In this paper, we improve FL models’ generalization capacity to tackle the problem of client drift. To achieve this, we propose the following two algorithms: FedGAM and FedGAM-CV. FedGAM introduces the Gradient Norm Minimization algorithm, modifying the client model training objective to simultaneously minimize the loss value and first-order flatness, ultimately aiming to find local flat minima and, thereby, improve the generalization capability of local models. To directly flatten the global model’s loss landscape and enhance its generalization capability, we build upon FedGAM by employing the control variate technique. This involves introducing control variates on both the server and client sides to represent their update directions, thereby correcting and guiding clients to train their models in the direction of a globally flat landscape. Experiments conducted on CIFAR-10, MNIST, and FashionMNIST datasets demonstrate that FedGAM significantly outperforms the classic FedAvg and approaches the performance of the latest and most effective algorithm, FedSMOO. FedGAM-CV consistently achieves higher test accuracy than baseline algorithms (FedAvg, FedProx, FedNova, MOON, SCAFFOLD, FedSAM, and FedSMOO) and FedGAM under various client drift conditions. The models trained with FedGAM-CV also exhibit lower loss values and flatter loss landscapes, indicating stronger generalization capabilities. The experimental results demonstrate that our proposed methods greatly improve model generalization and effectively address the client drift problem, validating the efficacy of the control variate technique.
In the future, we plan to prove the convergence of our proposed algorithms and analytically compute their generalization bounds in order to assess their performance even more. Furthermore, even though our algorithms have produced excellent results in stable communication situations, this research does not take into account the fact that FL often undergoes training in unstable communication environments. Consequently, in order to accomplish additional improvement, we will keep investigating how unstable communication situations affect our algorithms.

Author Contributions

Conceptualization, Y.X.; Funding acquisition, Y.W.; Methodology, Y.X. and W.M.; Project administration, C.D.; Supervision, C.D. and W.M.; Validation, Y.X. and H.Z.; Visualization, Y.X.; Writing—original draft, Y.X.; Writing—review and editing, W.M. and H.Z. All authors have read and agreed to the published version of the manuscript.

Funding

This work was supported by the General Program of the National Natural Science Foundation of China, grant number 61871388.

Data Availability Statement

All data generated or analyzed during this study are available from the corresponding author upon reasonable request.

Acknowledgments

The authors acknowledge the National University of Defense Technology.

Conflicts of Interest

The authors declare no conflicts of interest.

References

  1. Saranya, T.; Deisy, C.; Sridevi, S.; Anbananthen, K.S.M. A comparative study of deep learning and Internet of Things for precision agriculture. Eng. Appl. Artif. Intell. 2023, 122, 106034. [Google Scholar] [CrossRef]
  2. Subhashini, R.; Khang, A. The role of Internet of Things (IoT) in smart city framework. In Smart Cities; CRC Press: Boca Raton, FL, USA, 2023; pp. 31–56. [Google Scholar]
  3. Aminizadeh, S.; Heidari, A.; Toumaj, S.; Darbandi, M.; Navimipour, N.J.; Rezaei, M.; Talebi, S.; Azad, P.; Unal, M. The applications of machine learning techniques in medical data processing based on distributed computing and the Internet of Things. Comput. Methods Programs Biomed. 2023, 241, 107745. [Google Scholar] [CrossRef]
  4. Rajak, P.; Ganguly, A.; Adhikary, S.; Bhattacharya, S. Internet of Things and smart sensors in agriculture: Scopes and challenges. J. Agric. Food Res. 2023, 14, 100776. [Google Scholar] [CrossRef]
  5. Ng, D.T.K.; Lee, M.; Tan, R.J.Y.; Hu, X.; Downie, J.S.; Chu, S.K.W. A review of AI teaching and learning from 2000 to 2020. Educ. Inf. Technol. 2023, 28, 8445–8501. [Google Scholar] [CrossRef]
  6. Cetinic, E.; She, J. Understanding and creating art with AI: Review and outlook. ACM Trans. Multimed. Comput. Commun. Appl. (TOMM) 2022, 18, 1–22. [Google Scholar] [CrossRef]
  7. Cao, L. Ai in finance: Challenges, techniques, and opportunities. ACM Comput. Surv. (CSUR) 2022, 55, 1–38. [Google Scholar]
  8. Wu, X.; Xiao, L.; Sun, Y.; Zhang, J.; Ma, T.; He, L. A survey of human-in-the-loop for machine learning. Future Gener. Comput. Syst. 2022, 135, 364–381. [Google Scholar] [CrossRef]
  9. Yang, P.; Xiong, N.; Ren, J. Data security and privacy protection for cloud storage: A survey. IEEE Access 2020, 8, 131723–131740. [Google Scholar] [CrossRef]
  10. Li, Q.; Wen, Z.; Wu, Z.; Hu, S.; Wang, N.; Li, Y.; Liu, X.; He, B. A survey on federated learning systems: Vision, hype and reality for data privacy and protection. IEEE Trans. Knowl. Data Eng. 2021, 35, 3347–3366. [Google Scholar] [CrossRef]
  11. McMahan, H.B.; Moore, E.; Ramage, D.; y Arcas, B.A. Federated learning of deep networks using model averaging. arXiv 2016, arXiv:1602.056292. [Google Scholar]
  12. Guendouzi, B.S.; Ouchani, S.; Assaad, H.E.; Zaher, M.E. A systematic review of federated learning: Challenges, aggregation methods, and development tools. J. Netw. Comput. Appl. 2023, 220, 103714. [Google Scholar] [CrossRef]
  13. Zhang, C.; Xie, Y.; Bai, H.; Yu, B.; Li, W.; Gao, Y. A survey on federated learning. Knowl.-Based Syst. 2021, 216, 106775. [Google Scholar] [CrossRef]
  14. Kairouz, P.; McMahan, H.B.; Avent, B.; Bellet, A.; Bennis, M.; Bhagoji, A.N.; Bonawitz, K.; Charles, Z.; Cormode, G.; Cummings, R.; et al. Advances and open problems in federated learning. Found. Trends® Mach. Learn. 2021, 14, 1–210. [Google Scholar] [CrossRef]
  15. Li, Q.; Diao, Y.; Chen, Q.; He, B. Federated learning on non-iid data silos: An experimental study. In Proceedings of the 2022 IEEE 38th International Conference on Data Engineering (ICDE), Kuala Lumpur, Malaysia, 9–12 May 2022; pp. 965–978. [Google Scholar]
  16. Li, T.; Sahu, A.K.; Zaheer, M.; Sanjabi, M.; Talwalkar, A.; Smith, V. Federated optimization in heterogeneous networks. Proc. Mach. Learn. Syst. 2020, 2, 429–450. [Google Scholar]
  17. Zhu, F.; Zhang, J.; Liu, S.; Wang, X. DRAG: Divergence-based Adaptive Aggregation in Federated learning on Non-IID Data. arXiv 2023, arXiv:2309.01779. [Google Scholar]
  18. Acar, D.A.E.; Zhao, Y.; Navarro, R.M.; Mattina, M.; Whatmough, P.N.; Saligrama, V. Federated learning based on dynamic regularization. arXiv 2021, arXiv:2111.04263. [Google Scholar]
  19. Karimireddy, S.P.; Kale, S.; Mohri, M.; Reddi, S.; Stich, S.; Suresh, A.T. Scaffold: Stochastic controlled averaging for federated learning. In Proceedings of the International Conference on Machine Learning, Online, 13–18 July 2020; pp. 5132–5143. [Google Scholar]
  20. Li, Q.; He, B.; Song, D. Model-contrastive federated learning. In Proceedings of the IEEE/CVF Conference on Computer Vision and Pattern Recognition, Nashville, TN, USA, 20–25 June 2021; pp. 10713–10722. [Google Scholar]
  21. Li, Z.; Lin, T.; Shang, X.; Wu, C. Revisiting weighted aggregation in federated learning with neural networks. In Proceedings of the International Conference on Machine Learning, Honolulu, HI, USA, 23–29 July 2023; pp. 19767–19788. [Google Scholar]
  22. Zhang, J.; Li, A.; Tang, M.; Sun, J.; Chen, X.; Zhang, F.; Chen, C.; Chen, Y.; Li, H. Fed-cbs: A heterogeneity-aware client sampling mechanism for federated learning via class-imbalance reduction. In Proceedings of the International Conference on Machine Learning, Honolulu, HI, USA, 23–29 July 2023; pp. 41354–41381. [Google Scholar]
  23. Chen, H.; Chao, W. Fedbe: Making bayesian model ensemble applicable to federated learning. arXiv 2020, arXiv:2009.01974. [Google Scholar]
  24. Park, S.; Suh, Y.; Lee, J. FedPSO: Federated learning using particle swarm optimization to reduce communication costs. Sensors 2021, 21, 600. [Google Scholar] [CrossRef]
  25. Wang, H.; Yurochkin, M.; Sun, Y.; Papailiopoulos, D.; Khazaeni, Y. Federated learning with matched averaging. arXiv 2020, arXiv:2002.06440. [Google Scholar]
  26. Wang, J.; Liu, Q.; Liang, H.; Joshi, G.; Poor, H.V. Tackling the objective inconsistency problem in heterogeneous federated optimization. Adv. Neural Inf. Process. Syst. 2020, 33, 7611–7623. [Google Scholar]
  27. Chaudhari, P.; Choromanska, A.; Soatto, S.; LeCun, Y.; Baldassi, C.; Borgs, C.; Chayes, J.; Sagun, L.; Zecchina, R. Entropy-sgd: Biasing gradient descent into wide valleys. J. Stat. Mech. Theory Exp. 2019, 2019, 124018. [Google Scholar] [CrossRef]
  28. Foret, P.; Kleiner, A.; Mobahi, H.; Neyshabur, B. Sharpness-aware Minimization for Efficiently Improving Generalization. In Proceedings of the International Conference on Learning Representations, Virtual, 3–7 May 2021. [Google Scholar]
  29. Keskar, N.S.; Socher, R. Improving generalization performance by switching from adam to sgd. arXiv 2017, arXiv:1712.07628. [Google Scholar]
  30. Dziugaite, G.K.; Roy, D.M. Computing nonvacuous generalization bounds for deep (stochastic) neural networks with many more parameters than training data. arXiv 2017, arXiv:1703.11008. [Google Scholar]
  31. Jiang, Y.; Neyshabur, B.; Mobahi, H.; Krishnan, D.; Bengio, S. Fantastic generalization measures and where to find them. arXiv 2019, arXiv:1912.02178. [Google Scholar]
  32. Zhang, X.; Xu, R.; Yu, H.; Zou, H.; Cui, P. Gradient norm aware minimization seeks first-order flatness and improves generalization. In Proceedings of the IEEE/CVF Conference on Computer Vision and Pattern Recognition, Vancouver, BC, Canada, 17–24 June 2023; pp. 20247–20257. [Google Scholar]
  33. Jia, Z.; Su, H. Information-theoretic local minima characterization and regularization. In Proceedings of the International Conference on Machine Learning, Virtual, 13–18 July 2020; pp. 4773–4783. [Google Scholar]
  34. Kaur, S.; Cohen, J.; Lipton, Z.C. On the maximum hessian eigenvalue and generalization. In Proceedings of the Proceedings on “I Can’t Believe It’s Not Better!—Understanding Deep Learning Through Empirical Falsification” at NeurIPS 2022 Workshops, New Orleans, LA, USA, 3 December 2022; pp. 51–65. [Google Scholar]
  35. Keskar, N.S.; Mudigere, D.; Nocedal, J.; Smelyanskiy, M.; Tang, P.T.P. On large-batch training for deep learning: Generalization gap and sharp minima. arXiv 2016, arXiv:1609.04836. [Google Scholar]
  36. Zhuang, J.; Gong, B.; Yuan, L.; Cui, Y.; Adam, H.; Dvornek, N.; Tatikonda, S.; Duncan, J.; Liu, T. Surrogate gap minimization improves sharpness-aware training. arXiv 2022, arXiv:2203.08065. [Google Scholar]
  37. Qu, Z.; Li, X.; Duan, R.; Liu, Y.; Tang, B.; Lu, Z. Generalized federated learning via sharpness aware minimization. In Proceedings of the International Conference on Machine Learning, Baltimore, MD, USA, 17–23 July 2022; pp. 18250–18280. [Google Scholar]
  38. Caldarola, D.; Caputo, B.; Ciccone, M. Improving generalization in federated learning by seeking flat minima. In Proceedings of the European Conference on Computer Vision, Tel Aviv, Israel, 23–27 October 2022; pp. 654–672. [Google Scholar]
  39. Sun, Y.; Shen, L.; Huang, T.; Ding, L.; Tao, D. Fedspeed: Larger local interval, less communication round, and higher generalization accuracy. arXiv 2023, arXiv:2302.10429. [Google Scholar]
  40. Sun, Y.; Shen, L.; Chen, S.; Ding, L.; Tao, D. Dynamic regularized sharpness aware minimization in federated learning: Approaching global consistency and smooth landscape. In Proceedings of the International Conference on Machine Learning, Honolulu, HI, USA, 23–29 July 2023; pp. 32991–33013. [Google Scholar]
  41. Panchal, K.; Choudhary, S.; Mitra, S.; Mukherjee, K.; Sarkhel, S.; Mitra, S.; Guan, H. Flash: Concept drift adaptation in federated learning. In Proceedings of the International Conference on Machine Learning, Honolulu, HI, USA, 23–29 July 2023; pp. 26931–26962. [Google Scholar]
  42. Li, Q.; He, B.; Song, D. Adversarial collaborative learning on non-iid features. In Proceedings of the International Conference on Machine Learning, Honolulu, HI, USA, 23–29 July 2023; pp. 19504–19526. [Google Scholar]
  43. Xie, L.; Liu, J.; Lu, S.; Chang, T.H.; Shi, Q. An efficient learning framework for federated XGBoost using secret sharing and distributed optimization. ACM Trans. Intell. Syst. Technol. (TIST) 2022, 13, 1–28. [Google Scholar] [CrossRef]
  44. Dai, R.; Yang, X.; Sun, Y.; Shen, L.; Tian, X.; Wang, M.; Zhang, Y. Fedgamma: Federated learning with global sharpness-aware minimization. IEEE Trans. Neural Netw. Learn. Syst. 2023, 1–14. [Google Scholar] [CrossRef]
  45. Xiao, H.; Rasul, K.; Vollgraf, R. Fashion-mnist: A novel image dataset for benchmarking machine learning algorithms. arXiv 2017, arXiv:1708.07747. [Google Scholar]
  46. Krizhevsky, A.; Hinton, G. Learning multiple layers of features from tiny images. Handb. Syst. Autoimmune Dis. 2009, 1, 4. [Google Scholar]
  47. Deng, L. The mnist database of handwritten digit images for machine learning research [best of the web]. IEEE Signal Process. Mag. 2012, 29, 141–142. [Google Scholar] [CrossRef]
  48. Hsu, T.M.H.; Qi, H.; Brown, M. Measuring the effects of non-identical data distribution for federated visual classification. arXiv 2019, arXiv:1909.06335. [Google Scholar]
  49. He, K.; Zhang, X.; Ren, S.; Sun, J. Deep residual learning for image recognition. In Proceedings of the IEEE Conference on Computer Vision and Pattern Recognition, Las Vegas, NV, USA, 27–30 June 2016; pp. 770–778. [Google Scholar]
  50. Li, H.; Xu, Z.; Taylor, G.; Studer, C.; Goldstein, T. Visualizing the loss landscape of neural nets. arXiv 2018, arXiv:1712.09913. [Google Scholar]
Figure 1. Diagram of FedGAM-CV. In each communication round, the server sends the global model (w) and the global control variable (c) to the clients. On the client side, each client first performs a local flatness search based on GAM. Then, it uses the difference between the local variable ( c i ) and the global variable (c), i.e., c c i , to correct the optimization direction. This correction mimics the expected global flatness search direction. This alignment can be seen as bringing each client together to search for global flatness.
Figure 1. Diagram of FedGAM-CV. In each communication round, the server sends the global model (w) and the global control variable (c) to the clients. On the client side, each client first performs a local flatness search based on GAM. Then, it uses the difference between the local variable ( c i ) and the global variable (c), i.e., c c i , to correct the optimization direction. This correction mimics the expected global flatness search direction. This alignment can be seen as bringing each client together to search for global flatness.
Mathematics 12 02644 g001
Figure 2. Visualization of Non-IID property in federated learning (FL). (ac) Dirichlet (Dir( α )) label-based imbalanced partitioning on 100 clients with the CIFAR-10 (Dir(0.3)), CIFAR-10 (Dir(0.7)), and CIFAR-10 (IID) datasets, respectively. Each row represents a client, and the color of each rectangle indicates the number of data samples of a particular class belonging to a client. Dark green signifies a high quantity, while light green indicates a low quantity.
Figure 2. Visualization of Non-IID property in federated learning (FL). (ac) Dirichlet (Dir( α )) label-based imbalanced partitioning on 100 clients with the CIFAR-10 (Dir(0.3)), CIFAR-10 (Dir(0.7)), and CIFAR-10 (IID) datasets, respectively. Each row represents a client, and the color of each rectangle indicates the number of data samples of a particular class belonging to a client. Dark green signifies a high quantity, while light green indicates a low quantity.
Mathematics 12 02644 g002
Figure 3. Test accuracy of different federated learning algorithms under three data partitioning methods on the MNIST, FashionMNIST, and CIFAR-10 datasets. (ac) Results on the MNIST dataset under three heterogeneous data environments: IID, Dir(0.3), and Dir(0.7). (df) Results on the FashionMNIST dataset under three heterogeneous data environments: IID, Dir(0.3), and Dir(0.7). (gi) Results on the MNIST dataset under three heterogeneous data environments: IID, Dir(0.3), and Dir(0.7).
Figure 3. Test accuracy of different federated learning algorithms under three data partitioning methods on the MNIST, FashionMNIST, and CIFAR-10 datasets. (ac) Results on the MNIST dataset under three heterogeneous data environments: IID, Dir(0.3), and Dir(0.7). (df) Results on the FashionMNIST dataset under three heterogeneous data environments: IID, Dir(0.3), and Dir(0.7). (gi) Results on the MNIST dataset under three heterogeneous data environments: IID, Dir(0.3), and Dir(0.7).
Mathematics 12 02644 g003
Figure 4. Loss surface of FedAvg, FedSAM, FedSMOO, FedGAM, and FedGAM-CV with ResNet-18 on CIFAR-10 with the Dir(0.3) partition method, a 60 % participation rate, 100 communication rounds, and 15 local epochs. The red surface represents the loss landscape of FedGAM-CV, while the blue surface represents the loss landscapes of FedAvg (a), FedSAM (b), FedSMOO (c), and FedGAM (d). The x and y axes represent the weights of two randomly sampled orthogonal Gaussian perturbations, and the z axis represents the loss value. The minimum loss values for FedAvg, FedSAM, FedSMOO, FedGAM, and FedGAM-CV are 0.403, 0.402, 0.398, 0.392, and 0.389, respectively.
Figure 4. Loss surface of FedAvg, FedSAM, FedSMOO, FedGAM, and FedGAM-CV with ResNet-18 on CIFAR-10 with the Dir(0.3) partition method, a 60 % participation rate, 100 communication rounds, and 15 local epochs. The red surface represents the loss landscape of FedGAM-CV, while the blue surface represents the loss landscapes of FedAvg (a), FedSAM (b), FedSMOO (c), and FedGAM (d). The x and y axes represent the weights of two randomly sampled orthogonal Gaussian perturbations, and the z axis represents the loss value. The minimum loss values for FedAvg, FedSAM, FedSMOO, FedGAM, and FedGAM-CV are 0.403, 0.402, 0.398, 0.392, and 0.389, respectively.
Mathematics 12 02644 g004
Figure 5. Test accuracy of FedGAM and FedGAM-CV on FashionMNIST under different perturbation radii and trade-off coefficients. (a) Test accuracy under different perturbation radius ( ρ ). (b) Test accuracy under different trade-off coefficient ( α ).
Figure 5. Test accuracy of FedGAM and FedGAM-CV on FashionMNIST under different perturbation radii and trade-off coefficients. (a) Test accuracy under different perturbation radius ( ρ ). (b) Test accuracy under different trade-off coefficient ( α ).
Mathematics 12 02644 g005
Table 1. Commonly used symbols.
Table 1. Commonly used symbols.
SymbolDescriptionSymbolDescription
NTotal number of clients w , w i Global model, local model of client i
TCommunication roundsELocal epochs
L ( · ; · ) Loss functionGradient operator
α Trade-off coefficient ρ Perturbation radius
p i Aggregation weight of client i η Learning rate
KNumber of clients participating in trainingc, c i Global control variable, control variable of client i
S t Set of clients participating in the t-th communication round n i Size of the random sample taken from client i’s training data
Table 2. Baseline algorithms.
Table 2. Baseline algorithms.
AlgorithmMajor Contributions
FedAvg [11]The algorithm first proposes the federated framework that incorporates partial participation and local multiple training.
FedProx [16]The algorithm guarantees the local collective objectives through local regularization.
FedNova [26]The algorithm normalizes and scales each client’s local updates based on their local steps before updating the global model.
SCAFFOLD [19]The algorithm employs the control variate technique to mitigate the client drift issue.
MOON [20]The algorithm introduces a contrastive loss to enhance the model’s generalization ability by reducing the distance between the representations learned by the local model and those learned by the global model.
FedSAM [37]The algorithm incorporates zeroth-order flatness into the client’s objective function to enhance the model’s generalization ability by seeking flat minima.
FedSMOO [40]Building on FedSAM, the algorithm employs a dynamic regularizer to ensure that the local objectives align with the global objective.
Table 3. Test accuracy and training accuracy of the proposed algorithms (FedGAM-CV and FedGAM) vs. seven federated learning baseline algorithms on three datasets (MNIST, FashionMNIST, and CIFAR-10) under three data partitioning strategy conditions (Dir(0.3), Dir(0.7), and IID).
Table 3. Test accuracy and training accuracy of the proposed algorithms (FedGAM-CV and FedGAM) vs. seven federated learning baseline algorithms on three datasets (MNIST, FashionMNIST, and CIFAR-10) under three data partitioning strategy conditions (Dir(0.3), Dir(0.7), and IID).
DatasetAlgorithmDir(0.3)Dir(0.7)IID
TestTrainTestTrainTestTrain
MNISTFedAvg95.8396.8297.1997.8498.0098.69
FedProx95.6596.4097.1397.9997.9598.17
MOON98.5698.7798.5799.1898.9699.29
SCAFFOLD97.6298.1397.8498.6798.3198.68
FedNova96.9597.8697.2197.9398.0798.93
FedSAM96.2097.5697.2097.9898.0798.61
FedSMOO98.8999.4298.9399.5798.9099.60
FedGAM98.6998.8598.7498.9698.8999.13
FedGAM-CV98.9899.4899.1199.5199.1599.53
FashionMNISTFedAvg83.0585.0384.0985.2686.0287.09
FedProx77.8679.9080.4182.2283.6785.47
MOON87.1789.3487.8190.7688.2492.16
SCAFFOLD83.7284.4885.0386.1286.5487.68
FedNova83.4785.1984.1586.2885.8788.65
FedSAM83.4085.2284.5486.3286.0488.22
FedSMOO87.6289.3288.6690.4689.2591.57
FedGAM87.8188.5588.3089.9488.3390.12
FedGAM-CV88.6189.2289.2090.3289.6690.58
CIFAR-10FedAvg72.9874.5576.5778.8979.6183.40
FedProx72.5074.0676.1178.5179.1582.94
MOON83.1285.5486.2889.6787.3392.87
SCAFFOLD74.9676.4179.6480.9084.3687.37
FedNova73.9375.3276.8279.2279.4783.18
FedSAM76.5379.1976.9679.3180.1783.79
FedSMOO78.2181.8579.5282.6281.1383.22
FedGAM84.8886.0786.6288.8887.5389.35
FedGAM-CV88.2089.2490.0591.7092.4093.83
Table 4. Effects of local epochs on the FashionMNIST dataset under the Dir(0.3) partition method with a 1.0 participation rate. σ ( E ) represents the variance in test accuracy across different local epochs.
Table 4. Effects of local epochs on the FashionMNIST dataset under the Dir(0.3) partition method with a 1.0 participation rate. σ ( E ) represents the variance in test accuracy across different local epochs.
AlgorithmE = 5E = 10E = 15E = 20 σ (E)
FedAvg73.8080.4283.0584.05 1.6 × 10 3
FedProx71.5675.9877.8679.11 8.2 × 10 4
FedSAM74.1981.0083.4084.46 1.6 × 10 3
SCAFFOLD75.2982.0283.7284.75 1.36 × 10 2
FedNova74.3281.2383.4781.40 1.19 × 10 3
MOON84.4187.1387.1786.00 1.3 × 10 4
FedSMOO84.4386.9787.6288.39 2.3 × 10 4
FedGAM84.4086.9987.8188.17 2.2 × 10 4
FedGAM-CV85.9088.2088.6188.91 1.4 × 10 4
Table 5. Effects of participation rate on the FashionMNIST dataset under the Dir(0.3) partition method with 15 local epochs. σ ( C ) represents the variance in test accuracy across different participation rates.
Table 5. Effects of participation rate on the FashionMNIST dataset under the Dir(0.3) partition method with 15 local epochs. σ ( C ) represents the variance in test accuracy across different participation rates.
AlgorithmC = 0.2C = 0.4C = 0.6C = 0.8C = 1.0 σ (C)
FedAvg80.4381.6282.0282.8083.05 8.7 × 10 5
FedProx75.0476.1976.6076.7377.86 8.3 × 10 5
FedSAM80.7083.2583.2883.3683.40 3.3 × 10 5
SCAFFOLD81.9082.9583.1083.6483.72 4.2 × 10 5
FedNova81.8782.7383.2683.3483.47 3.4 × 10 5
MOON85.2286.3886.4886.9487.17 4.6 × 10 5
FedSMOO84.8585.1686.0286.9587.62 1.1 × 10 4
FedGAM86.2287.0987.3887.7387.81 3.2 × 10 5
FedGAM-CV87.2988.3888.4388.5488.61 2.4 × 10 5
Disclaimer/Publisher’s Note: The statements, opinions and data contained in all publications are solely those of the individual author(s) and contributor(s) and not of MDPI and/or the editor(s). MDPI and/or the editor(s) disclaim responsibility for any injury to people or property resulting from any ideas, methods, instructions or products referred to in the content.

Share and Cite

MDPI and ACS Style

Xu, Y.; Ma, W.; Dai, C.; Wu, Y.; Zhou, H. Generalized Federated Learning via Gradient Norm-Aware Minimization and Control Variables. Mathematics 2024, 12, 2644. https://doi.org/10.3390/math12172644

AMA Style

Xu Y, Ma W, Dai C, Wu Y, Zhou H. Generalized Federated Learning via Gradient Norm-Aware Minimization and Control Variables. Mathematics. 2024; 12(17):2644. https://doi.org/10.3390/math12172644

Chicago/Turabian Style

Xu, Yicheng, Wubin Ma, Chaofan Dai, Yahui Wu, and Haohao Zhou. 2024. "Generalized Federated Learning via Gradient Norm-Aware Minimization and Control Variables" Mathematics 12, no. 17: 2644. https://doi.org/10.3390/math12172644

Note that from the first issue of 2016, this journal uses article numbers instead of page numbers. See further details here.

Article Metrics

Back to TopTop