Distributionally Robust Neural Networks
The results suggest that regularization is critical for worst-group performance in the overparameterized regime, even if it is not needed for average performance, and introduce and provide convergence guarantees for a stochastic optimizer for this group DRO setting.
Abstract
Overparameterized neural networks trained to minimize average loss can be highly accurate on average on an i.i.d. test set, yet consistently fail on atypical groups of the data (e.g., by learning spurious correlations that do not hold at test time). Distributionally robust optimization (DRO) provides an approach for learning models that instead minimize worst-case training loss over a set of pre-defined groups. We find, however, that naively applying DRO to overparameterized neural networks fails: these models can perfectly fit the training data, and any model with vanishing average training loss will also already have vanishing worst-case training loss. Instead, the poor worst-case performance of these models arises from poor generalization on some groups. As a solution, we show that increased regularization---e.g., stronger-than-typical weight decay or early stopping---allows DRO models to achieve substantially higher worst-group accuracies, with 10% to 40% improvements over standard models on a natural language inference task and two image tasks, while maintaining high average accuracies. Our results suggest that regularization is critical for worst-group performance in the overparameterized regime, even if it is not needed for average performance. Finally, we introduce and provide convergence guarantees for a stochastic optimizer for this group DRO setting, underpinning the empirical study above.