Welcome to The Nonlinear Library, where we use Text-to-Speech software to convert the best writing from the Rationalist and EA communities into audio.
This is: interpreting GPT: the logit lens, published by nostalgebraist on the AI Alignment Forum.
This post relates an observation I've made in my work with GPT-2, which I have not seen made elsewhere.
IMO, this observation sheds a good deal of light on how the GPT-2/3/etc models (hereafter just "GPT") work internally.
There is an accompanying Colab notebook which will let you interactively explore the phenomenon I describe here.
[Edit: updated with another section on comparing to the inputs, rather than the outputs. This arguably resolves some of my confusion at the end. Thanks to algon33 and Gurkenglas for relevant suggestions here.]
[Edit 5/17/21: I've recently written a new Colab notebook which extends this post in various ways:
trying the "lens" on various models from 125M to 2.7B parameters, including GPT-Neo and CTRL
exploring the contributions of the attention and MLP sub-blocks within transformer blocks/layers
trying out a variant of the "decoder" used in this post, which dramatically helps with interpreting some models
overview
GPT's probabilistic predictions are a linear function of the activations in its final layer. If one applies the same function to the activations of intermediate GPT layers, the resulting distributions make intuitive sense.
This "logit lens" provides a simple (if partial) interpretability lens for GPT's internals.
Other work on interpreting transformer internals has focused mostly on what the attention is looking at. The logit lens focuses on what GPT "believes" after each step of processing, rather than how it updates that belief inside the step.
These distributions gradually converge to the final distribution over the layers of the network, often getting close to that distribution long before the end.
At some point in the middle, GPT will have formed a "pretty good guess" as to the next token, and the later layers seem to be refining these guesses in light of one another.
The general trend, as one moves from earlier to later layers, is
"nonsense / not interpretable" (sometimes, in very early layers) -->
"shallow guesses (words that are the right part of speech / register / etc)" -->
"better guesses"
...though some of those phases are sometimes absent.
On the other hand, only the inputs look like the input tokens.
In the logit lens, the early layers sometimes look like nonsense, and sometimes look like very simple guesses about the output. They almost never look like the input.
Apparently, the model does not "keep the inputs around" for a while and gradually process them into some intermediate representation, then into a prediction.
Instead, the inputs are immediately converted to a very different representation, which is smoothly refined into the final prediction.
This is reminiscent of the perspective in Universal Transformers which sees transformers as iteratively refining a guess.
However, Universal Transformers have both an encoder and decoder, while GPT is only a decoder. This means GPT faces a tradeoff between keeping around the input tokens, and producing the next tokens.
Eventually it has to spit out the next token, so the longer it spends (in depth terms) processing something that looks like token i, the less time it has to convert it into token i+1. GPT has a deadline, and the clock is ticking.
More speculatively, this suggests that GPT mostly "thinks in predictive space," immediately converting inputs to predicted outputs, then refining guesses in light of other guesses that are themselves being refined.
I think this might suggest there is some fundamentally better way to do sampling from GPT models? I'm having trouble writing out the intuition clearly, so I'll leave it for later posts.
Caveat: I call this a "lens" because it is one way of extracting information from GPT's internal activations. I imagine there is other information...
view more