1. Introduction
In recent years, deep neural networks have advanced machine learning, with over-parameterization increasing fitting capabilities and expanding complex solution spaces, risking overfitting [
1]. This reduces saddle points and increases local minima [
2], each with varying generalization abilities [
3]. Current studies have shown a positive correlation between loss basin flatness and generalization [
4,
5,
6,
7,
8,
9]. Flatness can be described using a Hessian matrix, which represents the second-order derivative [
10]. However, computing a Hessian matrix during gradient descent in a large network incurs significant computational costs. To improve the efficiency of gradient descent, machine learning typically relies on first-order optimizers, such as stochastic gradient descent (SGD), which, however, lack awareness of the flatness direction [
11].
Reconstructing the flatness perception of SGD is a key challenge for improving optimization. The Sharpness-Aware Minimization (SAM) strategy [
9] perturbs the gradient to approximate second-order curvature information [
12]. Specifically, SAM minimizes the worst-case loss within a perturbation region of radius
around the current solution. However, due to the diverse shapes of the loss landscapes across different networks and datasets [
13], selecting an appropriate
is challenging. A larger
may result in the inability to perceive loss basins, as there may not exist a loss basin in the loss landscape that meets the required radius, leading to divergence and higher training losses.
To address this, Algorithm 1 ASAM [
14] normalizes the solution space before setting
, aiming to stabilizing optimization. GSAM [
15] further normalizes gradients to mitigate uncertainty, allowing
to be set within a controlled range (typically between 0 and 1). However, manual tuning remains necessary. Inspired by the concept of linear mode connectivity [
16], we propose an approach that models
dynamically using two points, A and B, at the end point in linear mode connectivity of the solution space. This eliminates the need for manual tuning, making SAM-based optimization more adaptive and robust.
Algorithm 1 ASAM [14] |
Input: DataSet S, Loss function L, Batch size b, Step size , Neighborhood size , Initial model , Output: Model trained with ASAM Start at repeat Sample batch Compute gradient of the batch’s training loss; Compute Compute gradient approximation for the SAM objective Update weights: until converged return
|
In summary, our heuristic strategy gradually reduces the distance between two points in the solution space, guiding the search toward flatter regions that improve the generalization. By incorporating connectivity awareness into sharpness-aware minimization, we enable the optimization process to target these flatter regions, making the search more effective. Linear mode connectivity improves our understanding of the relationships between minima, allowing the algorithm to converge to connected minima along linear modes, thereby enhancing the optimization. Experimental results confirmed that our method achieved improved convergence and better generalization.
We propose a heuristic training strategy aimed at generating linearly connected mode endpoints [
16]. The radius
of the SAM method is modeled using two points, A and B, in the solution space, to avoid manually setting the hyperparameter
. Specifically, we redefine the SAM open ball using A and B. The midpoint between A and B serves as the center, and the radius
is defined as half the distance between them (see
Figure 1). In each training iteration, our heuristic algorithm optimizes the models corresponding to points A and B, while adding a distance regularization constraint that simulates the SAM optimization process. To ensure connectivity, we employ a subspace learning strategy [
17]. Specifically, our algorithm initializes the optimization process at the center of points A and B in each iteration and applies subspace optimization techniques to maintain a linearly connected path [
17,
18]. The goal of our strategy is to gradually reduce the distance between A and B, guiding them to converge toward the same flat local minimum basin or neighboring local minimum basins with low loss barriers, as illustrated in
Figure 1. This process is akin to a binary star system, where A and B behave like stars in space, mutually influencing each other’s gravitational pull, thus reducing the distance between them over time.
Our contributions are summarized as follows:
We explored how to achieve linear mode connectivity in sharpness-aware optimization, and we propose that connectivity can be leveraged to dynamically adjust the radius in SAM optimization, removing the need for manual tuning.
We developed a heuristic search algorithm that combines flatness optimization with linear connectivity. This algorithm simulates the gradual convergence of two points in the solution space, ensuring that when the search stops, the two solutions exhibit both linear connectivity and flatness within a geometric range that includes their neighbors.
We conducted comprehensive validation experiments, including testing across multiple network architectures and visualization studies, confirming the effectiveness of our approach.
3. Preliminary
3.1. Sharpness-Aware Minimization
The classification model operates within the joint probability distribution space , where represents the sample space and denotes the label space. The training set, denoted as , adheres to the distribution , and consists of n samples, , independently sampled from . The parameters of the model, denoted by , are used to model the region around with an open ball , where is the radius and is the center for the open ball B in the Euclidean space, i.e., .
The training loss, , is the average of individual losses over the training set. The generalization gap, which compares the expected loss to the training loss , measures the ability of the model to generalize to unseen data.
Sharpness-Aware Minimization (SAM) aims to minimize the following PAC-Bayesian generalization error upper bound:
With optimization operators for
, there is an empirical risk of minimization.
h is a strictly increasing function that can be replaced by a norm function. Sharpness-Aware Minimization (SAM) [
9] employs a one-step backward approach to obtain gradient approximation:
Then, SAM computes the gradient with respect to the perturbed model
for the second step update:
The detailed SAM is presented in Algorithm 2.
Algorithm 2 SAM [9] |
Input: DataSet S, Loss function L, Batch size b, Step size , Neighborhood size , Initial model Output: Model trained with SAM Start at repeat Sample batch Compute gradient of the batch’s training loss; Compute Compute gradient approximation for the SAM objective Update weights: until converged return
|
3.2. Analysis of Sharpness-Aware Minimization
An optimal perturbation radius
is crucial for Sharpness-Aware Minimization (SAM) convergence. The ideal
should balance the local curvature and flat region geometry, but an explicit search is challenging due to non-hyperspherical flat regions and their network-dependent variability [
13]. In non-convex optimization, geometric uncertainty forces a reliance on trial-and-error for
selection.
To address the challenge of hyperparameter selection, we propose a principled method for selecting from the perspective of linear mode connectivity, to enable a more efficient screening. Through observation, we identified that flat regions in the loss landscape inherently satisfy linear connectivity. Based on this, we present the following theory.
Theorem 1. The optimal is the maximal diameter of linearly mode connected regions [16], defined as: Theorem 1 ensures optimization within flat regions, but non-convexity complicates direct implementation. We instead use backward reasoning, approximating via nonlinear mode connectivity.
Theorem 2. For any pair of parameters . If and exhibit a loss barrier (i.e., are not linear mode connected), they cannot coexist in a ball of any radius , also including .
Although Theorem 2 cannot directly lead us to an appropriate solution, it can help us eliminate distracting solutions.
Building on both insights, we propose an implicit
-search algorithm that operationalizes linear connectivity as a dynamic selection heuristic. First, we reformulate the optimization of SAM by modeling the endpoints of its hyperspherical perturbation as linear mode connectivity (
Section 4). Subsequently, we introduce a Twin Stars Entwined (TSE) algorithm to co-optimize these paired models, steering the gradient search toward flatter regions (
Section 5).
5. Proposed Algorithm
To better find flatter local minimum basins, SAM (Sharpness-Aware Minimization) and subspace learning provide several factors. The details are illustrated in
Figure 2.
To provide an intuitive description of our motivation, we present the following facts:
- 1.
Ideal SAM guarantees connectivity for the convergent solution and a hyperball centered at this solution with radius .
- 2.
A larger to attempt convergence to flatter regions. However, if is too large, the algorithm may oscillate due to the absence of basins in the solution space that meet the current requirement.
- 3.
Based on Theorem 2, the connectivity of two points A and B in the solution space determines whether they lie within the hyperball of ideal SAM.
- 4.
If SAM optimization with radius is applied to endpoints A and B separately, under ideal conditions, A and B will converge, and if the distance between A and B is less than or equal to , then A and B are linearly connected.
- 5.
When Fact 4 is satisfied, the flat regions guaranteed by A and B are adjacent. The combined flat region guaranteed by A and B is larger than either of them individually, and thus can be regarded as a flatter region.
- 6.
If the line connecting A and B is treated as the diameter of a virtual ball, and Fact 4 is satisfied, then the virtual ball formed by A and B may also conform to a larger ideal SAM hyperball with radius .
Summarizing the above facts, we propose a new optimization approach to implicitly obtain a larger flat region. We define an optimization process where we model the line connecting
A and
B as a virtual ball with radius
. Assuming
A and
B are optimized using
-SAM and exhibit nonlinear connectivity upon convergence, we use distance regularization to bring
A and
B closer, and restart training until linear connectivity is achieved, at which point we obtain a larger flat region. We further model the process of A and B approaching each other as a process with T iterative steps, as detailed in
Section 6. During optimization, the target radius is determined by the iterates
and
at step
t, defined as
As optimization progresses,
and
move closer, until they enter the same or adjacent loss basins, satisfying
Our algorithm has the potential to converge if -SAM converges. If the manual setting of for the SAM algorithm converges, there exists at least one region in the solution space where both points A and B can converge simultaneously. The extreme case is when A and B collapse into a single point as they approach each other.
Our algorithm offers two key advantages. First, our algorithm eliminates the need for laborious manual tuning of —instead of exhaustively searching for an optimal value, it only requires an initial that ensures SAM convergence, then automatically expands the search for flatter regions. Second, our algorithm inherently avoids poor solutions by design: supported by convergence analysis and Facts 5–6, the method stabilizes SAM while providing a provable lower-bound guarantee on solution quality.
9. Discussion and Limitations
9.1. Failure Case Analysis
The main challenge with the algorithm arises from obtaining the model in the initial step. Factors influencing this include incorrectly set hyperparameters, such as
and learning rate. If the initial models A and B are difficult to converge, this implies that the iterative training process may produce jumps. To avoid excessive consumption of computational resources, the training process can be further analyzed for biases through a visual analysis at each step, as shown in the
Figure 10.
Visualizing the weights obtained at each step helps with error analysis. First, compare the weights obtained by the different optimizers and their performance metrics. In
Figure 10, the linear interpolations among SGD, SAM, and ASAM exhibit higher loss barriers, suggesting that these optimization methods converge to distinct solutions. Meanwhile, there was no significant difference in performance, suggesting that the current training hyperparameter settings were reliable. Second, analyzing the performance connectivity between the current and previous steps helps determine whether the search iteration has introduced any biases, such as checking whether the distance regularization term produced the expected effect. Third, a comprehensive analysis of the connectivity between the solutions at each step can help determine if the process is converging to a flat region. In a good optimization process, the linear interpolation loss barriers between the models at later steps show a decreasing trend, indicating that the search direction is likely becoming flatter. Otherwise, it may be necessary to return and adjust the training hyperparameters.
9.2. Consumption Analysis
Our algorithm only adds a regularization term to each step of the test, resulting in no significant increase in memory consumption, apart from the overhead introduced by the regularization term. The regularization term requires storing the Fisher matrix of the network’s learnable weights.
The computational time of our algorithm is difficult to estimate directly. Therefore, we provide a table listing the number of forward and backward propagations during training. Assuming that the computational cost of forward and backward propagation for one epoch of stochastic gradient descent (SGD) is taken as 1 unit, and one full dataset pass consists of n epochs, we define the maximum number of iterations of our algorithm as I. The table below shows the worst-case computational cost of our algorithm:
The
Table 6 describes the worst-case scenario, where the number of training epochs per iteration is the same as in standalone training. To accelerate the process, an early stopping strategy can be applied during iterations, while setting a smaller number of epochs. The two additional passes added in each iteration are required to generate the Fisher matrix.
9.3. Stop Condition Analysis
The loss variance between two endpoints is a robust criterion for terminating iterations, ensuring consistent performance.
Table 7 shows the loss variance of linear interpolation models between the endpoint models at each step. As iterations increased, this variance decreased monotonically. Optimal results typically occurred near the minimum variance step, except for ResNet-50 on CIFAR-100. Due to shared hyperparameters, PyramidNet-110-270 required more steps to reduce variance. We proposed a threshold of 0.1 as a stopping criterion. Additionally, visualizing the loss curves of the interpolation models at the conclusion of each iteration step provides significant insights, as shown in
Figure 5 and
Figure 6. Such visualizations enable the assessment of endpoint model convergence and facilitate a comparative analysis of the performance of the two endpoints, thereby informing decisions on whether to proceed with further iterations.
Note that we prefer to set a threshold where iterations are terminated once the variance falls below this threshold, thereby avoiding excessive iterations. One of the motivations behind our algorithm was to model the endpoints in search of flatter regions. Excessive iteration steps may cause the endpoints to gradually converge and eventually collapse into a single point.
9.4. Meta and Heuristic Learning
Our approach primarily emphasizes regularization and fine-tuning, aligning more closely with greedy algorithms within the heuristic optimization paradigm. However, the discussion of heuristic methods and meta-learning extends beyond the core scope of this paper. Nevertheless, our algorithm can be regarded as a foundation for further applications of heuristic algorithms, which, in turn, have the potential to enhance our method.
First, from the perspective of meta-learning, our method (TSE) lacks a memory mechanism. Meta-learning focuses on learning to learn, where the optimization process may carry more generalizable information about the loss landscape geometry [
45,
46,
47]. In contrast, our TSE relies solely on the model state from the previous step for updates. A potential direction for improvement would be to incorporate multi-step historical information to enhance the optimization performance. Furthermore, integrating more comprehensive metrics could facilitate a more accurate assessment of whether the current search direction is appropriate.
Second, heuristic algorithms offer advantages in refining regularization terms and selecting search starting points. For instance, combinatorial optimization strategies such as evolutionary algorithms [
48] and genetic algorithms [
49] can optimize the selection of iteration starting points, while providing greater flexibility in adjusting the weights of regularization terms. The incorporation of these methods has the potential to further improve the efficiency and robustness of model optimization.
9.5. Future Directions
Our algorithm and visualization can provide a more intuitive way to search for flat regions. Our algorithm method offers an adaptive approach for flat region detection. This method follows a bounding strategy: it begins by assuming a sufficiently large flat region, and if the current loss landscape does not satisfy this assumption, the region is reduced and the search is repeated. Compared to the SAM cosine method, which sets a radius
, our approach is more purposeful. SAM lacks intuitive feedback to guide further hyperparameter tuning. As shown in
Figure 4, the stop distance achieved for different network architectures indicates that SAM requires specific
settings for each architecture, which involves extensive tuning and hyperparameter adjustments. Our method can be seen as a tool for intuitively understanding the flatness properties of the loss landscape in different network solution spaces.
Our method has led to two interesting observations, providing a foundation for further research. Assuming the results represent the flattest loss regions, future research may focus on the following:
- 1.
Model Architecture and Flat Region Size. As shown in
Figure 4, the same network architecture corresponds to different flattest regions, depending on the dataset. Similarly, for the same dataset, different network architectures exhibit distinct flattest regions. It is possible that a normalization framework could unify this phenomenon. Furthermore, the size of the flattest region could guide the designing of model architectures.
- 2.
Euclidean Distance vs. Cosine Distance. As the Euclidean distance between two points decreases, the learned parameter vectors of the two models tend to become more orthogonal. This suggests that, although the Euclidean distance diminishes, model diversity increases. This observation is beneficial for better application of ensemble learning. When using ensemble learning to obtain a set of models, orthogonal regularization is often applied without considering the Euclidean distance between models. A potential issue arises: the generalization ability of an ensemble of models after orthogonal regularization is related to whether these models reside in the same loss basin. Our TSE algorithm provides a controllable tool to support further research in this area.
- 3.
Combining heuristic algorithms and meta-learning. A meta-learning algorithm memorizes gradient trajectory information, which is beneficial for designing better regularization and optimization iteration steps. Heuristic algorithms provide more diverse and efficient combinatorial forms, facilitating the selection of starting points and anchor points for the optimization iteration algorithm.