Invited Talk 2 : Transformers learn in-context by implementing gradient descent
Suvrit Sra
2024 Invited Talk
in
Workshop: Bridging the Gap Between Practice and Theory in Deep Learning
in
Workshop: Bridging the Gap Between Practice and Theory in Deep Learning
Abstract
We study the theory of context learning, for which we investigate how Transformers can implement learning algorithms in their forward pass. We show that a linear attention Transformer naturally learns to implement gradient descent, which enables it to learn linear functions in-context. More generally, we show that a (non-linear attention based) Transformer can implement functional gradient descent with respect to some RKHS metric, which allows it to learn a broad class of nonlinear functions in-context. We show that the RKHS metric is determined by the choice of attention activation, and that the optimal choice of attention activation depends in a natural way on the class of functions that need to be learned.
Video
Chat is not available.
Successful Page Load