STAIRS-Former: Spatio-Temporal Attention with Interleaved Recursive Structure TransFormer for Offline Mulit-task Multi-agent Reinforcement Learning
Abstract
Offline multi-agent reinforcement learning (MARL) with multi-task (MT) datasets poses unique challenges, as input structures vary across tasks due to the varying number of agents. Prior works have adopted transformers and hierarchical skill learning to facilitate coordination, but these methods underutilize the transformer’s attention mechanism, focusing instead on extracting transferable skills. Moreover, existing transformer-based approaches compress the entire history into a single token and input this token at next time step, forming simple recursive neural network (RNN) processing on history tokens. As a result, models rely primarily on current and near-past observations while neglecting long historical information, even though the partially observable nature of MARL makes history information critical. In this paper, we propose STAIRS-Former, a transformer architecture augmented with spatial and temporal hierarchies that enables the model to properly attend to critical tokens while effectively leveraging long history. To further enhance robustness across varying token counts, we incorporate token dropout, which improves generalization to diverse agent populations. Experiments on the StarCraft Multi-Agent Challenge (SMAC) benchmark with diverse multi-task datasets show that STAIRS-Former consistently outperforms prior algorithms, achieving new state-of-the-art performance.