TRIM: Hybrid Inference via Targeted Stepwise Routing in Multi-Step Reasoning Tasks
Vansh Kapoor · Aman Gupta · Hao Chen · Anurag Beniwal · Jing Huang · Aviral Kumar
Abstract
Multi-step reasoning tasks like mathematical problem solving are vulnerable to cascading failures where a single incorrect step leads to complete solution breakdown. Current LLM routing methods assign entire queries to one model, treating all reasoning steps as equal. We propose TRIM (Targeted Routing in Multi-step reasoning tasks), which routes only critical steps to larger models while letting smaller models handle routine continuations. Our key insight is that targeted step-level interventions can fundamentally transform inference efficiency by confining expensive calls to precisely those steps where stronger models prevent cascading errors. TRIM operates at step-level granularity using process reward models to identify erroneous steps and makes routing decisions based on step-level uncertainty and budget constraints. We develop four routing strategies: a simple thresholding policy, two RL-trained policies (one using full sequential features, another using aggregated statistics), and a POMDP-based approach that handles uncertainty in step-level correctness estimates. On MATH-500, the thresholding policy already surpasses contemporary routing methods with $6.51$x higher cost efficiency, while RL-trained and POMDP-based policies match the strong, expensive model’s performance using $80$% fewer expensive model tokens. All methods generalize effectively across mathematical reasoning datasets, demonstrating that step-level difficulty represents fundamental characteristics of multi-step reasoning.
Successful Page Load