Skip to yearly menu bar Skip to main content


Nebius

Expo Talk Panel

Kvax: Fast and easy-to-use Flash Attention implementation for JAX

Sergei Skvortsov


Abstract:

Kvax is a custom FlashAttention implementation for JAX, optimised for long-context training with efficient document mask computation and context parallelism. This talk explores the key ideas behind its implementation, focusing on document mask performance optimisations and context parallelism.

Chat is not available.