Robust Aggregation for Federated Learning by Minimum γ-Divergence Estimation
Abstract
:1. Introduction
- We propose the -mean as a robust aggregator in federated learning. This robust aggregator mitigates the influence of Byzantine clients by assigning fewer weights. The weighting scheme is data-driven and controlled by the tuning parameter .
- We have a discussion on robustness from the influence function point of view. Benefits of adopting -mean can be seen from its influence function in comparison to other robust alternatives such as marginal median, geometric median and trimmed mean.
- The robustness of -mean is then verified through simulation study and real data experiments. The -mean based aggregation outperforms other robust aggregators such as marginal median, geometric median and trimmed mean.
2. Related Work
3. Proposed Aggregator and Its Robustness
3.1. Minimum -Divergence Estimation
3.2. Robust Aggregation by -Mean
Algorithm 1 -mean with a Gaussian working model. |
Input: Gradient information … and the maximum number of iterations S Output: = Start with initials and . for (while and iterations not yet converge) do for do Calculate at the ith local client. end for Denote . , . end for |
Algorithm 2 simple -mean with the standard Gaussian as the working model. |
Input: Gradient information … and the maximum number of iterations S Output: = Start with initial . for (while and iterations not yet converge) do for do Calculate at the local client. end for Denote . . end for |
3.3. Robustness
3.3.1. Influence Function
3.3.2. Comparison with Other Aggregators
4. Simulation Study
4.1. Simulation Settings
- Scenario 1. We focus on the behavior of aggregators for increasing p from 20 to 1000. Other experimental setting is as follows: the number of clients , the fraction of Byzantine attacks and , and the hyper-parameter for controlling the robustness .
- Scenario 2. We focus on the behavior of aggregators for increasing contamination fraction from 0 to . Other experimental setting is as follows: , and .
- Scenario 3. We focus on the effect of values and set for various constants c ranging from 0.5 to 4. Other experimental setting is as follows: , and p ranges from 1 to 1000.
- Scenario 4. After comparison between the -mean and other aggregators, we focus on the comparison between two versions of our proposal, the -mean versus the simple -mean. Other experimental setting is as follows: , , and p ranges from 1 to 1000.
4.2. Results
4.2.1. Scenario 1
4.2.2. Scenario 2
4.2.3. Scenario 3
4.2.4. Scenario 4
5. Real Data Examples
5.1. Datasets
- MNIST [19]. The MNIST database of handwritten digits has a training set of 60,000 examples, and a test set of 10,000 examples. The digits have been size-normalized and centered in a fixed-size, grayscale, images.
- Fashion MNIST [20]. Fashion-MNIST is a dataset of Zalando’s article images consisting of a training set of 60,000 examples and a test set of 10,000 examples. Each example is a grayscale image, associated with a label from 10 types of clothing, such as shoes, t-shirts, dresses, sandals, sneakers and more.
- Chest X-ray images (pneumonia) [21]. The dataset contains 5856 X-ray images and 2 classes (pneumonia and normal). The 5,856 images consist of 5232 training images (which we further split into 90% for model training and 10% for model validation to implement early stopping) and 624 testing images. Chest X-ray images (anterior-posterior) were selected from retrospective cohorts of pediatric patients of one to five years old from Guangzhou Women and Children’s Medical Center, Guangzhou.
5.2. Experimental Setting
- For MNIST and fashion MNIST, we set and the number of Byzantine clients is two. For chest X-ray images, we set and the number of Byzantine clients is one. Byzantine clients return random values from Gaussian (5, 1).
- is set to 0.5. This setting is different from the setting in simulation. The main reason is that there is a certain complicated relationship between the value and the neural network model adopted, such as the dimensionality, gradient size, learning rate, etc. We have not yet fully understood this relationship, which might govern the selection of . We will leave it as a future study.
- We set for the trimmed mean.
- To allow for the imbalanced size of clients, we obtain the sample size of each client by sampling from the following steps [11].
- Sampling a vector from .
- Sampling from . The sum of vector will be 1 due to the property of Dirichlet distribution.
- Obtain the sample sizes of clients from multinomial , where n is total sample size and is the minimum sample size guaranteed for each client. We set .
- We run 1000 rounds of FL for MNIST and fashion MNIST, and 100 rounds for chest X-ray images.
- We apply stochastic gradient descent (SGD) with a cosine decay learning rate (decay over rounds), where the decay step is 1000 for MNIST and fashion MNIST, and 100 for chest X-ray images. In each epoch of local clients on MNIST and fashion MNIST, the SGD will go through only 10% of local data to save computing time. This implementation leads to some fluctuations in the early stage of training but the training process will be much faster than going through all local data.
- The initial learning rates are 0.1, 0.5, and for MNIST, fashion MNIST, and chest X-ray images, respectively. In each epoch in local iterates, is set as the decay constant.
- We apply gradient clipping to avoid exploding gradients on MNIST and fashion MNIST. If the 2-norm of aggregated gradient is larger than 1, the vector will be scaled to a new vector with norm 1.
- To handle the imbalanced class size, we use weighted cross-entropy as loss function in the chest X-ray example (pneumonia: 0.35, normal: 1.0), where the chest X-ray training dataset contains 3883 pneumonia cases and 1349 normal cases. In addition to the classification accuracy, we also use ‘accuracy’, ‘sensitivity’ (also known as ‘recall’ and ‘true positive rate’) and ‘precision’ as our evaluation metrics. In particular, correctly predicting pneumonia is more important than predicting the normal case.
5.3. Models
5.4. Results
5.4.1. MNIST
5.4.2. Fashion MNIST
5.4.3. Chest X-ray Images (Pneumonia)
6. Concluding Remarks
Author Contributions
Funding
Institutional Review Board Statement
Data Availability Statement
Acknowledgments
Conflicts of Interest
Abbreviations
Acc | accuracy |
Byz | Byzantine |
FL | Federated learning |
FN | false negative, predicted negative but actually positive |
FP | false positive, predicted positive but actually negative |
GeoMed | Geometric median |
IF | Influence function |
Prec | precision, |
Sens | sensitivity, or type II error |
TN | true negative, predicted negative and actually negative |
TP | true positive, predicted positive and actually positive |
Appendix A
References
- Konečný, J.; McMahan, H.B.; Yu, F.X.; Richtarik, P.; Suresh, A.T.; Bacon, D. Federated Learning: Strategies for Improving Communication Efficiency. In Proceedings of the NeurIPS Workshop on Private Multi-Party Machine Learning; 2016. Available online: https://nips.cc/Conferences/2016/ScheduleMultitrack?event=6250 (accessed on 29 March 2022).
- So, J.; Güler, B.; Avestimehr, A.S. Byzantine-resilient secure federated learning. IEEE J. Sel. Areas Commun. 2020, 39, 2168–2181. [Google Scholar] [CrossRef]
- Xu, J.; Glicksberg, B.S.; Su, C.; Walker, P.; Bian, J.; Wang, F. Federated learning for healthcare informatics. J. Healthc. Inform. Res. 2021, 5, 1–19. [Google Scholar] [CrossRef] [PubMed]
- Alistarh, D.; Allen-Zhu, Z.; Li, J. Byzantine stochastic gradient descent. In Advances in Neural Information Processing Systems; Curran Associates, Inc.: Red Hook, NY, USA, 2018; p. 31. [Google Scholar]
- Chen, X.; Chen, T.; Sun, H.; Wu, S.Z.; Hong, M. Distributed training with heterogeneous data: Bridging median- and mean-based algorithms. In Advances in Neural Information Processing Systems; Curran Associates, Inc.: Red Hook, NY, USA, 2020; p. 33. [Google Scholar]
- Chen, Y.; Su, L.; X, J. Distributed statistical machine learning in adversarial settings: Byzantine gradient descent. Proc. Acm Meas. Anal. Comput. Syst. 2017, 1, 1–25. [Google Scholar] [CrossRef]
- Xie, C.; Koyejo, O.; Gupta, I. Generalized Byzantine-tolerant SGD. arXiv 2018, arXiv:1802.10116. [Google Scholar]
- Li, L.; Xu, W.; Chen, T.; Giannakis, G.B.; Ling, Q. RSA: Byzantine-robust stochastic aggregation methods for distributed learning from heterogeneous datasets. Proc. Aaai Conf. Artif. Intell. 2019, 33, 1544–1551. [Google Scholar] [CrossRef]
- McMahan, B.; Moore, E.; Ramage, D.; Hampson, S.; y Arcas, B.A. Communication-efficient learning of deep networks from decentralized data. In Proceedings of the 20th International Conference on Artificial Intelligence and Statistics, Fort Lauderdale, FL, USA, 10 April 2017; Volume 54, pp. 1273–1282. [Google Scholar]
- Dayan, I.; Roth, H.R.; Zhong, A.; Harouni, A.; Gentili, A.; Abidin, A.Z.; Li, Q. Federated learning for predicting clinical outcomes in patients with covid-19. Nat. Med. 2021, 27, 1735–1743. [Google Scholar] [CrossRef] [PubMed]
- Portnoy, A.; Tirosh, Y.; Hendler, D. Towards Federated Learning with Byzantine-Robust Client Weighting. In Proceedings of the International Workshop on Federated Learning for User Privacy and Data Confidentiality in Conjunction with ICML; 2021. Available online: https://federated-learning.org/fl-icml-2021/ (accessed on 29 March 2022).
- Weiszfeld, E.; Plastria, F. On the point for which the sum of the distances to n given points is minimum. Ann. Oper. Res. 2009, 167, 7–41. [Google Scholar] [CrossRef]
- Fujisawa, H.; Eguchi, S. Robust parameter estimation with a small bias against heavy contamination. J. Multivar. Anal. 2008, 99, 2053–2081. [Google Scholar] [CrossRef] [Green Version]
- Hung, H. A robust removing unwanted variation–testing procedure via γ-divergence. Biometrics 2019, 75, 650–662. [Google Scholar] [CrossRef] [PubMed]
- Jones, M.C.; Hjort, N.L.; Harris, I.R.; Basu, A. A comparison of related density-based minimum divergence estimators. Biometrika 2001, 88, 865–873. [Google Scholar] [CrossRef]
- Huber, P.J. Robust Statistics; John Wiley & Sons: Hoboken, NJ, USA, 2004. [Google Scholar]
- Chaudhuri, P. On a geometric notion of quantiles for multivariate data. J. Am. Stat. Assoc. 1996, 91, 862–872. [Google Scholar] [CrossRef]
- van der Vaart, A.W. Asymptotic Statistics; Cambridge Series in Statistical and Probabilistic Mathematics; Cambridge University Press: Cambridge, UK, 1998. [Google Scholar]
- Deng, L. The MNIST database of handwritten digit images for machine learning research. IEEE Signal Process. Mag. 2012, 29, 141–142. [Google Scholar] [CrossRef]
- Xiao, H.; Rasul, K.; Vollgraf, R. Fashion-MNIST a novel image dataset for benchmarking machine learning algorithms. arXiv 2017, arXiv:1708.07747. [Google Scholar]
- Kermany, D.S.; Goldbaum, M.; Cai, W.; Valentim, C.; Liang, H.; Baxter, S.L.; McKeown, A.; Yang, G.; Wu, X.; Yan, F.; et al. Identifying medical diagnoses and treatable diseases by image-based deep learning. Cell 2018, 172, 1122–1131.e9. [Google Scholar] [CrossRef] [PubMed]
Byz | Aggregator | TN | FN | FP | TP | Prec | Sens (Type II Error) | Acc |
---|---|---|---|---|---|---|---|---|
single machine | 156 | 23 | 78 | 367 | 0.8247 | 0.9410 (0.0590) | 0.8381 | |
No | mean | 212 | 103 | 22 | 287 | 0.9288 | 0.7359 (0.2661) | 0.7997 |
marginal median | 190 | 63 | 44 | 327 | 0.8814 | 0.8385 (0.1615) | 0.8285 | |
simple -mean | 126 | 8 | 108 | 382 | 0.7796 | 0.9795 (0.0205) | 0.8141 | |
GeoMed | 177 | 30 | 57 | 360 | 0.8633 | 0.9231 (0.0769) | 0.8606 | |
Yes | mean | – | – | – | – | – | – | – |
marginal median | 228 | 271 | 6 | 119 | 0.9520 | 0.3051 (0.6949) | 0.5561 | |
simple -mean | 140 | 11 | 94 | 379 | 0.8013 | 0.9718 (0.0282) | 0.8317 | |
GeoMed | – | – | – | – | – | – | – |
Publisher’s Note: MDPI stays neutral with regard to jurisdictional claims in published maps and institutional affiliations. |
© 2022 by the authors. Licensee MDPI, Basel, Switzerland. This article is an open access article distributed under the terms and conditions of the Creative Commons Attribution (CC BY) license (https://creativecommons.org/licenses/by/4.0/).
Share and Cite
Li, C.-J.; Huang, P.-H.; Ma, Y.-T.; Hung, H.; Huang, S.-Y. Robust Aggregation for Federated Learning by Minimum γ-Divergence Estimation. Entropy 2022, 24, 686. https://doi.org/10.3390/e24050686
Li C-J, Huang P-H, Ma Y-T, Hung H, Huang S-Y. Robust Aggregation for Federated Learning by Minimum γ-Divergence Estimation. Entropy. 2022; 24(5):686. https://doi.org/10.3390/e24050686
Chicago/Turabian StyleLi, Cen-Jhih, Pin-Han Huang, Yi-Ting Ma, Hung Hung, and Su-Yun Huang. 2022. "Robust Aggregation for Federated Learning by Minimum γ-Divergence Estimation" Entropy 24, no. 5: 686. https://doi.org/10.3390/e24050686
APA StyleLi, C.-J., Huang, P.-H., Ma, Y.-T., Hung, H., & Huang, S.-Y. (2022). Robust Aggregation for Federated Learning by Minimum γ-Divergence Estimation. Entropy, 24(5), 686. https://doi.org/10.3390/e24050686