Align-SAM: Seeking Flatter Minima for Better Cross-Subset Alignment
Abstract
Sharpness-Aware Minimization (SAM) has proven effective in enhancing deep neural network training by simultaneously minimizing the training loss and the sharpness of the loss landscape, thereby guiding models toward flatter minima that are empirically linked to improved generalization. From another perspective, generalization can be seen as a model’s ability to remain stable under distributional variability. In particular, effective learning requires that updates derived from different subsets or resamplings of the same data distribution remain consistent. In this work, we investigate the connection between the flatness induced by SAM and the alignment of gradients across random subsets of the data distribution, and propose \textit{Align-SAM} as a novel strategy to further enhance model generalization. Align-SAM extends the core principles of SAM by promoting optimization toward flatter minima on a primary subset (the training set), while simultaneously enforcing low loss on an auxiliary subset drawn from the same distribution. This dual-objective approach leads to solutions that are not only resilient to local perturbations but also robust against distributional shifts in each training iteration. Empirical evaluations demonstrate that Align-SAM consistently improves generalization across diverse datasets and challenging settings, including scenarios with noisy labels and limited data availability.