June 27, 2024

Gautam Reddy

Introduction 

A striking feature of large language models is their ability for in-context learning. In-context learning (ICL) is the ability to predict the response to a query based on illustrative examples presented in the context, without any additional weight updates. Consider this interaction with ChatGPT:

Chat GPT example

The model certainly has never seen this particular input during training. But, as this simple example illustrates, it presumably “knows” the concept of marries, understands that it is commutative and correctly parses the relationships between the entities to produce the appropriate response.

In this post, I will focus on one version of ICL where the network learns new relationships from exemplars in the input sequence, which differs from situations where the prompt helps the network recognize which task to perform. Efforts to quantify and understand this form of ICL have mainly focused on training small networks to solve well-defined synthetic tasks. The network is presented with a sequence of N item-label pairs x1, ℓ1, x2, ℓ2, . . . , xN , ℓN and is asked to predict the label of a target item xq. For example, in an in-context linear regression task, the items are D-dimensional random vectors and the labels in a particular sequence are produced as i = ηT xi for randomly drawn regression coefficients η. When trained over many sequences, each with different regression vectors, attention-based networks learn an operation that emulates gradient descent over the examples presented in the sequence (see, for example[1]). In other words, the networks learn a generic algorithm for in-context linear regression. Moreover, the ICL solution in such tasks is often learned abruptly: a phase of slow increase in performance is followed by a sharp jump to near perfect accuracy (Figure 1).

 

IC accuracy during training diagram

Figure 1: The IC accuracy during training in an in-context classification task

There are various questions raised by these observations: 1) are attention-based networks particularly well-suited for ICL? 2) what factors determine whether the network learns a general ICL solution instead of memorizing the training data? 3) why is the acquisition of the ICL solution abrupt? My recent work[2] examines the latter two questions in a simplified in-context classification task. Here, I will focus on an argument (detailed in the paper) for why we should generically expect to see abrupt learning when we train attention-based networks on ICL tasks.

High-level summary 

The gist of the argument is that the ICL solution is implemented by a specific sequence of read-write operations. These operations are implemented by successive layers of the attention-based network (for simplicity, let’s consider attention-only networks that do not have MLPs at each layer). Since each layer of an attention-based network (with soft-max attention) involves a soft-max operation, a sequence of operations across layers will involve a sequence of nested exponentials. These nested exponentials create sharp “cliffs” in the loss landscape. In certain cases (but not always), a weak intrinsic curriculum may gradually guide the network in flat regions of landscape. Abrupt learning occurs once the network reaches the edge of a cliff.

In the rest of this post, I will outline this argument using a paradigmatic in-context copying task, where a phenomenological model provides some quantitative insight into ICL learning dynamics.

In-context copying in a two-layer attention network

Consider the simplest version of an in-context copying task. For example, a particular sequence may include two item-label pairs ((B, 2) and (A, 1)) and a target item B. The sequence presented to the network is B, 2, A, 1, B and the expected response is 2. When trained over many such sequences, a two-layer, single-head attention-only network learns what has been termed an induction head (we need at least two layers for this to work). To understand what an induction head does, it is useful to imagine each token’s embedding to have three orthogonal components: one that encodes its content, one that encodes its position in the sequence and another buffer which is initially empty. 

An induction head involves two read-write operations followed by classification (illustrated in Figure 2). In the first layer, each token uses its positional information to pay attention to the token before it. For example, the 2 pays attention to B and A pays attention to 2. The content of the targeted token is written to the buffer of the original token. 2’s buffer now contains the content of B

In the second layer, each token uses its content to pay attention to the buffer of the tokens that come before them (due to a causal mask). Here, the target B (the second B in the sequence) will pay attention to 2 since 2’s buffer contains the content of the first B in the sequence. The content of 2 is now written into the buffer of the target B. The target B equipped with this information is read by a standard classifier, which reads the 2 in its buffer and predicts the label 2, the correct response. 

 

Figure illustrating the operations involved in two-layer induction head

Figure 2: An illustration of the operations involved in a two-layer induction head.

 

Three soft-max operations 

The two steps in an induction head involve two soft-max operations of the form:

equation1

where Qi and Kj are query and key vectors respectively. In our simplification of each token into three components, this dot product will involve a total of nine possible interactions between the two tokens. In the first layer, only one of these is relevant, that is, the position-position interaction. We can parameterize the “strength” of this interaction as β and ignore the other eight terms. The attention paid by a token at position N + 1 to the token at N has the form 

equation2

In the second step, the target item uses its content to pay attention to the buffer of the other tokens. Again, of the nine possible interactions, only the content-buffer interaction is the relevant one and we ignore the rest. Suppose there is only one copy of the target in the sequence. If there are N item-label pairs, the attention paid by the target item to its copy in the sequence is:

equation3

where α parameterizes the strength of this interaction and y matters here as it scales the magnitude of what’s written in the buffer in the first layer. The attention paid to the other items is z′ = 1/(2N + eαy). After writing the contents of the items into its buffer, the target item’s buffer now contains a linear combination of the contents of the other 2N tokens weighted by the attention paid to each one: zt + z′ ∑Ni̸=t i + other irrelevant terms, where t is the index of the correct label. 

The target item’s buffer is read by the classifier which performs soft-max regression to predict the correct label amongst, say, L possible labels. The classifier learns a set of regression vectors which map the label vector to its corresponding label. This involves another soft-max operation. The key parameter is the overlap of a regression vector γi with its corresponding label vectors, 2 relative to other label vectors:  ζij = γTii − γ Tij for j ̸= i. If the labels are statistically identical and balanced, we can assume ζij = ξ independent of i and j. 

 

Figure displaying loss curves and the loss landscape

Figure 3: Loss curves and the loss landscape of the phenomenological model.

The loss landscape

The network is trained using a cross-entropy loss. After a few simplifications, we obtain an expression for the loss in terms of the three parameters β, α and ξ:

(L=log(1+(N-1) e^(-u)+(L-N) e^(-u') ),)

where u = ξ(e αy − 1)/(e αy+ 2N), u′ = ξeαy/(e αy + 2N) and LN. This is a phenomenological characterization of an induction head’s loss landscape, which captures important features but certainly ignores some finer details. The key point is that the loss contains three nested exponentials, which creates a “cliff” in the loss landscape as shown in Figure 3. Moreover, examining the loss at the origin shows that there is a weak gradient in ξ which guides the network towards the cliff. This is because the network can do slightly better than chance (accuracy of 1/L) by randomly picking one of the labels in the input sequence (accuracy of 1/N). This sub-optimal strategy acts as an intrinsic curriculum as the network slowly aligns the label vectors with the regression vectors (thereby increasing ξ). Removing this gradient by simply setting L = N abolishes learning even in the full model!


References

[1] Johannes von Oswald, Eyvind Niklasson, Ettore Randazzo, João Sacramento, Alexander Mordvintsev, Andrey Zhmoginov, Max Vladymyrov “Transformers learn in-context by gradient descent” 2023.

[2] Gautam Reddy "The mechanistic basis of data dependence and abrupt learning in an in-context classification task" 2024