Featuring Ananda Theertha Suresh, Google Research
Abstract: Autoregressive sampling from large language models has led to state-of-the-art results in several natural language tasks. However, autoregressive sampling generates tokens one at a time making it slow, and even prohibitive in certain tasks. One way to speed up sampling is speculative decoding (Leviathan et al., 2022): use a small model to sample a draft (block or sequence of tokens), and then score all tokens in the draft by the large language model in parallel. A subset of the tokens in the draft are accepted (and the rest rejected) based on a statistical method to guarantee that the final output follows the distribution of the large model.
In this talk, we provide a principled understanding of speculative decoding through the lens of distribution coupling and optimal transport theory. This new formulation enables us to improve upon speculative decoding in three ways: first we propose an optimal draft acceptance algorithm that provides additional wall-clock speedup without incurring additional computation cost. Next, we ask if the latency can be improved further with extra parallel computations? We answer this question affirmatively by showing that if we have multiple drafts from the small model, we can use them to improve the speedup further albeit using extra parallel computations. Finally, we consider algorithms for scenarios requiring consistent output with a fixed random seed, regardless of the specific small model used for speculation. We provide theoretical guarantees on the proposed algorithms and demonstrate the practicality of the algorithms on standard datasets.
Bio: Ananda Theertha Suresh is a research scientist at Google Research, New York. He received his PhD from University of California San Diego, where he was advised by Prof. Alon Orlitsky. His research focuses on theoretical and algorithmic aspects of machine learning. He is a recipient of the 2017 Paul Baran Maroni Young Scholar award and a co-recipient of best paper awards at NeurIPS 2015, ALT 2020 and CCS 2021.