MASAM: Multimodal Adaptive Sharpness-Aware Minimization for Heterogeneous Data Fusion
Abstract
Multimodal learning requires integrating heterogeneous modalities, such as structured records, visual imagery, and temporal signals. It has been revealed that this heterogeneity causes modality encoders to converge at different rates, making the multimodal learning imbalanced. We empirically observe that such an imbalance is related to the sharpness of the solution. Modality encoders that converge faster could be dragged into sharp regions due to inter-modal interference, degrading the generalization capability of unimodal features learned. Sharpness-Aware Minimization is effective in improving generalization via finding solutions in flat regions. However, its application in multimodal scenarios is challenging: 1) SAM pays excessive attention to the dominant modality, exacerbating modality imbalance, and 2) the perturbation gradient calculation is affected by interference from other modalities. To address these issues, we propose Multimodal Adaptive Sharpness-Aware Minimization (MASAM), which optimizes different modalities based on their dominance. We design an Adaptive Perturbation Score (APS) using convergence speed and gradient alignment to identify dominant modalities for SAM application. Our Modality-Decoupled Perturbation Scaling (MDPS) then reduces inter-modal interference during optimization, better aligning each modality with shared information. Extensive empirical evaluations on five multimodal datasets and six downstream tasks demonstrate that MASAM consistently attains flatter solutions, achieves balanced multimodal learning, and subsequently surpasses state-of-the-art methods across diverse datasets and tasks.