Link to original articleWelcome to The Nonlinear Library, where we use Text-to-Speech software to convert the best writing from the Rationalist and EA communities into audio. This is: How To Do Patching Fast, published by Joseph Miller on May 14, 2024 on LessWrong.
This post outlines an efficient implementation of Edge Patching that massively outperforms common hook-based implementations. This implementation is available to use in my new library, AutoCircuit, and was first introduced by Li et al. (2023).
Link to original article
Welcome to The Nonlinear Library, where we use Text-to-Speech software to convert the best writing from the Rationalist and EA communities into audio. This is: How To Do Patching Fast, published by Joseph Miller on May 14, 2024 on LessWrong.
This post outlines an efficient implementation of Edge Patching that massively outperforms common hook-based implementations. This implementation is available to use in my new library, AutoCircuit, and was first introduced by Li et al. (2023).
What is activation patching?
I introduce new terminology to clarify the distinction between different types of activation patching.
Node Patching
Node Patching (aka. "normal" activation patching) is when some activation in a neural network is altered from the value computed by the network to some other value. For example we could run two different prompts through a language model and replace the output of
Attn 1 when the model is given some
input 1 with the output of the head when the model is given some other
input 2.
We will use the running example of a tiny, 1-layer transformer, but this approach generalizes to any transformer and any residual network.
All the nodes downstream of
Attn 1 will be affected by the patch.
Edge Patching
If we want to make a more precise intervention, we can think about the transformer differently, to isolate the interactions between components.
Now we can patch the edge
Attn 1 -> MLP and only nodes downstream of
MLP will be affected (eg.
Attn 1->Output is unchanged). Edge Patching has not been explicitly named in any prior work.
Path Patching
Path Patching refers to the intervention where an input to a path is replaced in the 'treeified' view of the model. The treeified view is a third way of thinking about the model where we separate each path from input to output. We can implement an equivalent intervention to the previous diagram as follows:
In the IOI paper, 'Path Patching' the edge
Component 1 -> Component 2 means Path Patching all paths of the form
where all components between
Component 1 and
Component 2 are
MLPs[1]. However, it can be easy to confuse Edge Patching and Path Patching because if we instead patch all paths of the form
this is equivalent to Edge Patching the edge
Component 1->Component 2.
Edge Patching all of the edges which have some node as source is equivalent to Node Patching that node. AutoCircuit does not implement Path Patching, which is much more expensive in general. However, as explained in the appendix, Path Patching is sometimes equivalent to Edge Patching.
Fast Edge Patching
We perform two steps.
First we gather the activations that we want to patch into the model. There's many ways to do this, depending on what type of patching you want to do. If we just want to do zero ablation, then we don't need to even run the model. But let's assume we want to patch in activations from a different, corrupt input. We create a tensor,
Patch Activations, to store the outputs of the source of each edge and we write to the tensor during the forward pass. Each source component has a row in the tensor, so the shape is
[n_sources, batch, seq, d_model].[2]
Now we run the forward pass in which we actually do the patching. We write the outputs of each edge source to a different tensor,
Current Activations, of the same shape as
Patch Activations. When we get to the input of the destination component of the edge we want to patch, we add the difference between the rows of
Patch Activations and
Current Activations corresponding to the edge's source component output.
This works because the difference in input to the edge destination is equal to the difference in output of the source component.[3] Now it's straightforward to extend this to patching multiple edges at once by subtracting the entire
Current Activations tensor from the entire
Patch Activations tensor and multiplying by a
Mask tensor of shape
[n_sources] that has a single value for each input edge.
By creating a
Mask tensor for each destination node w...
View more