In this post, we consider how to make language models better, not just faster, inspired by several papers mentioned below, but triggered by this April 2024 paper from Meta, which:
· Steps forward from speculative decoding into pre-training
· Provides new, orthogonal insights into model capability vs. size
We are huge fans of leveraging more information in language model pre-training and in augmenting transformer architecture, such as demonstrated in several papers mentioned herein:
· Adding heads to the transformer architecture has various utility.
· Multiple heads increase text generation performance significantly.
Training language models only to more accurately predict the word that comes immediately after some text coming before it is seemingly naïve and yet surprisingly effective. It seems naïve because it doesn’t seem to teach the model as much about the meaning of words or sentences as training to predict a word or phrase given the text around it. And yet, it is surprisingly effective at teaching a model to generate impressive text, as first demonstrated by Open AI in 2018.
By mid-2020, a 175 billion parameter GPT-3 model demonstrated not only outstanding text generation ability, but the ability to seemingly learn from examples given in the prefix context. Just as Google’s invention of the transformer architecture pulled the field away from recursive neural networks (RNN), OpenAI pulled the field away from alternatives to its simple architecture and training regimen.
Next word prediction views the text coming before the word as, in some sense, “causing” the word which comes after it. The task of the language models is to “decode” the next word given preceding text. Consequently, GPT-like language models are known as “causal decoders”.
When generating text, such a model predicts the next word, effectively by estimating the probability that each word in its vocabulary will come next. The vocabulary is typically many tens of thousands of words (i.e.,pieces of text, called tokens). Text is generated one token at a time. The time taken to generate a sentence or paragraph of text is proportional to the number of times the hardware is asked to compute the probabilities of each token in the vocabulary coming next. [1]
Speculative decoding aims to accelerate text generation by generating more than one token at a time. For example, if 2 tokens are appended to the prefix each step, only half as many steps will be necessary, thereby doubling the speed at which text is generated.
Since initial work in 2018, speculative decoding has employed smaller, faster models to predict a few tokens forward and then employ the larger, slower model to evaluate a few such predictions in parallel. For example, a few hundred million parameter model might speculate much faster than a few billion parameter model. Then, batch prediction using the larger model would accept 0 or more of the speculated tokens. This accelerates decoding unless no speculated tokens are accepted. [2]
The more speculated tokens align with the model’s predictions the greater the acceleration of speculative decoding. In practice, such speculation accelerates decoding several fold, provided the speculating model is well-aligned with the predicting/generating model.
Increasing decoding performance by a factor of two or more has enormous practical, economic, and environmental implications. [3] Nonetheless, speculative decoding is only recently “catching on”. The issue of aligning speculation with prediction has held it back.
More recently, within the last quarter, “self”-speculative models use various approaches to avoid the necessary alignment of fast predictive models with more capable models. Self-speculation adds relatively small computations to a large model itself rather than a second model.
Folks have looked at self-speculation using earlier layers of the full model, but how much of the full model to compute in order to be sufficiently aligned limits the potential acceleration. For example, if you have to compute most of the layers to productively speculate, generation performance cannot be doubled. [4]
Medusa and Meta’s work take a different approach. They simply add heads to the transformer architecture. These additional heads are trained to speculate the next several tokens beyond the existing head of the transformer architecture which is trained to predict the next token.
Such stunningly simple adaptation is remarkably effective. The approach straightforwardly accelerates decoding by a factor of two or more. The additional heads can be trained with or without fine-tuning the layers of the transformer.
Meta takes the concept further, however, and trains all the heads and layers, end-to-end, in the course of pre-training. In effect, this allows the language model to learn more from less. It can learn more about what’s coming, not just what’s next. And it can learn from less because it gets more signal per training step.
There are many fascinating aspects of the Meta paper, including the investigation into the need for sufficient scale (i.e., billions of parameters) and speculating many bytes forward (potentially eliminating tokenization, vocabulary issues, etc.). Some of the more interesting observations include:
1. the improvement on coding tasks resulting from pre-training on what comes soon rather than what comes next
2. improved fine-tuning results after pre-training on what comes soon
3. improved summarization results versus what comes next
These are material improvements!
The paper also discusses acouple of tasks that did not benefit from such pre-training of 7 billionparameter models, notably multiple-choice tasks. We agree with the authors that such tasksevaluate aspects other than generation, however, such as memory and reasoning.
Overall, according to the paper, models trained on what comes soon are better at generation but appear “only as good” at multiple choice as models of similar size trained on what comes next. Remember, however, that abilities of language models seem to “emerge” as scale increases. It remains to be seen whether models trained on what comes next are better at in-context learning or reasoning, in general. Apparently, they are not worse while being better in other regards.
[1] There are important details here, such as the difference between “greedy” decoding which blindly accepts the most probable next token and “sampled”decoding, which searches among plausible next tokens to some depth. Generally, sampling produces higher quality text, but in many cases greedy decoding suffices in practice.
[2] Consider a batch of 0 and 1 speculated tokens. If the 1 speculated token is not the highest probability, the time to speculate has been lost. Otherwise, the language model generates 2 tokens (the speculated one and the one after) in the time to speculate plus the time to generate 1 token.
[3] Perhaps Open AI has been employing speculative decoding since a noticeable acceleration in its decoding performance a few quarters ago?
[4] Kangaroo tries to address this by augmenting a shallow layer with a trained adapter.