Memorization Through the Lens of Sample Gradients
Deepak Ravikumar · Efstathia Soufleri · Abolfazl Hashemi · Kaushik Roy
Abstract
Deep neural networks are known to often memorize underrepresented, hard examples, with implications for generalization and privacy. Feldman & Zhang (2020) defined a rigorous notion of memorization. However it is prohibitively expensive to compute at scale because it requires training models both with and without the data point of interest in order to calculate the memorization score. We observe that samples that are less memorized tend to be learned earlier in training, whereas highly memorized samples are learned later. Motivated by this observation, we introduce Cumulative Sample Gradient (CSG), a computationally efficient proxy for memorization. CSG is the gradient of the loss with respect to input samples, accumulated over the course of training. The advantage of using input gradients is that per-sample gradients can be obtained with negligible overhead during training. The accumulation over training also reduces per-epoch variance and enables a formal link to memorization. Theoretically, we show that CSG is bounded by memorization and by learning time. Tracking these gradients during training reveals a characteristic rise–peak–decline trajectory whose timing is mirrored by the model’s weight norm. This yields an early-stopping criterion that does not require a validation set: stop at the peak of the weight norm. This early stopping also enables our memorization proxy, CSG, to be up to five orders of magnitude more efficient than the memorization score from Feldman & Zhang (2020). It is also approximately 140 $\times$ and 10$\times$ faster than the prior state-of-the-art memorization proxies, input curvature and cumulative sample loss, while still aligning closely with the memorization score, exhibiting high correlation. Further, we develop Sample Gradient Assisted Loss (SGAL), a proxy that further improves alignment with memorization and is highly efficient to compute. Finally, we show that CSG attains state-of-the-art performance on practical dataset diagnostics, such as mislabeled-sample detection and enables bias discovery, providing a theoretically grounded toolbox for studying memorization in deep networks.
Successful Page Load