TLDR: This research introduces FedSWA and FedMoSWA, two new federated learning algorithms designed to improve model generalization, especially when data across clients is highly heterogeneous. FedSWA uses Stochastic Weight Averaging (SWA) with cyclical learning rates, while FedMoSWA further enhances this with a momentum-based stochastic control mechanism to better align local and global models. Both algorithms demonstrate superior performance and generalization capabilities compared to existing methods like FedSAM, particularly in challenging non-IID data settings, as validated by extensive experiments and theoretical analysis.
Federated Learning (FL) has emerged as a powerful approach to machine learning, allowing multiple participants to collaboratively train a shared model without directly sharing their raw data. This is particularly beneficial for applications where data privacy and security are paramount, such as in healthcare, finance, and personalized mobile services. However, FL faces a significant challenge: data heterogeneity. This occurs when data across different clients is not uniformly distributed, leading to difficulties in training a global model that generalizes well to unseen data.
The ability of a model to generalize, meaning its performance on new, unseen data, is crucial for real-world applications. In traditional deep learning, finding ‘flat minima’ in the model’s parameter space has been shown to improve generalization. Two prominent methods for this are Stochastic Weight Averaging (SWA) and Sharpness-Aware Minimization (SAM).
In the context of federated learning, previous attempts to improve generalization often involved adapting SAM for local client optimization, leading to algorithms like FedSAM. While FedSAM aimed to find flat minima, research has shown that it often performs poorly when data is highly heterogeneous. This is because FedSAM tends to find local flat minima rather than a truly global flat minimum, which is essential for a robust global model in a federated setting. Moreover, SAM-based methods can be computationally expensive, requiring additional calculations for perturbations.
To address these limitations, researchers have proposed a novel federated learning algorithm called Federated Stochastic Weight Averaging (FedSWA). Inspired by the effectiveness of SWA in finding flatter minima, FedSWA integrates SWA into the FL framework. A key innovation in FedSWA is its use of a cyclical learning rate and exponential moving average (EMA) for aggregating model weights on the server side. Unlike FedSAM, which uses a constant learning rate and simple averaging, FedSWA’s approach helps the model escape suboptimal local minima and converge to better, flatter global minima, especially in highly heterogeneous data environments.
Building upon FedSWA, an even more advanced algorithm, Federated Learning via Momentum-Based Stochastic Controlled Weight Averaging (FedMoSWA), has been introduced. FedMoSWA incorporates a momentum-based variance reduction mechanism. This mechanism helps to better align the local models trained by individual clients with the global model, effectively mitigating the impact of data heterogeneity. This alignment ensures that local models contribute more consistently to finding a globally flat and low minimum, which is crucial for superior generalization. FedMoSWA is also computationally more efficient than SAM-type algorithms because it avoids the extra forward and backward propagations needed for perturbation calculations.
The theoretical underpinnings of both FedSWA and FedMoSWA have been rigorously analyzed, providing convergence analysis and generalization bounds. These theoretical results indicate that FedMoSWA achieves smaller optimization and generalization errors compared to its predecessors, including FedSAM and its variants. Specifically, FedMoSWA significantly reduces the effect of data heterogeneity on generalization error.
Empirical evaluations on standard datasets like CIFAR10, CIFAR100, and Tiny ImageNet, using various neural network architectures such as LeNet-5, VGG-11, ResNet-18, and Vision Transformer (ViT-Base), demonstrate the superior performance of the proposed algorithms. FedMoSWA consistently outperforms other state-of-the-art FL methods, especially under conditions of high data heterogeneity. For instance, on the CIFAR-100 dataset with high heterogeneity (Dirichlet-0.1), FedMoSWA achieved significantly higher test accuracy and required fewer communication rounds compared to MoFedSAM.
The research also explored the impact of key hyperparameters. Increasing the ‘alpha’ parameter in FedMoSWA initially accelerates performance, with an optimal value found around 1.5. Similarly, the ‘gamma’ parameter, which influences the momentum-based update, showed optimal performance around 0.2. The ‘rho’ parameter, controlling the local learning rate decay, indicated that a faster decay (smaller rho) generally leads to better algorithm performance. The effectiveness of cyclical learning rates and momentum-based variance reduction were also confirmed through ablation studies.
Also Read:
- Enhancing Federated Learning with DAG-AFL: A New Approach for Asynchronous and Heterogeneous Environments
- Boosting AI Performance and Privacy at the Edge with Federated Layering
In conclusion, this study offers significant advancements in federated learning, particularly for scenarios with highly heterogeneous data. The proposed FedSWA and FedMoSWA algorithms provide effective solutions for improving model generalization and optimization errors. While FedMoSWA substantially mitigates the impact of data heterogeneity, future work aims to further eliminate its effects. For more details, you can refer to the full research paper: Improving Generalization in Federated Learning with Highly Heterogeneous Data via Momentum-Based Stochastic Controlled Weight Averaging.


