Research Scientist at Google Brain
Research interests: machine learning; probability theory; random matrices; high-dimensional statistics; evolutionary dynamics
I’m a Research Scientist at Google Brain working to understand deep learning and apply it to basic-science problems. I joined Google in 2018 as an AI Resident, and before that I was a PhD student in applied math at Harvard, where I used techniques from probability theory and stochastic processes to study evolutionary dynamics and random matrices.
Deep learning theory
Beginning with computer vision and speech recognition, deep learning is now the preeminent approach to most machine learning tasks. This has prompted theoretical work to understand to which properties neural networks owe their success and how we might explain the many perplexing empirical results. In fact, deep learning appears to be part of a long standing trend of jointly expanding to larger models and datasets.
My research explores this joint, high-dimensional limit. Specifically, I use techniques from random matrix theory to exactly calculate the generalization properties of random feature methods. Random features share some properties with neural networks but are statistically relatively simple. This makes them a good null model for neural networks and allows me to separate which phenomena are unique to neural networks and which are merely a consequence of high-dimensional statistics. So far, my work has uncovered more similarities than differences. Random features benefit from overparameterization, exhibit multiple descents, share similar sources of variance, and show similar features under distribution shift.
Applications in basic science
The rush to deploy deep learning has brought with it a host of new hardware and software. I am interested in how these tools might be repurposed to improve and change computing in basic science. These domains often have stricter requirements from their models than risk minimization (like robustness to distribution shift, calibration, etc.) necessitating a more theoretically grounded approach that relies less on good intuition and trial-and-error.
For example, I wrote a simulator for the SIR model in structured populations using JAX. Writing the simulator to run efficiently on GPU and TPU facilitated significantly faster and larger scale experiments. Moreover, automatic differentiation enables first-order optimization to select model hyperparameters to serve a particular goal. I am interested in extending this approach to other Markov chain models from biology and statistical physics.
As another example, theoretical work on infinite-width neural networks has uncovered a new class of high-performance kernels for Gaussian processes. While Gaussian process inference is known to be challenging for large datasets, these kernels bring the additional difficulty that they are themselves expensive to compute. Using parallelization and JAX, I have built infrastructure to use these Gaussian processes for much larger datasets. Next steps are to benchmark these models on large, highly-structured data, like molecular or sequence data.