Gradient-based optimization methods have proven successful in learning complex, overparameterized neural networks from non-convex objectives. Yet, the precise theoretical relationship between gradient-based optimization methods, the induced training dynamics, and generalization in deep neural networks remains unclear. In this work, we investigate the training dynamics of overparameterized neural networks under natural gradient descent. Taking a function-space view of the training dynamics, we give an exact analytic solution to the training dynamics on training points. We derive a bound on the discrepancy between the distributions over functions at the global optimum of natural gradient descent and the analytic solution to the natural gradient descent training dynamics linearized around the parameters at initialization and validate our theoretical results empirically. In particular, we show that the discrepancy between the functions obtained from linearized and non-linearized natural gradient descent is provably smaller than under standard gradient descent, and we demonstrate empirically that the discrepancy is small for overparameterized neural networks without needing to make a limit argument about the width of the neural network layers, as was done in previous work. Finally, we show that our theoretical results are consistent with the empirical discrepancy between the functions obtained from linearized and non-linearized natural gradient descent and that the discrepancy is small on a set of regression benchmark problems.
Tim G. J. Rudner, Florian Wenzel, Yee Whye Teh, Yarin Gal
Contributed talk, Workshop on Bayesian Deep Learning, NeurIPS 2019