Next Article in Journal
Feature Importance in Gradient Boosting Trees with Cross-Validation Feature Selection
Previous Article in Journal
Correction: Brochet et al. A Quantitative Comparison between Shannon and Tsallis–Havrda–Charvat Entropies Applied to Cancer Outcome Prediction. Entropy 2022, 24, 436
 
 
Font Type:
Arial Georgia Verdana
Font Size:
Aa Aa Aa
Line Spacing:
Column Width:
Background:
Article

Robust Aggregation for Federated Learning by Minimum γ-Divergence Estimation

1
Institute of Statistical Science, Academia Sinica, Taipei City 11529, Taiwan
2
Data Science Degree Program, National Taiwan University, Taipei City 10617, Taiwan
3
Institute of Epidemiology and Preventive Medicine, National Taiwan University, Taipei City 10055, Taiwan
*
Author to whom correspondence should be addressed.
Entropy 2022, 24(5), 686; https://doi.org/10.3390/e24050686
Submission received: 30 March 2022 / Revised: 10 May 2022 / Accepted: 11 May 2022 / Published: 13 May 2022
(This article belongs to the Section Information Theory, Probability and Statistics)

Abstract

:
Federated learning is a framework for multiple devices or institutions, called local clients, to collaboratively train a global model without sharing their data. For federated learning with a central server, an aggregation algorithm integrates model information sent from local clients to update the parameters for a global model. Sample mean is the simplest and most commonly used aggregation method. However, it is not robust for data with outliers or under the Byzantine problem, where Byzantine clients send malicious messages to interfere with the learning process. Some robust aggregation methods were introduced in literature including marginal median, geometric median and trimmed-mean. In this article, we propose an alternative robust aggregation method, named γ -mean, which is the minimum divergence estimation based on a robust density power divergence. This γ -mean aggregation mitigates the influence of Byzantine clients by assigning fewer weights. This weighting scheme is data-driven and controlled by the γ value. Robustness from the viewpoint of the influence function is discussed and some numerical results are presented.

1. Introduction

Federated learning (FL), a distributed learning paradigm, was proposed by Google in 2016 [1] for distributing the optimization of model parameters over multiple machines. The distributed framework allows the data to be stored in local devices without sharing with each other. These machines, called clients, collaboratively train a global model. In this article, we consider the case that there is a central server, which works as a coordinator and aggregates the gradient information sent by the local clients to update the global model through an iterative process [2]. In real-world applications, clients, such as hospitals and clinical institutions, hold highly sensitive and private data such as electronic healthcare records, medical images, etc. [3]. However, FL still encounters several challenges in a real world setting. In its training process, some unreliable clients may produce outlying data information or even send malicious values [4], which leads to a biased result or even a failure of the training process. These arbitrary clients are called Byzantine clients. With a properly designed aggregation scheme, the effect by Byzantine clients can be reduced or excluded. Several papers have made significant contributions to developing Byzantine-resilience methods with specific use of robust aggregators, including marginal median [5], geometric median [6], and trimmed-mean [7]. Another aggregation framework is the Byzantine-robust stochastic aggregation method [8], which adds a regularization term in the objective function to handle the Byzantine and heterogeneous data problem.
In this article, we consider the case of FL, which consists of a central server and m clients. Assume that there is an α -fraction of Byzantine clients. Our goal is to minimize the following objective function:
L ( θ ) = i = 1 m π i E η i D [ l η i ( θ ) ] ,
where θ R p is the global model parameter, l η i ( θ ) is the loss function of client i with local data η i having an unknown distribution D , and { π i } i = 1 m are the prior weights assigned to local clients. At the t th round of FL iterative updates, the server broadcasts the current parameter value θ t R p to local clients. Normal clients faithfully compute an estimate of the local gradients E η i D l η i ( θ t ) and send the gradient information back to the server. On the contrary, Byzantine clients send arbitrary erroneous messages to obstruct the optimization process. The central server aggregates the gradient information from clients and updates the model parameter from θ t to θ t + 1 . At the end, the server should output an estimate of the optimal θ , which aims to minimize the objective loss (1). In this work we consider uniform client weights, i.e., π 1 = = π m = 1 . The main reason is that, under the Byzantine problem, the local sample sizes claimed by Byzantine clients might not be reliable.
Our main contributions of this work are listed below.
  • We propose the γ -mean as a robust aggregator in federated learning. This robust aggregator mitigates the influence of Byzantine clients by assigning fewer weights. The weighting scheme is data-driven and controlled by the tuning parameter γ .
  • We have a discussion on robustness from the influence function point of view. Benefits of adopting γ -mean can be seen from its influence function in comparison to other robust alternatives such as marginal median, geometric median and trimmed mean.
  • The robustness of γ -mean is then verified through simulation study and real data experiments. The γ -mean based aggregation outperforms other robust aggregators such as marginal median, geometric median and trimmed mean.
The rest of the article is organized as follows. In Section 2, we review some related work in robust federated learning. In Section 3, we propose our main method for robust FL aggregation and provide some theoretical viewpoints by the influence function. In Section 4, we conduct an extensive simulation study. In Section 5, we provide some further numerical examples using MNIST, fashion MNIST and chest X-ray images (pneumonia). In Section 6, we make some concluding remarks.

2. Related Work

McMahan et al. introduced the FedAvg algorithm [9], which used the mean as its aggregator. The sample mean aggregation is a very popular FL framework [9,10]. However, the sample mean is known to be vulnerable to outliers and heavy-tailed distribution. For instance, Byzantine clients [7,8,11] may send extreme values to strongly influence the aggregation of sample mean. There are some robust alternatives to the mean aggregator, such as marginal median, geometric median and trimmed mean. In contrast to mean, marginal median is a relatively robust aggregator by computing coordinate-wise medians. Another Byzantine-resilient aggregator is geometric median, which computes the geometric median of { x i } i = 1 m given by (see Weiszfeld’s algorithm [12] for solution computation).
GeoMed ( { x i } i = 1 m ) = arg min x R p i = 1 m x x i 2 .
The trimmed mean aggregator computes coordinate-wise trimmed means by removing the β -fraction of data points from each of the two tails [11]. Both the marginal median and geometric median, though are robust to outliers, but are not efficient, as they only use the most centrally-lying data for inference. As for trimmed mean, first it needs a proper selection of β . Second, it often does not work well for outliers lying on one side of a coordinate instead of on two sides. To have a better balance between robustness and efficiency, we will propose another robust aggregator based on a robust density power divergence, namely, the γ -divergence.

3. Proposed Aggregator and Its Robustness

In this section, we introduce an aggregator, called “ γ -mean”, which is based on the minimum γ -divergence estimation. Two versions of implementation algorithms are given and some robustness viewpoints based on the influence function are provided.

3.1. Minimum γ -Divergence Estimation

The γ -divergence [13,14], which is also known as the density power divergence of type zero [15], is a robust divergence against outliers with a tuning parameter γ > 0 . The γ -divergence between the data distribution with probability density function g and the model distribution f τ (indexed by τ ) is defined by
D γ ( g , f τ ) = g γ + 1 γ ( γ + 1 ) 1 f τ ( x ) f τ γ + 1 γ g ( x ) g γ + 1 d x ,
where f τ γ + 1 = ( f τ γ + 1 ( x ) d x ) 1 / ( γ + 1 ) . In the limiting case, it reduces to the Kullback-Leibler divergence, i.e., lim γ 0 D γ ( g , f τ ) = ln ( g ( x ) f τ ( x ) ) g ( x ) d x . In the population level, the minimum γ -divergence estimation of τ is given by
τ γ = arg min τ D γ ( g , f τ ) = arg max τ f τ ( x ) f τ γ + 1 γ g ( x ) d x .
In the sample level with empirical data { x i } i = 1 m , the data distribution will be replaced by the empirical distribution to get the estimate τ ^ γ :
τ ^ γ = arg min τ D γ ( g , f τ ) = arg max τ 1 m i = 1 m f τ ( x i ) f τ γ + 1 γ .
In this article, we use a Gaussian working model, i.e., f τ N ( μ , Σ ) . Suppose g takes a contaminated form: g ( x ) = ( 1 ϵ ) f τ ( x ) + ϵ h ( x ) , where h ( x ) is the probability density function of contamination distribution. It is assumed that f τ γ ( x ) h ( x ) is small for certain γ > 0 , so that the contamination has little effect in the learning process. By taking derivatives with respect to τ = ( μ , Σ ) and setting them to zero, the estimates ( μ ^ , Σ ^ ) have to satisfy the following estimating equations:
1 m i = 1 m d i ( γ ) ( μ ^ , Σ ^ ) ( x i μ ^ ) = 0 , 1 m i = 1 m d i ( γ ) ( μ ^ , Σ ^ ) 1 γ + 1 Σ ^ ( x i μ ^ ) ( x i μ ^ ) = 0 ,
where d i ( γ ) ( μ ^ , Σ ^ ) = exp γ 2 ( x i μ ^ ) Σ ^ 1 ( x i μ ^ ) . The solution pair, ( μ ^ , Σ ^ ) , has to satisfy the following stationary equations:
μ ^ = i = 1 m d i ( γ ) ( μ ^ , Σ ^ ) x i i = 1 m d i ( γ ) ( μ ^ , Σ ^ ) ,
Σ ^ = ( γ + 1 ) i = 1 m d i ( γ ) ( μ ^ , Σ ^ ) ( x i μ ^ ) ( x i μ ^ ) i = 1 m d i ( γ ) ( μ ^ , Σ ^ ) .

3.2. Robust Aggregation by γ -Mean

In view of the stationary Equations (4) and (5), we use a fixed-point iteration. Two algorithms are provided below. Algorithm 1 is the usual γ -mean, which uses N ( μ , Σ ) as the working model, and Algorithm 2 is the simple γ -mean, which adopts N ( μ , σ 2 I ) as the working model. Note that, in the former case when the sample size is not sufficiently large enough to have a stable estimate of the covariance inverse, we will use { diag ( Σ ^ ) } 1 instead, where Σ ^ is the minimum γ -divergence estimator for the covariance and diag ( Σ ^ ) denotes the diagonal matrix consisting of diagonal elements of Σ ^ . Also note that, in the latter case of simple γ -mean, σ 2 can be merged into the hyperparameter γ . Thus, we will simply use the standard Gaussian as the working model.
Algorithm 1  γ -mean with a Gaussian working model.
Input: Gradient information X = [ x 1 , x 2 , x m ] and the maximum number of iterations S
Output: μ γ = μ ^
  Start with initials μ ^ = μ ^ i n i and Σ ^ = Σ ^ i n i .
  for  s = 1 , 2 , , S (while s S and iterations not yet converge) do
     for  i = 1 , 2 , , m  do
       Calculate d i ( γ ) ( μ ^ , Σ ^ ) exp γ 2 ( x i μ ^ ) Σ ^ 1 ( x i μ ^ ) at the ith local client.
     end for
     Denote w i ( γ ) = d i ( γ ) ( μ ^ , Σ ^ ) i = 1 m d i ( γ ) ( μ ^ , Σ ^ ) .
      μ ^ i = 1 m w i ( γ ) · x i ,
      Σ ^ 1 + γ · i = 1 m w i ( γ ) · ( x i μ ^ ) ( x i μ ^ ) .
end for
Algorithm 2 simple γ -mean with the standard Gaussian as the working model.
Input: Gradient information X = [ x 1 , x 2 , x m ] and the maximum number of iterations S
Output: μ γ = μ ^
  Start with initial μ ^ = μ ^ i n i .
  for  s = 1 , 2 , , S (while s S and iterations not yet converge) do
     for  i = 1 , 2 , , m  do
       Calculate d i ( γ ) ( μ ^ ) exp γ 2 ( x i μ ^ ) ( x i μ ^ ) at the i th local client.
     end for
     Denote w i ( γ ) = d i ( γ ) ( μ ^ ) i = 1 m d i ( γ ) ( μ ^ ) .
      μ ^ i = 1 m w i ( γ ) · x i .
end for
For the extremely large dimension p, such as in a deep neural network model, the simple γ -mean algorithm will be more feasible than the usual γ -mean for better numerical stability.

3.3. Robustness

3.3.1. Influence Function

The influence function is a tool to evaluate the change of an estimator by a small perturbation to the data distribution. An estimator with a lower influence provides better resistance against outliers. Therefore, we analyze the robustness of different estimators by showing their influence functions. Let { x i } i = 1 m R p be sampled from G and T be the statistical functional for estimation. The robustness of T ( G ) can be evaluated by the influence function of T:
IF T ( x 0 ; G ) = ϵ T { ( 1 ϵ ) G + ϵ δ x 0 } ϵ = 0 ,
where δ x 0 is the point mass at x 0 . In other words, the influence function of T is the Gâteaux derivative of T ( G ) with respect to G along the direction δ x 0 G . In this paper, we consider evaluating the influence function at the Gaussian working model F τ with τ = ( μ , Σ ) and seeing the deviant effect when a point perturbation x 0 is added. The influence function of M-estimators have been well-studied in [16], which is related to the inverse of Hessian matrix and the first order derivative. The estimator based on γ -divergence is also an M-estimate. Its influence function derivation can be found in [14] and is given by
IF μ γ ( x 0 ; G ) | G = F τ = ( γ + 1 ) p + 2 2 exp γ 2 ( x 0 μ γ ) Σ γ 1 ( x 0 μ γ ) ( x 0 μ γ ) ,
where μ γ and Σ γ 1 are the minimum γ -divergence estimators for the location and covariance matrix of the working model F τ , respectively. The influence function of the γ -mean, IF μ γ ( x 0 ; F τ ) , is similar to the influence function of the sample mean, which is given by ( x 0 μ ) , except for an additional multiplicative weight d x 0 ( γ ) ( μ γ , Σ γ ) . This weight focuses on the Mahalanobis distance between the perturbing point x 0 and μ γ . As a result, when x 0 is an outlier away from the target mean of data distribution, this weighting factor produces a down-weighting effect and prevents the estimator from biased estimation caused by the outlier.

3.3.2. Comparison with Other Aggregators

In this subsection, we present influence functions for several robust aggregation methods including marginal median, geometric median and trimmed-mean.
Geometric median μ geo : The influence function of μ geo is given by [17]
IF μ geo ( x 0 ; G ) = E X X μ geo 2 1 I p X μ geo 2 2 ( X μ geo ) ( X μ geo ) 1 x 0 μ geo x 0 μ geo 2 ,
where the expectation E X is taken with respect to the data distribution G.
Trimmed-mean μ trm : The trimmed-mean is an L-esimator, and its influence function is derived in [18]. For the influence function of the marginal trimmed-mean, we derive it coordinate-wise since the coordinates are independent to each other. Therefore, let G i 1 be the quantile function of marginal data distribution on i-th component, and the influence function of μ trm is given by
IF μ trm ( x 0 ; G ) = 1 1 2 β w 1 w p ,
where
w i = G i 1 ( β ) W i , x 0 i < G i 1 ( β ) , x 0 W i , G i 1 ( β ) < x 0 i < G i 1 ( 1 β ) , G i 1 ( 1 β ) W i , x 0 i > G i 1 ( 1 β ) ,
x 0 = ( x 01 , , x 0 p ) , and
W i = ( 1 2 β ) G i 1 ( β ) G i 1 ( 1 β ) t d G i ( t ) + β ( G i 1 ( 1 β ) + G i 1 ( β ) ) .
It is obvious that an influential effect, x 0 W i , still remains if x 0 is outside the trimmed range.
Marginal median μ ( 0.5 ) : The influence function of the median is also derived in [18], and we use the same techniques as used in the trimmed-mean to derive the influence function of the marginal median. The influence function is given by
IF μ ( 0.5 ) ( x 0 ; G ) = x 01 μ 1 ( 0.5 ) 2 g 1 ( μ 1 ( 0.5 ) ) | x 01 μ 1 ( 0.5 ) | x 0 p μ p ( 0.5 ) 2 g p ( μ p ( 0.5 ) ) | x 0 p μ p ( 0.5 ) | ,
where μ ( 0.5 ) = ( μ 1 ( 0.5 ) , , μ p ( 0.5 ) ) is the marginal median and g i is the density function of the marginal distribution G i .
The influence functions given above reveal that different robust aggregators have different resistance ability against outliers. The trimmed mean can still be influenced by outliers outside the trimmed range. While the marginal median and the geometric median are fairly robust, they may lose too much efficiency due to the use of only the most centrally-lying data point. The γ -mean provides an adjustable control between the robustness and efficiency by the hyper-parameter γ .

4. Simulation Study

In this section, we evaluate our robust γ -mean aggregator and compare it with existing aggregators, including mean, marginal median, geometric median and trimmed mean, through simulation. For computing the geometric median, Weiszfeld’s algorithm [12] is used. In this simulation study, we investigate the behavior and inspect the robustness of the γ -mean and other aggregators as an estimator of the location parameter under multivariate Gaussian distribution and multivariate t-distribution with and without Byzantine interference. We build several simulation scenarios, including testing aggregators across the growth of dimension p, different fractions (i.e., α values) of Byzantine attacks and for various settings of γ (which controls the degree of robustness). All our simulation experiments were run on an Nvidia DGX A100 server. We used one A100 GPU card, 16G RAM, and 16 cores of AMD-EPYC-7742 CPUs. However, a personal computer with a moderate performance-efficient GPU card and CPU can also carry out the simulation experiments, but with longer run time. For the real data examples in Section 5, a personal computer will not be sufficient for carrying out the computing job.

4.1. Simulation Settings

  • Scenario 1. We focus on the behavior of aggregators for increasing p from 20 to 1000. Other experimental setting is as follows: the number of clients m = 200 , the fraction of Byzantine attacks α = 0 and α = 0.1 , and the hyper-parameter for controlling the robustness γ = 2 / p .
  • Scenario 2. We focus on the behavior of aggregators for increasing contamination fraction α from 0 to 0.5 . Other experimental setting is as follows: m = 200 , p = 1000 and γ = 2 / p .
  • Scenario 3. We focus on the effect of γ values and set γ = c / p for various constants c ranging from 0.5 to 4. Other experimental setting is as follows: m = 200 , α = 0.1 and p ranges from 1 to 1000.
  • Scenario 4. After comparison between the γ -mean and other aggregators, we focus on the comparison between two versions of our proposal, the γ -mean versus the simple γ -mean. Other experimental setting is as follows: m = 200 , α = 0.1 , γ = 2 / p and p ranges from 1 to 1000.
For regular clients, gradient vectors are generated from the standard Gaussian distribution and t-distribution with 5 degrees of freedom. For Byzantine clients, gradient vectors are generated from the same Gaussian and t distributions, except with a location shift μ = 100 × 1 p , where 1 p is a vector with one in all entries. Each experimental scenario is implemented with 100 replicate runs.

4.2. Results

To compare the performance of different aggregators, the mean squared error (MSE) to the true target value is used as a performance metric. MSE can be further decomposed into the squared bias and variance.

4.2.1. Scenario 1

The results are shown in Figure 1. Without contamination (Figure 1a), all aggregators have MSE close to zero. For the Gaussian case, the MSE curves of the mean, geometric median and γ -mean are almost collapsed together. These 3 curves have the lowest MSE values followed by the trimmed-mean and marginal median. For the t-distribution, the γ -mean has the lowest MSE, followed by the geometric median, trimmed-mean, mean and then marginal median. With 10 % of contamination, the mean-aggregator fails to perform well and has large values of MSE, verifying that it is not a robust aggregator. Due its large MSE value, the result of the mean-aggregator is removed from Figure 1b. From Figure 1b we can see that γ -mean performs the best among the five aggregators in terms of MSE. In addition, γ -mean remains low and stable in variance under both Gaussian and t distributions.

4.2.2. Scenario 2

As Figure 2 shows, values of squared bias are close to zero when α = 0 and increase as α grows for all aggregators. However, the increasing velocity of aggregators are quite different. Mean-aggregator surges with the highest velocity, while γ -mean progresses with the lowest velocity. Trimmed-mean slowly grows when α 0.1 ; yet, its increasing pace is about the same as the mean after α > 0.1 due to the fixed 10 % trimming percentage from both tails. The trends in MSE for different aggregators have similar patterns, and the only difference is in the scale with the γ -mean having the lowest MSE values. Also note that the variance of the marginal median increases much faster than other aggregators as α increases, followed by the geometric median, γ -mean, trimmed mean and then the mean. However, the trimmed mean and the mean have dominant bias leading to high MSE, even though they have small variance. As the values of squared bias are large for the mean and trimmed-mean, the bias curves for the γ -mean, geometric median and marginal median in the 2nd row of Figure 2 look collapsed together. To have a better view of these three bias curves, zoomed-in views are provided in the 3rd row of Figure 2, and we can clearly see that the γ -mean has the lowest bias values.

4.2.3. Scenario 3

We conduct further experiments on γ -mean to see the effect of different γ values by setting γ = c / p with c = 0.5 , 1 , 2 , 3 , 4 , where α = 0.1 and p [ 1 , 1000 ] . Results are presented in Figure 3. The squared bias is ignorable, indicating that the γ -mean is quite robust with different γ -values. The main source of MSE comes from the variance. Larger γ values lead to larger stochastic variances, indicating that larger γ values result in lower estimation efficiency.

4.2.4. Scenario 4

For p [ 100 , 1000 ] and under the Gaussian case, the MSE, bias and variance curves for the γ -mean and simple γ -mean are collapsed together (Figure 4, left panel). For p [ 100 , 1000 ] and under the t-distribution, these curves are nearly indistinguishable (Figure 4, right panel). When p is small ( p 100 ), the MSE values of the simple γ -mean are significantly larger than those of γ -mean. However, we did not display the results for p < 100 , since the relative large MSE values for small p will make the MSE curves of γ -mean and simple γ -mean look lying on the x-axis for p 100 .

5. Real Data Examples

5.1. Datasets

  • MNIST [19]. The MNIST database of handwritten digits has a training set of 60,000 examples, and a test set of 10,000 examples. The digits have been size-normalized and centered in a fixed-size, 28 × 28 grayscale, images.
  • Fashion MNIST [20]. Fashion-MNIST is a dataset of Zalando’s article images consisting of a training set of 60,000 examples and a test set of 10,000 examples. Each example is a 28 × 28 grayscale image, associated with a label from 10 types of clothing, such as shoes, t-shirts, dresses, sandals, sneakers and more.
  • Chest X-ray images (pneumonia) [21]. The dataset contains 5856 X-ray images and 2 classes (pneumonia and normal). The 5,856 images consist of 5232 training images (which we further split into 90% for model training and 10% for model validation to implement early stopping) and 624 testing images. Chest X-ray images (anterior-posterior) were selected from retrospective cohorts of pediatric patients of one to five years old from Guangzhou Women and Children’s Medical Center, Guangzhou.

5.2. Experimental Setting

All the data examples were run on an Nvidia DGX A100 server. We used one A100 GPU card, 64 G RAM, and 16 cores of AMD-EPYC-7742 CPUs. The experimental setting is described below.
  • For MNIST and fashion MNIST, we set m = 20 and the number of Byzantine clients is two. For chest X-ray images, we set m = 9 and the number of Byzantine clients is one. Byzantine clients return random values from Gaussian (5, 1).
  • γ is set to 0.5. This setting is different from the setting γ = c / p in simulation. The main reason is that there is a certain complicated relationship between the γ value and the neural network model adopted, such as the dimensionality, gradient size, learning rate, etc. We have not yet fully understood this relationship, which might govern the selection of γ . We will leave it as a future study.
  • We set β = 0.1 for the trimmed mean.
  • To allow for the imbalanced size of clients, we obtain the sample size of each client by sampling from the following steps [11].
    • Sampling a vector a from Lognormal ( 1.5 , 3 . 45 2 ) .
    • Sampling η from Dirichlet ( a ) . The sum of vector η will be 1 due to the property of Dirichlet distribution.
    • Obtain the sample sizes of clients from multinomial ( n m × n o , η ) , where n is total sample size and n o is the minimum sample size guaranteed for each client. We set n o = 512 .
In each round of parameter update, the server broadcasts the current parameter values to m clients. Each normal client computes the local gradient and local parameter update, and then returns the change of parameters after iterating k epochs using local data ( k = 1 for MNIST and fashion MNIST, and k = 20 for chest X-ray images). Byzantine clients return random values from a Gaussian distribution N ( 5 , 1 ) . Some further implementation details are given below.
  • We run 1000 rounds of FL for MNIST and fashion MNIST, and 100 rounds for chest X-ray images.
  • We apply stochastic gradient descent (SGD) with a cosine decay learning rate (decay over rounds), where the decay step is 1000 for MNIST and fashion MNIST, and 100 for chest X-ray images. In each epoch of local clients on MNIST and fashion MNIST, the SGD will go through only 10% of local data to save computing time. This implementation leads to some fluctuations in the early stage of training but the training process will be much faster than going through all local data.
  • The initial learning rates are 0.1, 0.5, and 10 4 for MNIST, fashion MNIST, and chest X-ray images, respectively. In each epoch in local iterates, 10 5 is set as the decay constant.
  • We apply gradient clipping to avoid exploding gradients on MNIST and fashion MNIST. If the 2-norm of aggregated gradient is larger than 1, the vector will be scaled to a new vector with norm 1.
  • To handle the imbalanced class size, we use weighted cross-entropy as loss function in the chest X-ray example (pneumonia: 0.35, normal: 1.0), where the chest X-ray training dataset contains 3883 pneumonia cases and 1349 normal cases. In addition to the classification accuracy, we also use ‘accuracy’, ‘sensitivity’ (also known as ‘recall’ and ‘true positive rate’) and ‘precision’ as our evaluation metrics. In particular, correctly predicting pneumonia is more important than predicting the normal case.

5.3. Models

We use a simple CNN model (Figure A1 in Appendix A) for MNIST and fashion MNIST. For chest X-ray images (pneumonia), we use a pretrained ResNet50 (Figure A2 in Appendix A) with input image size 150 × 150 × 3 . We further connect the model to 3 dense-block layers.

5.4. Results

Results for the trimmed mean are not reported here, as they do not perform well due to its reliance on the specification of β and its non-robustness to non-symmetric outliers (note that the trimming procedure is symmetric in two tails of each coordinate).

5.4.1. MNIST

Figure 5 shows the results of FL using different aggregators. In the case of no attack, the simple γ -mean performs as good as the mean-aggregator with 96% testing data classification accuracy. In the Byzantine case, simple γ -mean performs the best with testing accuracy 96% followed by the two median-based methods, geometric median and marginal median, while the mean-aggregator has failed. In addition, the marginal median has some fluctuations in the early stage of training.

5.4.2. Fashion MNIST

Figure 6 shows the results of fashion MNIST. The conclusion is similar to MNIST, that the simple γ -mean performs the best. Note that the fluctuations shown in the training process of marginal median look worse than those shown in Figure 5. Although we have applied gradient clipping, still there are some sudden surges in loss function values and they lead to sudden drops in classification accuracy.

5.4.3. Chest X-ray Images (Pneumonia)

We train the model on the training dataset without any partitions (i.e., single machine with full training data) as a baseline. In this single machine case, the maximum number of training epochs is 1000 and the training process will stop early if the evaluation metric (loss + accuracy + precision + sensitivity on validation data) does not get better in 10 consecutive epochs. For the FL case, the number of FL rounds is 100 and the maximum training epochs of each client are 20 at every FL round. At each round, the clients will stop early if the evaluation metric (loss + accuracy + precision + sensitivity on validation data) does not get better in 5 consecutive epochs. The simple γ -mean algorithm has adopted the marginal median as initial. All other settings are the same as those listed in Section 5.2. The results are shown in Figure 7.
We further report the following testing data results in Table 1: true positive (TP, which is predicted positive and actually positive), true negative (TN, which is predicted negative and actually negative), false positive (FP, which is predicted positive but actually negative), false negative (FN, which is predicted negative but actually positive), precision (Prec, which is TP TP + FP ), sensitivity (Sens, which is TP TP + FN as well as one minus type-II error) and classification accuracy (Acc). The aggregators by mean and geometric median cannot tolerate the Byzantine (Byz) attack and both methods were crushed during model training. Thus, there are no results reported for these two methods.

6. Concluding Remarks

Our major contribution in this article is to propose the γ -mean aggregation in federated learning, which is robust against outliers and Byzantine attacks. We have provided some theoretical discussions on influence functions and carried out numerical studies to justify our proposal. In both simulation and real data experiments, γ -mean based aggregation in general outperforms the existing robust aggregators such as marginal median, geometric median and trimmed mean. In addition, the simple γ -mean does not require complex computation. It can be easily computed by fixed point iteration and works well for extremely high dimensional cases such as in a deep neural network model. The γ -mean (as well as the simple γ -mean) takes a form of weighted average, where γ controls the trade-off between robustness and estimation efficiency. The selection of γ value is important, and a data-driven selection procedure is needed, which will be pursued in our future study.
In our numerical examples (simulation and real data), we have tested γ -mean’s ability to guard against the Byzantine attacks. However, when it comes to federated learning in real-world applications, adversarial attacks are also an important issue for model robustness. These attacks create adversarial fake examples, which even humans might not be able to distinguish from genuine ones. Moreover, adversarial examples aim to confuse the model and result in wrong prediction. How to modify the baseline γ -mean algorithm for adversarial attacks will be an interesting and yet challenging problem. We will defer it to a future study.

Author Contributions

Conceptualization, S.-Y.H.; methodology, S.-Y.H.; C.-J.L. and H.H.; software, C.-J.L. and P.-H.H.; formal analysis, C.-J.L.; P.-H.H.; Y.-T.M.; writing—original draft preparation, C.-J.L.; P.-H.H.; Y.-T.M.; writing—review and editing, C.-J.L.; P.-H.H.; Y.-T.M.; H.H. and S.-Y.H. All authors have read and agreed to the published version of the manuscript.

Funding

This research received no external funding.

Institutional Review Board Statement

Not applicable.

Data Availability Statement

Benchmark datasets are available from the following links: “MNIST” at http://yann.lecun.com/exdb/mnist/ accessed on 31 August 2021, “fashion MNIST” at https://research.zalando.com/project/fashion_mnist/fashion_mnist/ accessed on 31 August 2021, and “Chest X-ray images (pneumonia)” at https://www.kaggle.com/datasets/paultimothymooney/chest-xray-pneumonia accessed on 1 October 2021.

Acknowledgments

Su-Yun Huang thanks the Institute of Statistical Science, Academia Sinica, for a seed grant funding.

Conflicts of Interest

The authors declare no conflict of interest.

Abbreviations

The following abbreviations are used in this manuscript:
Accaccuracy
ByzByzantine
FLFederated learning
FNfalse negative, predicted negative but actually positive
FPfalse positive, predicted positive but actually negative
GeoMedGeometric median
IFInfluence function
Precprecision, TP TP + FP
Senssensitivity, TP TP + FN or 1 type II error
TNtrue negative, predicted negative and actually negative
TPtrue positive, predicted positive and actually positive

Appendix A

Figure A1. Model used in MNIST and fashion MNIST.
Figure A1. Model used in MNIST and fashion MNIST.
Entropy 24 00686 g0a1
For MNIST and fashion MNIST, we use a model similar to LeNet-5. All filters in the convolution layers are 5 × 5 maps without padding. The first layer in fully connection is a dense layer with 500 neurons with ReLU activation. The last layer is an output layer with 10 neurons with softmax activation. The number of parameters is 931,080.
Figure A2. Model used in chest X-ray images.
Figure A2. Model used in chest X-ray images.
Entropy 24 00686 g0a2
For chest X-ray images, we use a pretrained ResNet50 with input image size is 150 × 150 × 3 and freeze all ResNet50 parameters. The output dimensionality is 5 × 5 × 2048 . A global average pooling is adopted and it leads to the dimensionality 2048. Next, we connect it to 3 dense-block layers. Dense block: dense layer → batch normalization → drop out. Number of neurons are 256, 64, and 32, respectively. The drop out probabilities are 0.7, 0.5, and 0.3, respectively. The number of trainable parameters is 543,809.

References

  1. Konečný, J.; McMahan, H.B.; Yu, F.X.; Richtarik, P.; Suresh, A.T.; Bacon, D. Federated Learning: Strategies for Improving Communication Efficiency. In Proceedings of the NeurIPS Workshop on Private Multi-Party Machine Learning; 2016. Available online: https://nips.cc/Conferences/2016/ScheduleMultitrack?event=6250 (accessed on 29 March 2022).
  2. So, J.; Güler, B.; Avestimehr, A.S. Byzantine-resilient secure federated learning. IEEE J. Sel. Areas Commun. 2020, 39, 2168–2181. [Google Scholar] [CrossRef]
  3. Xu, J.; Glicksberg, B.S.; Su, C.; Walker, P.; Bian, J.; Wang, F. Federated learning for healthcare informatics. J. Healthc. Inform. Res. 2021, 5, 1–19. [Google Scholar] [CrossRef] [PubMed]
  4. Alistarh, D.; Allen-Zhu, Z.; Li, J. Byzantine stochastic gradient descent. In Advances in Neural Information Processing Systems; Curran Associates, Inc.: Red Hook, NY, USA, 2018; p. 31. [Google Scholar]
  5. Chen, X.; Chen, T.; Sun, H.; Wu, S.Z.; Hong, M. Distributed training with heterogeneous data: Bridging median- and mean-based algorithms. In Advances in Neural Information Processing Systems; Curran Associates, Inc.: Red Hook, NY, USA, 2020; p. 33. [Google Scholar]
  6. Chen, Y.; Su, L.; X, J. Distributed statistical machine learning in adversarial settings: Byzantine gradient descent. Proc. Acm Meas. Anal. Comput. Syst. 2017, 1, 1–25. [Google Scholar] [CrossRef]
  7. Xie, C.; Koyejo, O.; Gupta, I. Generalized Byzantine-tolerant SGD. arXiv 2018, arXiv:1802.10116. [Google Scholar]
  8. Li, L.; Xu, W.; Chen, T.; Giannakis, G.B.; Ling, Q. RSA: Byzantine-robust stochastic aggregation methods for distributed learning from heterogeneous datasets. Proc. Aaai Conf. Artif. Intell. 2019, 33, 1544–1551. [Google Scholar] [CrossRef]
  9. McMahan, B.; Moore, E.; Ramage, D.; Hampson, S.; y Arcas, B.A. Communication-efficient learning of deep networks from decentralized data. In Proceedings of the 20th International Conference on Artificial Intelligence and Statistics, Fort Lauderdale, FL, USA, 10 April 2017; Volume 54, pp. 1273–1282. [Google Scholar]
  10. Dayan, I.; Roth, H.R.; Zhong, A.; Harouni, A.; Gentili, A.; Abidin, A.Z.; Li, Q. Federated learning for predicting clinical outcomes in patients with covid-19. Nat. Med. 2021, 27, 1735–1743. [Google Scholar] [CrossRef] [PubMed]
  11. Portnoy, A.; Tirosh, Y.; Hendler, D. Towards Federated Learning with Byzantine-Robust Client Weighting. In Proceedings of the International Workshop on Federated Learning for User Privacy and Data Confidentiality in Conjunction with ICML; 2021. Available online: https://federated-learning.org/fl-icml-2021/ (accessed on 29 March 2022).
  12. Weiszfeld, E.; Plastria, F. On the point for which the sum of the distances to n given points is minimum. Ann. Oper. Res. 2009, 167, 7–41. [Google Scholar] [CrossRef]
  13. Fujisawa, H.; Eguchi, S. Robust parameter estimation with a small bias against heavy contamination. J. Multivar. Anal. 2008, 99, 2053–2081. [Google Scholar] [CrossRef] [Green Version]
  14. Hung, H. A robust removing unwanted variation–testing procedure via γ-divergence. Biometrics 2019, 75, 650–662. [Google Scholar] [CrossRef] [PubMed]
  15. Jones, M.C.; Hjort, N.L.; Harris, I.R.; Basu, A. A comparison of related density-based minimum divergence estimators. Biometrika 2001, 88, 865–873. [Google Scholar] [CrossRef]
  16. Huber, P.J. Robust Statistics; John Wiley & Sons: Hoboken, NJ, USA, 2004. [Google Scholar]
  17. Chaudhuri, P. On a geometric notion of quantiles for multivariate data. J. Am. Stat. Assoc. 1996, 91, 862–872. [Google Scholar] [CrossRef]
  18. van der Vaart, A.W. Asymptotic Statistics; Cambridge Series in Statistical and Probabilistic Mathematics; Cambridge University Press: Cambridge, UK, 1998. [Google Scholar]
  19. Deng, L. The MNIST database of handwritten digit images for machine learning research. IEEE Signal Process. Mag. 2012, 29, 141–142. [Google Scholar] [CrossRef]
  20. Xiao, H.; Rasul, K.; Vollgraf, R. Fashion-MNIST a novel image dataset for benchmarking machine learning algorithms. arXiv 2017, arXiv:1708.07747. [Google Scholar]
  21. Kermany, D.S.; Goldbaum, M.; Cai, W.; Valentim, C.; Liang, H.; Baxter, S.L.; McKeown, A.; Yang, G.; Wu, X.; Yan, F.; et al. Identifying medical diagnoses and treatable diseases by image-based deep learning. Cell 2018, 172, 1122–1131.e9. [Google Scholar] [CrossRef] [PubMed]
Figure 1. Comparison of different aggregators across different dimensions p. (a) Case α = 0 (no Byzantine client). (b) Case α = 0.1 (10% Byzantine clients).
Figure 1. Comparison of different aggregators across different dimensions p. (a) Case α = 0 (no Byzantine client). (b) Case α = 0.1 (10% Byzantine clients).
Entropy 24 00686 g001
Figure 2. Comparison of different aggregators across different α values with fixed p = 1000 . The two subplots in the 3rd row are the zoomed-in views of subplots in the 2nd row.
Figure 2. Comparison of different aggregators across different α values with fixed p = 1000 . The two subplots in the 3rd row are the zoomed-in views of subplots in the 2nd row.
Entropy 24 00686 g002
Figure 3. Effect of different γ values across different dimensions p [ 1 , 1000 ] with α = 0.1 .
Figure 3. Effect of different γ values across different dimensions p [ 1 , 1000 ] with α = 0.1 .
Entropy 24 00686 g003
Figure 4. Comparison of γ -mean versus simple γ -mean with α = 0.1 and p [ 100 , 1000 ] .
Figure 4. Comparison of γ -mean versus simple γ -mean with α = 0.1 and p [ 100 , 1000 ] .
Entropy 24 00686 g004
Figure 5. Testing process and comparison of testing accuracy for different aggregators on MNIST.
Figure 5. Testing process and comparison of testing accuracy for different aggregators on MNIST.
Entropy 24 00686 g005
Figure 6. Testing process and comparison of testing accuracy for different aggregators on fashion MNIST.
Figure 6. Testing process and comparison of testing accuracy for different aggregators on fashion MNIST.
Entropy 24 00686 g006
Figure 7. Testing process and comparison of testing accuracy for different aggregators on chest x ray. The aggregators by mean and geometric median cannot tolerate the Byzantine attack and both methods were crushed during model training. Thus, there are no results reported for these two methods.
Figure 7. Testing process and comparison of testing accuracy for different aggregators on chest x ray. The aggregators by mean and geometric median cannot tolerate the Byzantine attack and both methods were crushed during model training. Thus, there are no results reported for these two methods.
Entropy 24 00686 g007
Table 1. Pneumonia prediction on test data.
Table 1. Pneumonia prediction on test data.
ByzAggregatorTNFNFPTPPrecSens (Type II Error)Acc
single machine15623783670.82470.9410 (0.0590)0.8381
Nomean212103222870.92880.7359 (0.2661)0.7997
marginal median19063443270.88140.8385 (0.1615)0.8285
simple γ -mean12681083820.77960.9795 (0.0205)0.8141
GeoMed17730573600.86330.9231 (0.0769)0.8606
Yesmean
marginal median22827161190.95200.3051 (0.6949)0.5561
simple γ -mean14011943790.80130.9718 (0.0282)0.8317
GeoMed
† The symbol “–” indicates model crushed during training.
Publisher’s Note: MDPI stays neutral with regard to jurisdictional claims in published maps and institutional affiliations.

Share and Cite

MDPI and ACS Style

Li, C.-J.; Huang, P.-H.; Ma, Y.-T.; Hung, H.; Huang, S.-Y. Robust Aggregation for Federated Learning by Minimum γ-Divergence Estimation. Entropy 2022, 24, 686. https://doi.org/10.3390/e24050686

AMA Style

Li C-J, Huang P-H, Ma Y-T, Hung H, Huang S-Y. Robust Aggregation for Federated Learning by Minimum γ-Divergence Estimation. Entropy. 2022; 24(5):686. https://doi.org/10.3390/e24050686

Chicago/Turabian Style

Li, Cen-Jhih, Pin-Han Huang, Yi-Ting Ma, Hung Hung, and Su-Yun Huang. 2022. "Robust Aggregation for Federated Learning by Minimum γ-Divergence Estimation" Entropy 24, no. 5: 686. https://doi.org/10.3390/e24050686

Note that from the first issue of 2016, this journal uses article numbers instead of page numbers. See further details here.

Article Metrics

Back to TopTop