Top 3 Attention Mechanisms in Large Language Models(LLMs)

Transformers have changed the way Natural Language Processing(NLP) tasks are performed over the last few years. The Self-Attention mechanism without the recurrence operation is the key to this success. Self-attention is the foundational block of the Transformer architecture. Self-attention is a concept based on the attention mechanism introduced in the paper by Bahdanau. It can be defined as the measure that quantifies the relationship each word(or subword) in an input sequence has with all the other words in the same sequence. It is “self” attention because the focus is on the input sequence itself.  

Self-Attention

The calculation of Self-attention uses the three components Query, Key and Value that are defined as follows. 

  • Query(Q) – The current word under consideration.
  • Key(K) – All the other words in the input sequence.
  • Value(V) – The meaningful content or information carried by each word in the input sequence.
  • d_k – The model dimension

The distance between the query and key is compared by calculating the dot product between the current query and the other words in the sequence. This is divided by a scaling factor and normalized with the softmax function which are further used to weigh the values.

\text{Self-Attention} = \text{softmax}(\frac{Q * K^T}{\sqrt{d_k}})

A transformer has two main blocks namely the encoder and decoder. The Self-attention in the encoder is calculated with the entire input sequence under consideration. For the self-attention in the decoder, the words in the sequence before the current word are considered. 

Why learn Attention Techniques?

There are different ways in which Self-Attention can be implemented to capture the context in the most efficient way possible in terms of memory and speed. In this blog, we discuss Multi-Head Attention(MHA), Multi-Query Attention(MQA) and Grouped-Query Attention(GQA) which are used in most language models today. 

Why is it important to understand these techniques? Let us say, you want to train a language model of your native language from scratch. You might want to build the model with huge training datasets and in this case it is important to optimize the memory and performance efficiently. Also, if you know the available options, you can make the right choice and improve upon it for training based on your requirements. 

Multi-head Attention(MHA)

The Multi-head Self-attention is calculated using the matrices Query(Q), Key(K) and value(V) which are projected into “h” different heads. Each head learns one aspect of language such as semantic nuances, syntactic structure, positional importance etc.,  

Let us understand this with an example sentence “Transformers are the best for NLP”. Let us assume the number of heads to be 3(h = 3). Each word in the input sequence has an initial embedding which is added to positional embedding before being sent to the transformer. This is the input matrix denoted by X with the dimension of [number of tokens x embedding size] which here is 6 x 4. 

 Note: The calculations of the multiple heads are not done using different matrices. All of the query, key, and vectors are stored in multi-dimensional tensors which are processed in parallel. However, the underlying calculation remains the same as mentioned below. 

To obtain the query(Q), key(K) and value(V) we multiply, the input X with weight matrices W_q, W_k, W_v respectively. However, since we have three heads we will have three different sets of weights for each head as shown below(only two are shown here). Each of the weight matrices has the dimension of 4 x 2. 

Q_1 = X * W_{1q} \\
K_1 = X * W_{1k} \\
V_1 = X * W_{1v} \\
\\
Q_2 = X * W_{2q} \\
K_2 = X * W_{2k} \\
V_2 = X * W_{2v} \\

Next, the attention weights are calculated for each head as follows. The attention weights obtained by performing the Softmax operation show the attention of each word in the sequence with respect to every other word. Thus, the dimension of each of these matrices would be 6 x 6. 

\text{Attention\_Scores\_1} = \frac{Q_1 * K_1^T}{\sqrt{d_k}} \\
\text{Attention\_Weights\_1} = \text{softmax}(\text{Attention\_Scores\_1}) \\
\text{Attention\_Scores\_2} = \frac{Q_2 * K_2^T}{\sqrt{d_k}} \\
\text{Attention\_Weights\_2} = \text{softmax}(\text{Attention\_Scores\_2}) \\

These attention weights are multiplied with the value matrices to complete the attention calculation as follows. 

\text{Head\_1\_Output} = \text{Attention\_Weights\_1} * V_1 \\ 
\text{Head\_2\_Output} = \text{Attention\_Weights\_2} * V_2 \\ 

The final output from the three heads is calculated by multiplying the concatenated result with the output weight matrix W_o. 

\text{Multi\_Head\_Output} = [\text{Head1\_Output}, \text{Head2\_Output}, \text{Head3\_Output}] * W_o

Since there are separate weights for each head, all the above calculations can happen simultaneously to speed up the operation. 

Though the MHA works effectively in capturing the context, for long sequences the memory overhead and the speed during inference become a bottleneck. The attention for each head is calculated in parallel and hence all the separate query, key, and value matrices need to be stored in the memory at once.

Multi-Query Attention(MQA)

The MQA aims to fix the memory and speed problems of MHA with a small decrease in accuracy. Here, instead of having different query, key, and value matrices for each head, the Key and Value matrices are shared for all the heads. 

The MQA calculation is shown as follows. For simplicity, we will show the calculation only for two heads. As you can see here, there is only a single set of weights(W_k, W_v) and hence corresponding Key, Value(K_1, V_1) matrices for both the heads. 

Q_1 = X * W_{1q} \\
K_1 = X * W_{1k} \\
V_1 = X * W_{1v} \\

Q_2  = X * W_{2q} \\
  \\
\text{Attention\_Scores\_1} = \frac{Q_1 * K_1^T}{\sqrt{d_k}} \\
\text{Attention\_Weights\_1} = \text{softmax}(\text{Attention\_Scores\_1}) \\
\text{Attention\_Scores\_2} = \frac{Q_2 * K_1^T}{\sqrt{d_k}} \\
\text{Attention\_Weights\_2} = \text{softmax}(\text{Attention\_Scores\_2}) \\

\text{Head\_1\_Output} = \text{Attention\_Weights\_1} * V_1 \\ 
\text{Head\_2\_Output} = \text{Attention\_Weights\_2} * V_1 \\ 

This technique can be use when there are resource constraints while training and inference and a small decrease in accuracy is not fatal. Google’s PaLM(Pathway Language Model) and Falcon language models use Multi-Query Attention. 

Grouped Query Attention(GQA)

The GQA is a midway between MHA and MQA that works almost as well as MHA with lesser memory overhead. Instead of multiple queries having the same key and value matrices, the queries are separated into multiple(G) groups. The queries in each group share the same key and value matrices.

In this case, let us consider there are 6 attention heads in total. Hence, in MHA there will be 6 different sets of Query, Key, and Value matrices. Now, let us consider G = 3 and hence there are two query heads in one group, same set of Key and Value matrices for both the queries within the group. For Group-1, the outputs are calculated as follows. Note that for both Q1 and Q2, K1 and V1 are used as keys and values for calculation. 

Q_1 = X * W_{1q} \\
K_1 = X * W_{1k} \\
V_1 = X * W_{1v} \\

Q_2  = X * W_{2q} \\
  \\
\text{Attention\_Scores\_1} = \frac{Q_1 * K_1^T}{\sqrt{d_k}} \\
\text{Attention\_Weights\_1} = \text{softmax}(\text{Attention\_Scores\_1}) \\
\text{Attention\_Scores\_2} = \frac{Q_2 * K_1^T}{\sqrt{d_k}} \\
\text{Attention\_Weights\_2} = \text{softmax}(\text{Attention\_Scores\_2}) \\

\text{Head\_1\_Output} = \text{Attention\_Weights\_1} * V_1 \\ 
\text{Head\_2\_Output} = \text{Attention\_Weights\_2} * V_1 \\ 

Similarly, the second group has the queries Q3, Q4 and one set of Key(K2), Value(V2) for the group and the outputs are calculated as follows.

Q_3 = X * W_{3q} \\
K_2 = X * W_{2k} \\
V_2 = X * W_{2v} \\

Q_4  = X * W_{4q} \\
  \\
\text{Attention\_Scores\_3} = \frac{Q_3 * K_2^T}{\sqrt{d_k}} \\
\text{Attention\_Weights\_3} = \text{softmax}(\text{Attention\_Scores\_3}) \\
\text{Attention\_Scores\_4} = \frac{Q_4 * K_2^T}{\sqrt{d_k}} \\
\text{Attention\_Weights\_4} = \text{softmax}(\text{Attention\_Scores\_4}) \\

\text{Head\_3\_Output} = \text{Attention\_Weights\_3} * V_2 \\ 
\text{Head\_4\_Output} = \text{Attention\_Weights\_4} * V_2 \\ 

The final output is calculated as in the above techniques by multiplying the concatenation of all the 6 outputs with the output weight matrix W_o. 

\text{Multi\_Head\_Output} = [\text{Head1\_Output},....\text{Head3\_Output}..., \text{Head6\_Output}] * W_o

The GQA is a generalized form of MHA and MQA. When G = 1, there is only one group and hence all queries share the same Key and Value matrices which makes it MQA. When G = number of heads, every group has only one query and its own key and values and hence it is equivalent to the Multi-head attention. From experimentation results, the ideal number of groups is 8. 

This helps to reduce the memory overhead without significant decrease in accuracy and inference speed compared to MQA. The recently released model LLama-2 from Meta uses Grouped-Query Attention.

Intuitive Understanding

We have been through a lot of formulae. Now, let us make this fun by telling a story. Let us say you are a manager of a software engineering team who is assigned a team of interns whose job is to build software applications. A software application has different components such as the UI, database management, backend code, API, testing and deployment. You train each of the 6 interns to develop one of the components. They can be thought of as training the 6 different key and value sets to develop the 6 components(queries) of the software application like in the MHA. 

Continuing our story, let us say you have only one intern at your disposal to build the application. So, all the 6 components need to built by one person. That is for all the 6 queries, the key and value matrices are the same just like the Multi-Query Attention(MQA). This would increase the workload of one person but it has the advantage of having to train only one person.

As a compromise between the two above options, let us say you have 3 individuals with more than one competency. For example, often people who work with the databases also can handle the backend code. Therefore, let us say the first intern handles UI and API development, second intern handles databases and backend code, and the third intern takes care of testing and deployment. Think of them as the key and value where all the components(queries) are grouped into 3 sets of keys and values like in the Grouped Query Attention(GQA).

Conclusion

In this post, we have briefly discussed the famous self-attention mechanisms used in the building of famous large language models. We have developed an intuitive understanding of the differences between these mechanisms. In the upcoming posts, we will discuss more such techniques such as Sparse Attention and Flash Attention.

References

Want to read more blogs like this? Be sure to check out more such posts. Want to write for us and earn money? You can find more details here.

Insert math as
Block
Inline
Additional settings
Formula color
Text color
#333333
Type math using LaTeX
Preview
\({}\)
Nothing to preview
Insert