Decoder-Only Transformers

A look into the architecture that powers most LLMs.

← Back

Note: This post is incomplete. I’m actively working on writing the remaining sections. For now, here’s are two short sections on self-attention and multi-head attention.

Self-Attention

At the core of the transformer is the attention mechanism. For each input token \(\mathbf{x}_n\), the attention mechanism aggregates information from surrounding tokens and produces a new representation \(\mathbf{y}_n\). This is known as self-attention, because the input sequence attends to itself to create a new representation of the same sequence. The simplest way to compute this is to take a linear combination of all the input tokens,

\[ \mathbf{y}_n=\sum_{m}^N a_{nm} \mathbf{x}_m \] A token \(\mathbf{x}_m\) that is closely related \(\mathbf{x}_n\) should contribute a lot to the computation of \(\mathbf{y}_n\). Conversely, tokens that have very little to do with \(\mathbf{x}_n\) are unlikely to provide any interesting context. Intuitively then, \(a_{nm}\) should capture how similar the tokens \(\mathbf{x}_n\) and \(\mathbf{x}_m\) are. Tokens which are closely related should have a large \(a_{nm}\), while unrelated tokens should contribute close to nothing, i.e., \(a_{nm} \approx 0\). This is easy to do with vectors - we can simply compute their dot product and normalize them so that no weight is negative. We also make sure they sum to a constant so they act as weights. The Softmax function satisfies both criteria,

\[ a_{nm} = \frac{\exp\left(\mathbf{x}_n^T \mathbf{x}_m\right)}{\sum_m^N{\exp\left(\mathbf{x}_n^T \mathbf{x}_m\right)}}. \] Then, the final representation for \(\mathbf{x}_n\) is computed as follows:

\[ \mathbf{y}_n=\sum_{m}^N a_{nm} \mathbf{x}_m, \quad \text{where} \ \ a_{nm}\geq0 \ \ \text{and} \ \ \sum a_{nm} = 1. \] So far, this attention mechanism has no learning capacity. We have gotten to this point simply by reasoning about how each input token can be better represented according to its similarity to other input tokens. However, the original representations \(X\) may not be optimal for determining the nuanced relationships between different tokens. We can afford this mechanism some flexibility by introducing a linear transformation \(W^{(q)} \in \mathbb{R}^{D\times D_q}\) that will hopefully transform each token such that their dot product reveals more interesting relationships between them, \[ a_{nm} = \frac{\exp\left((W^{(q)}\mathbf{x}_n)^T W^{(q)}\mathbf{x}_m \right)}{\sum_m^N{\exp\left((W^{(q)}\mathbf{x}_n)^T W^{(q)}\mathbf{x}_m\right)}} \\ = \frac{\exp\left( \mathbf{x}_n^T W^{(q)T} W^{(q)} \mathbf{x}_m \right)}{\sum_m^N{\exp\left(\mathbf{x}_n^T W^{(q)T} W^{(q)} \mathbf{x}_m\right)}}. \] The entries of such a matrix are obviously unknown - they should be parameters learned from the data. But there’s a subtle problem with using the same matrix \(W^{(q)}\) to transform both \(\mathbf{x}_n\) and \(\mathbf{x}_m\). The matrix product \(W^{(q)T} W^{(q)}\) is symmetric, so

\[ \mathbf{x}_n^T W^{(q)T} W^{(q)} \mathbf{x}_m = \mathbf{x}_m^T W^{(q)T} W^{(q)} \mathbf{x}_n, \]

and therefore \(a_{nm}=a_{mn}\). This is not ideal, because the relationship between tokens is not necessarily symmetric. We can introduce another matrix \(W^{(k)} \in \mathbb{R}^{D\times D_k}\) to project all other tokens \(\mathbf{x}_m\),

\[ a_{nm} = \frac{\exp\left((W^{(q)}\mathbf{x}_n)^T W^{(k)}\mathbf{x}_m \right)}{\sum_m^N{\exp\left((W^{(q)}\mathbf{x}_n)^T W^{(k)}\mathbf{x}_m\right)}} \] The matrix \(W^{(q)}\) is known as the query matrix, and \(W^{(k)}\) is known as the key matrix. One way to think about them is like this: the query matrix distills from the token \(\mathbf{x}_n\) a question, or query, \(\mathbf{q}_n=W^{(q)}\mathbf{x}_n\) that this token can ask about the other tokens \(\mathbf{x}_m\), while the key matrix distills an answer, or a key, \(k_m=W^{(k)}\mathbf{x}_m\) from each of the remaining tokens \(\mathbf{x}_m\). The inner product between these the query and key reveals which tokens \(\mathbf{x}_m\) contain information relevant to building a richer representation for \(\mathbf{x}_n\).

In practice, these projections are computed in parallel as the matrix products \(Q = XW^{(q)}\) and \(K = XW^{(k)}\). We can express the attention mechanism so that it computes all the attention weights at once using matrices, rather than obtaining the weights \(a_{nm}\) for each query-key pair individually. This is how it would be implemented, after all, as it takes advantage of the parallelism of a GPU.

Multiplying X\in\mathbb{R}^{N\times D} and W^{(q)}\in\mathbb{R}^{D\times D_q} produces the matrix Q\in\mathbb{R}^{N\times D_q}. Every token \mathbf{x}_i (i.e., every row in X) is transformed into a new vector \mathbf{q}_n of length D_q. We can view this transformation as taking a linear combination of the rows of W^{(q)}, where each entry \{x_1, x_2, \dots, x_D\} are the coefficients.

Now we have a matrix \(Q\), of dimensions \(N \times D_q\), containing the “queries” that each of the \(N\) tokens would like to ask every other token, and the matrix \(K\), of dimensions \(N \times D_k\), containing the “key” or “answer” to each of these queries. Notice that there is no communcation between tokens yet - each token was indepedently transformed into a query and a key. We want to multiply the query and key matricest together to end up with an \(N \times N\) matrix (the first \(N\) representing the sequence of queries along the row dimension, and the second \(N\) representing the sequence of keys along the column dimension) whose \(nm\)th entry that tells us how closely related the \(n\)th query is to the \(m\)th key. To do this, we choose \(D_k = D_q\), and then compute \(QK^T\).

The values in row \(n\) of \(QK^T\) correspond to how closely associated a query \(\mathbf{q}_n\) is to the each of the keys \(\mathbf{k}_m\). As before, these values are normalized using a row-wise Softmax so that they are non-negative and sum to one:

\[ A = \text{Softmax}\left(\frac{QK^T}{\sqrt{D_q}}\right). \] The scaling factor \(\sqrt{D_q}\) is the standard deviation of the dot product between the query and key vectors (as each element in these vectors is assumed to come from a normal distribution with unit variance). Dividing by this scaling factor prevents the Softmax distribution from becoming too sharp.

The resulting matrix \(A\) is the attention matrix containing the attention weights \(a_{nm}\). To obtain the new representation \(Y\), we can use these weights to take linear combinations of the input tokens, \(Y=AX\). However, as before, rather than using the input tokens directly, we can allow the neural network some flexibility by introducing another weight matrix, \(W^{(v)}\), and linearly transform the input tokens to a more suitable representation, \(V=XW^{(v)}\). The matrix \(V\) surfaces the relevant information, or values, from each input token, which are then combined via the attention weights. For this reason, \(W^{(v)}\) is known as the value matrix.

At this point, the tokens are finally communicating with one another. The resulting row vectors in \(Y\) are new representations of the input tokens \(X\), with each row vector \(\mathbf{y}_n\) being an aggregate of all tokens, weighted by the attention weights. If \(a_{nm}\) is small, the \(m^{\text{th}}\) token contributes little to \(\mathbf{y}_n\); if large, it plays a significant role.

Multi-Head Attention

So far, we’ve used a single “head” of attention; that is, only we only capture correlations between different tokens once during the self-attention computation. To detect multiple relationships between inputs, we can introduce multi-head attention to independently perform self-attention \(h\) times, where \(h\) is the number of heads.

To do this, the input tokens \(\mathbf{x}\) are project from dimension \(D\) to \(D_{\text{head}}\times h\) using the query, key and value matrices, as opposed to the arbitrary dimensions \(D_q\), \(D_k\), or \(D_v\). The projected vectors now contain \(h\) heads, each of dimension \(D_{\text{head}}\).

When computing the query vectors, each row \mathbf{x}_n becomes \mathbf{q}_{n} as before, but this time the query contains h segments \mathbf{q^{(i)}_n}, where i \in \{1, \dots, h\}. The same process applies to the keys and values.

To process the \(h\) heads independently and in parallel, they are moved along their own dimension. So, the query, key and value matrices go from shape \(N \times (D_{\text{head}}h)\) to \(h\times N \times D_\text{head}\). From this point forward, the self-attention mechanism is identical to when we had a single head, with the only difference being that we have multiple heads along a batch-like dimension (samples in a batch are independent from one another, and are typically processed in parallel on a GPU). We can then compute the attention matrix \(A\) for the \(i\)th head by multiplying \(Q^{(i)}\) and \(K^{(i)T}\). In practice, all \(h\) heads are processed in parallel by doing a batch matrix multiplication \(QK^T\), which results in a \(h \times N \times N\) tensor \(A\). Ideally, each of these heads will capture different relationships between the input. It’s possible that the entry \(a_{nm}^{(i)}\) (which describes how closely related the query \(\mathbf{q}_{n}^{(i)}\) is to the key \(\mathbf{k}_{m}^{(i)}\)) is large, but \(a_{nm}^{(j)}\) is small because the \(j\)th head attends to different qualities in the input than the \(i\)th head.