LLM: From Zero to Hero Help

Transformer Architecture

When I was new to AI and first encountered the Transformer Architecture, I felt overwhelmed by the sheer number of concepts and tutorials available to understand it. I noticed that some videos or articles assumed prior knowledge of natural language processing (NLP) concepts, while others were too lengthy and difficult to comprehend. To grasp the Transformer Architecture, I had to read numerous articles and watch several videos, and there were times when I got stuck due to a lack of basic knowledge that prevented me from fully understanding the topic, including:

- Why input embedding can represent a word? - What is a matrix multiplication and how the shape performs? - what does softmax do? - when the model is trained, where is it stored? ...

With all the complexities involved, I have found it challenging to completely master this architecture. Therefore, I urge you to be patient and make sure to watch the video clips first, as the concepts are relatively easy to understand once explained. If we proceed step-by-step, starting from the basics and gradually moving towards more advanced topics, I am confident that you will soon be on the same page as me.

If you are new to LLM, be sure to check out my previous articles introducing the LLMs, which explains the basics of LLM and how it works.

High Level Overview

The Transformer model was first introduced in the paper 《Attention is all you need》 in 2017. The transformer architecture was intended for training language-translation purpose models. However, the team at OpenAI discovered that the transformer architecture was the crucial solution for character prediction. Once trained on the entirety of internet data, the model could potentially comprehend the context of any text and coherently complete any sentence, much like a human.

The model consists of two parts: the encoder and the decoder. In general, encoder-only architectures are proficient in extracting information from text for tasks like classification and regression, whereas decoder-only models specialize in generating text. For instance, GPT, which focuses on text generation, falls under the category of decoder-only models.

Let’s run through the key ideas of the architecture when training a model.

I drew a diagram to illustrate the training process of the GPT-like decoder-only transformer architecture:

Decoder only transformer
  1. First, we require a sequence of input characters as training data. These inputs are converted into a vector embedding format.

  2. Next, we add a positional encoding to the vector embeddings to capture each character's position within the sequence.

  3. Subsequently, the model processes these input embeddings via a series of computational operations, ultimately generating a probability distribution over possible next characters for the given input text.

  4. The model assesses the predicted outcome against the actual subsequent character from the training dataset, adjusting the probabilities or "weights" accordingly.

  5. Finally, the model iteratively refines this process, updating its parameters continually to enhance the precision of future predictions.

Let's break into the details of each step.

Step 1 : Tokenization

Tokenization is the first step of a transformer model that:

Translates an input sentence into a format of number representations.

Tokenization is the process of dividing text into smaller units called tokens, which can be words, sub-words, phrases, or characters. Because breaking phrases into smaller pieces helps the model to identify the underlying structure of the text and process it more efficiently.

For example:

Chapter 1: Building Rapport and Capturing

The above sentence can be cut into:

Chapter, , 1, :, , Building, , Rap, port, , and, , Capturing

which is tokenized as 10 numbers:

[26072, 220, 16, 25, 17283, 23097, 403, 220, 323, 220, 17013, 220, 1711]

As you can see, number 220 is used to represent the white space character. There are many ways to tokenize charcters to integers. For our sample dataset, we will use tiktoken library.

For demonstration purpose, I'll use a small textbook dataset ( from Hugging Face) which contains 460k characters for our training.

  1. File size: 450Kb

  2. Vocab size: 3,771 (means unique words/sub-words)

Our training data contains a vocabulary size of 3,771 distinct characters. And the maximum number used to tokenize our textbook dataset is 100069, which is mapped to a character Clar.

Once we have the tokenized mapping, we can find the corresponding integer index for each character in the dataset. We will utilize these assigned integer indices as tokens moving forward, rather than using entire words when interfacing with the model.

Step 2 : Word Embeddings

To begin, let us construct a look-up table containing all the characters from our vocabulary. In essence, this table consists of a matrix filled with randomly initialized numbers.

Given the largest token number we have is 100069, and considering a dimension of 64 (with the original paper utilizing 512 dimensions, denoted as d_model), the resulting lookup table becomes a 100,069 × 64 matrix, this is called the Token Embedding Look-up Table. Represented as follows:

Token Embedding Look-Up Table: 0 1 2 3 4 5 6 7 8 9 ... 54 55 56 57 58 59 60 61 62 63 0 0.625765 0.025510 0.954514 0.064349 -0.502401 -0.202555 -1.567081 -1.097956 0.235958 -0.239778 ... 0.420812 0.277596 0.778898 1.533269 1.609736 -0.403228 -0.274928 1.473840 0.068826 1.332708 1 -0.497006 0.465756 -0.257259 -1.067259 0.835319 -1.956048 -0.800265 -0.504499 -1.426664 0.905942 ... 0.008287 -0.252325 -0.657626 0.318449 -0.549586 -1.464924 -0.557690 -0.693927 -0.325247 1.243933 2 1.347121 1.690980 -0.124446 -1.682366 1.134614 -0.082384 0.289316 0.835773 0.306655 -0.747233 ... 0.543340 -0.843840 -0.687481 2.138219 0.511412 1.219090 0.097527 -0.978587 -0.432050 -1.493750 3 1.078523 -0.614952 -0.458853 0.567482 0.095883 -1.569957 0.373957 -0.142067 -1.242306 -0.961821 ... -0.882441 0.638720 1.119174 -1.907924 -0.527563 1.080655 -2.215207 0.203201 -1.115814 -1.258691 4 0.814849 -0.064297 1.423653 0.261726 -0.133177 0.211893 1.449790 3.055426 -1.783010 -0.832339 ... 0.665415 0.723436 -1.318454 0.785860 -1.150111 1.313207 -0.334949 0.149743 1.306531 -0.046524 ... ... ... ... ... ... ... ... ... ... ... ... ... ... ... ... ... ... ... ... ... ... 100064 -0.898191 -1.906910 -0.906910 1.838532 2.121814 -1.654444 0.082778 0.064536 0.345121 0.262247 ... 0.438956 0.163314 0.491996 1.721039 -0.124316 1.228242 0.368963 1.058280 0.406413 -0.326223 100065 1.354992 -1.203096 -2.184551 -1.745679 -0.005853 -0.860506 1.010784 0.355051 -1.489120 -1.936192 ... 1.354665 -1.338872 -0.263905 0.284906 0.202743 -0.487176 -0.421959 0.490739 -1.056457 2.636806 100066 -0.436116 0.450023 -1.381522 0.625508 0.415576 0.628877 -0.595811 -1.074244 -1.512645 -2.027422 ... 0.436522 0.068974 1.305852 0.005790 -0.583766 -0.797004 0.144952 -0.279772 1.522029 -0.629672 100067 0.147102 0.578953 -0.668165 -0.011443 0.236621 0.348374 -0.706088 1.368070 -1.428709 -0.620189 ... 1.130942 -0.739860 -1.546209 -1.475937 -0.145684 -1.744829 0.637790 -1.064455 1.290440 -1.110520 100068 0.415268 -0.345575 0.441546 -0.579085 1.110969 -1.303691 0.143943 -0.714082 -1.426512 1.646982 ... -2.502535 1.409418 0.159812 -0.911323 0.856282 -0.404213 -0.012741 1.333426 0.372255 0.722526 [100,069 rows x 64 columns]

where each line represents a character (indexed by its token number) and each column represents a dimension.

For now, you can consider a 'dimension' as a characteristic or aspect of a character. In our case, we have designated 64 dimensions, which means we will be able to understand a character's textual meaning in 64 distinct ways, such as by categorizing it as a noun, verb, adjective, and so on.

Say, now we have an example training input of context_length of 16 which is:

" . By mastering the art of identifying underlying motivations and desires, we equip ourselves with "

Now, we retrieve the embedding vector for each tokenized character (or word) by using its integer index to look up the embedding table. As a result, we obtain their respective input embeddings:

[ 627, 1383, 88861, 279, 1989, 315, 25607, 16940, 65931, 323, 32097, 11, 584, 26458, 13520, 449]

In the transformer architecture, multiple input sequences are processed simultaneously in parallel, often referred to as multiple batches. Let's set our batch_size to be 4. Therefore, we will process four randomly selected sentences as our inputs at once.

Input Sequence Batch: 0 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 0 627 1383 88861 279 1989 315 25607 16940 65931 323 32097 11 584 26458 13520 449 1 15749 311 9615 3619 872 6444 6 3966 11 10742 11 323 32097 13 3296 22815 2 13189 315 1701 5557 304 6763 374 88861 7528 10758 7526 13 4314 7526 2997 2613 3 323 6376 2867 26470 1603 16661 264 49148 627 18 13 81745 48023 75311 7246 66044 [4 rows x 16 columns]

each row represents a sentence; each column is the character from 0th to 15th position of that sentence.

In a result, we now have a matrix representing 4 batches of 16 characters inputs. The shape of this matrix is (batch_size, context_length) = [4, 16].

To recap, we defined an input embedding lookup table as a matrix of size 100,069 × 64 earlier. The next step is to take our input sequences matrix and map it onto this embedding matrix to obtain our Input Embedding.

Here, we will focus on breaking down each individual row of our input sequence matrix, starting with the first row. First, we reshape this opening row from its original dimensions of (1, context_length) = [1, 16] to a fresh format of (context_length, 1) = [16, 1]. Subsequently, we overlay this restructured row upon our previously established embedding matrix sized (vocab_size, d_model) = [100069, 64], thereby substituting the matching embedding vector for every single character present inside the given context window. The resulting output is a matrix with a shape of (context_length, d_model) = [16, 64].

First row of input sequence batch:

Input Embedding: 0 1 2 3 4 5 6 7 8 9 ... 54 55 56 57 58 59 60 61 62 63 0 1.051807 -0.704369 -0.913199 -1.151564 0.582201 -0.898582 0.984299 -0.075260 -0.004821 -0.743642 ... 1.151378 0.119595 0.601200 -0.940352 0.289960 0.579749 0.428623 0.263096 -0.773865 -0.734220 1 -0.293959 -1.278850 -0.050731 0.862562 0.200148 -1.732625 0.374076 -1.128507 0.281203 -1.073113 ... -0.062417 -0.440599 0.800283 0.783043 1.602350 -0.676059 -0.246531 1.005652 -1.018667 0.604092 2 -0.292196 0.109248 -0.131576 -0.700536 0.326451 -1.885801 -0.150834 0.348330 -0.777281 0.986769 ... 0.382480 1.315575 -0.144037 1.280103 1.112829 0.438884 -0.275823 -2.226698 0.108984 0.701881 3 0.427942 0.878749 -0.176951 0.548772 0.226408 -0.070323 -1.865235 1.473364 1.032885 0.696173 ... 1.270187 1.028823 -0.872329 -0.147387 -0.083287 0.142618 -0.375903 -0.101887 0.989520 -0.062560 4 -1.064934 -0.131570 0.514266 -0.759037 0.294044 0.957125 0.976445 -1.477583 -1.376966 -1.171344 ... 0.231112 1.278687 0.254688 0.516287 0.621753 0.219179 1.345463 -0.927867 0.510172 0.656851 5 2.514588 -1.001251 0.391298 -0.845712 0.046932 -0.036732 1.396451 0.934358 -0.876228 -0.024440 ... 0.089804 0.646096 -0.206935 0.187104 -1.288239 -1.068143 0.696718 -0.373597 -0.334495 -0.462218 6 0.498423 -0.349237 -1.061968 -0.093099 1.374657 -0.512061 -1.238927 -1.342982 -1.611635 2.071445 ... 0.025505 0.638072 0.104059 -0.600942 -0.367796 -0.472189 0.843934 0.706170 -1.676522 -0.266379 7 1.684027 -0.651413 -0.768050 0.599159 -0.381595 0.928799 2.188572 1.579998 -0.122685 -1.026440 ... -0.313672 1.276962 -1.142109 -0.145139 1.207923 -0.058557 -0.352806 1.506868 -2.296642 1.378678 8 -0.041210 -0.834533 -1.243622 -0.675754 -1.776586 0.038765 -2.713090 2.423366 -1.711815 0.621387 ... -1.063758 1.525688 -1.762023 0.161098 0.026806 0.462347 0.732975 0.479750 0.942445 -1.050575 9 0.708754 1.058510 0.297560 0.210548 0.460551 1.016141 2.554897 0.254032 0.935956 -0.250423 ... -0.552835 0.084124 0.437348 0.596228 0.512168 0.289721 -0.028321 -0.932675 -0.411235 1.035754 10 -0.584553 1.395676 0.727354 0.641352 0.693481 -2.113973 -0.786199 -0.327758 1.278788 -0.156118 ... 1.204587 -0.131655 -0.595295 -0.433438 -0.863684 3.272247 0.101591 0.619058 -0.982174 -1.174125 11 -0.753828 0.098016 -0.945322 0.708373 -1.493744 0.394732 0.075629 -0.049392 -1.005564 0.356353 ... 2.452891 -0.233571 0.398788 -1.597272 -1.919085 -0.405561 -0.266644 1.237022 1.079494 -2.292414 12 -0.611864 0.006810 1.989711 -0.446170 -0.670108 0.045619 -0.092834 1.226774 -1.407549 -0.096695 ... 1.181310 -0.407162 -0.086341 -0.530628 0.042921 1.369478 0.823999 -0.312957 0.591755 0.516314 13 -0.584553 1.395676 0.727354 0.641352 0.693481 -2.113973 -0.786199 -0.327758 1.278788 -0.156118 ... 1.204587 -0.131655 -0.595295 -0.433438 -0.863684 3.272247 0.101591 0.619058 -0.982174 -1.174125 14 -1.174090 0.096075 -0.749195 0.395859 -0.622460 -1.291126 0.094431 0.680156 -0.480742 0.709318 ... 0.786663 0.237733 1.513797 0.296696 0.069533 -0.236719 1.098030 -0.442940 -0.583177 1.151497 15 0.401740 -0.529587 3.016675 -1.134723 -0.256546 -0.219896 0.637936 2.000511 -0.418684 -0.242720 ... -0.442287 -1.519394 -1.007496 -0.517480 0.307449 -0.316039 -0.880636 -1.424680 -1.901644 1.968463 [16 rows x 64 columns]

matrix shows one of the four rows after mapping

And we do the same to the rest of 3 rows, in the end we have 4 sets x [16 rows x 64 columns].

This results in the Input Embedding matrix with shape of (batch_size, context_length, d_model) = [4, 16, 64].

In essence, providing each word with a unique embedding allows the model to accommodate variations in language and manage words that have multiple meanings or forms.

Let us move forward with the understanding that our input embedding matrix serves as the expected input format for our model, even if we do not yet fully grasp the underlying mathematical principles at play.

Step 3 : Positional Encoding

In my opinion, positional encoding is the most challenging concept to grasp in the transformer architecture.

To summarize what positional encoding solves:

  1. We want each word to carry some information about its position in the sentence.

  2. We want the model to treat words that appear close to each other as “close” and words that are distant as “distant”.

  3. We want the positional encoding to represent a pattern that can be learned by the model.

The position encoding describes the location or position of an entity in a sequence so that each position is assigned a unique representation.

The position encoding is another vector of numbers that is added to the input embedding of each tokenized character. The position encoding is a sine wave and a cosine wave with a frequency that varies based on the position of the tokenized character.

In original paper, the method introduced to calculate the position encoding is:

PE(pos, 2i) = sin(pos / 10000^(2i/d_model)) PE(pos, 2i+1) = cos(pos / 10000^(2i/d_model))

where pos is the position, and i is from 0 to d_model/2. d_model is the model dimension we defined when training the model (in our case it's 64, in original paper they use 512).

In fact, this positional encoding matrix just created once and reuse for every input sequence.

Let's take a look at the positional encoding matrix:

Position Embedding Look-Up Table: 0 1 2 3 4 5 6 7 8 9 ... 54 55 56 57 58 59 60 61 62 63 0 0.000000 1.000000 0.000000 1.000000 0.000000 1.000000 0.000000 1.000000 0.000000 1.000000 ... 0.000000 1.000000 0.000000 1.000000 0.000000 1.000000 0.000000 1.000000 0.000000 1.000000 1 0.841471 0.540302 0.681561 0.731761 0.533168 0.846009 0.409309 0.912396 0.310984 0.950415 ... 0.000422 1.000000 0.000316 1.000000 0.000237 1.000000 0.000178 1.000000 0.000133 1.000000 2 0.909297 -0.416147 0.997480 0.070948 0.902131 0.431463 0.746904 0.664932 0.591127 0.806578 ... 0.000843 1.000000 0.000632 1.000000 0.000474 1.000000 0.000356 1.000000 0.000267 1.000000 3 0.141120 -0.989992 0.778273 -0.627927 0.993253 -0.115966 0.953635 0.300967 0.812649 0.582754 ... 0.001265 0.999999 0.000949 1.000000 0.000711 1.000000 0.000533 1.000000 0.000400 1.000000 4 -0.756802 -0.653644 0.141539 -0.989933 0.778472 -0.627680 0.993281 -0.115730 0.953581 0.301137 ... 0.001687 0.999999 0.001265 0.999999 0.000949 1.000000 0.000711 1.000000 0.000533 1.000000 5 -0.958924 0.283662 -0.571127 -0.820862 0.323935 -0.946079 0.858896 -0.512150 0.999947 -0.010342 ... 0.002108 0.999998 0.001581 0.999999 0.001186 0.999999 0.000889 1.000000 0.000667 1.000000 6 -0.279415 0.960170 -0.977396 -0.211416 -0.230368 -0.973104 0.574026 -0.818837 0.947148 -0.320796 ... 0.002530 0.999997 0.001897 0.999998 0.001423 0.999999 0.001067 0.999999 0.000800 1.000000 7 0.656987 0.753902 -0.859313 0.511449 -0.713721 -0.700430 0.188581 -0.982058 0.800422 -0.599437 ... 0.002952 0.999996 0.002214 0.999998 0.001660 0.999999 0.001245 0.999999 0.000933 1.000000 8 0.989358 -0.145500 -0.280228 0.959933 -0.977262 -0.212036 -0.229904 -0.973213 0.574318 -0.818632 ... 0.003374 0.999994 0.002530 0.999997 0.001897 0.999998 0.001423 0.999999 0.001067 0.999999 9 0.412118 -0.911130 0.449194 0.893434 -0.939824 0.341660 -0.608108 -0.793854 0.291259 -0.956644 ... 0.003795 0.999993 0.002846 0.999996 0.002134 0.999998 0.001600 0.999999 0.001200 0.999999 10 -0.544021 -0.839072 0.937633 0.347628 -0.612937 0.790132 -0.879767 -0.475405 -0.020684 -0.999786 ... 0.004217 0.999991 0.003162 0.999995 0.002371 0.999997 0.001778 0.999998 0.001334 0.999999 11 -0.999990 0.004426 0.923052 -0.384674 -0.097276 0.995257 -0.997283 -0.073661 -0.330575 -0.943780 ... 0.004639 0.999989 0.003478 0.999994 0.002609 0.999997 0.001956 0.999998 0.001467 0.999999 12 -0.536573 0.843854 0.413275 -0.910606 0.448343 0.893862 -0.940067 0.340989 -0.607683 -0.794179 ... 0.005060 0.999987 0.003795 0.999993 0.002846 0.999996 0.002134 0.999998 0.001600 0.999999 13 0.420167 0.907447 -0.318216 -0.948018 0.855881 0.517173 -0.718144 0.695895 -0.824528 -0.565821 ... 0.005482 0.999985 0.004111 0.999992 0.003083 0.999995 0.002312 0.999997 0.001734 0.999998 14 0.990607 0.136737 -0.878990 -0.476839 0.999823 -0.018796 -0.370395 0.928874 -0.959605 -0.281349 ... 0.005904 0.999983 0.004427 0.999990 0.003320 0.999995 0.002490 0.999997 0.001867 0.999998 15 0.650288 -0.759688 -0.968206 0.250154 0.835838 -0.548975 0.042249 0.999107 -0.999519 0.031022 ... 0.006325 0.999980 0.004743 0.999989 0.003557 0.999994 0.002667 0.999996 0.002000 0.999998 [16 rows x 64 columns]

Let's talk a bit more on positional encoding tricks.

As far as I understand, the positional values are established concerning their relative positions within a sequence. Moreover, due to having a consistent context length for every input sentence, it enables us to recycle identical positional encoding throughout various inputs. Therefore, it becomes imperative to create sequential numbers cautiously to prevent excessive magnitude from negatively impacting input embeddings, ensuring adjacent positions exhibit minor discrepancies and distant ones showcase larger distinctions between them.

Using a combination of sine and cosine vectors, the model sees a positional encoding vector independent of word embedding without confusing the input embedding (semantic) information. It’s hard to imagine how this works inside the neuron network, but it works.

We can visualize our position embedding numbers and look at the patterns.

Pe 64dim

Each vertical line is our dimension from 0 to 64; each row represents a character. The values are between -1 and 1 since they are from sine and cosine functions. Darker colors indicate values closer to -1, and brighter colors are closer to 1. Green colors indicate values in between.

Let's go back to our positional encoding matrix, as you can see, this positional encoding table has the same shape as each batch in the input embedding table [4, 16, 64], they are both (context_length, d_model) = [16, 64].

Since two matrices with same shape can be summed together, we then can add the positional information to each of the input embedding rows to get the Final Input Embedding matrix.

batch 0: 0 1 2 3 4 5 6 7 8 9 ... 54 55 56 57 58 59 60 61 62 63 0 1.051807 0.295631 -0.913199 -0.151564 0.582201 0.101418 0.984299 0.924740 -0.004821 0.256358 ... 1.151378 1.119595 0.601200 0.059648 0.289960 1.579749 0.428623 1.263096 -0.773865 0.265780 1 0.547512 -0.738548 0.630830 1.594323 0.733316 -0.886616 0.783385 -0.216111 0.592187 -0.122698 ... -0.061995 0.559401 0.800599 1.783043 1.602587 0.323941 -0.246353 2.005651 -1.018534 1.604092 2 0.617101 -0.306899 0.865904 -0.629588 1.228581 -1.454339 0.596070 1.013263 -0.186154 1.793348 ... 0.383324 2.315575 -0.143404 2.280102 1.113303 1.438884 -0.275467 -1.226698 0.109251 1.701881 3 0.569062 -0.111243 0.601322 -0.079154 1.219661 -0.186289 -0.911600 1.774332 1.845533 1.278927 ... 1.271452 2.028822 -0.871380 0.852612 -0.082575 1.142617 -0.375369 0.898113 0.989920 0.937440 4 -1.821736 -0.785214 0.655805 -1.748969 1.072516 0.329445 1.969725 -1.593312 -0.423386 -0.870206 ... 0.232799 2.278685 0.255953 1.516287 0.622701 1.219178 1.346175 0.072133 0.510705 1.656851 5 1.555663 -0.717588 -0.179829 -1.666574 0.370867 -0.982811 2.255347 0.422208 0.123719 -0.034782 ... 0.091912 1.646094 -0.205354 1.187103 -1.287054 -0.068144 0.697607 0.626403 -0.333828 0.537782 6 0.219007 0.610934 -2.039364 -0.304516 1.144289 -1.485164 -0.664902 -2.161820 -0.664487 1.750649 ... 0.028036 1.638068 0.105957 0.399056 -0.366373 0.527810 0.845001 1.706170 -1.675722 0.733621 7 2.341013 0.102489 -1.627363 1.110608 -1.095316 0.228369 2.377153 0.597940 0.677737 -1.625878 ... -0.310720 2.276958 -1.139895 0.854859 1.209583 0.941441 -0.351562 2.506867 -2.295708 2.378678 8 0.948148 -0.980033 -1.523850 0.284180 -2.753848 -0.173272 -2.942995 1.450153 -1.137498 -0.197246 ... -1.060385 2.525683 -1.759494 1.161095 0.028703 1.462346 0.734397 1.479749 0.943511 -0.050575 9 1.120872 0.147380 0.746753 1.103982 -0.479273 1.357801 1.946789 -0.539822 1.227215 -1.207067 ... -0.549040 1.084117 0.440194 1.596224 0.514303 1.289719 -0.026721 0.067324 -0.410035 2.035753 10 -1.128574 0.556604 1.664986 0.988980 0.080544 -1.323841 -1.665967 -0.803163 1.258105 -1.155904 ... 1.208804 0.868336 -0.592132 0.566557 -0.861313 4.272244 0.103369 1.619057 -0.980840 -0.174126 11 -1.753818 0.102441 -0.022270 0.323699 -1.591020 1.389990 -0.921654 -0.123053 -1.336139 -0.587427 ... 2.457530 0.766419 0.402266 -0.597278 -1.916476 0.594436 -0.264688 2.237020 1.080961 -1.292415 12 -1.148437 0.850664 2.402985 -1.356776 -0.221765 0.939481 -1.032902 1.567763 -2.015232 -0.890874 ... 1.186370 0.592825 -0.082546 0.469365 0.045767 2.369474 0.826133 0.687041 0.593355 1.516313 13 -0.164386 2.303123 0.409138 -0.306666 1.549362 -1.596800 -1.504343 0.368137 0.454260 -0.721938 ... 1.210069 0.868330 -0.591184 0.566554 -0.860601 4.272243 0.103903 1.619056 -0.980440 -0.174127 14 -0.183482 0.232812 -1.628186 -0.080981 0.377364 -1.309922 -0.275964 1.609030 -1.440347 0.427969 ... 0.792566 1.237715 1.518224 1.296686 0.072853 0.763276 1.100520 0.557057 -0.581310 2.151496 15 1.052028 -1.289275 2.048469 -0.884570 0.579293 -0.768871 0.680185 2.999618 -1.418203 -0.211697 ... -0.435962 -0.519414 -1.002752 0.482508 0.311006 0.683955 -0.877969 -0.424683 -1.899643 2.968462 [16 rows x 64 columns] batch 1: 0 1 2 3 4 5 6 7 8 9 ... 54 55 56 57 58 59 60 61 62 63 0 -0.264236 0.965681 1.909974 -0.338721 -0.554196 0.254583 -0.576111 1.766522 -0.652587 0.455450 ... -1.016426 0.458762 -0.513290 0.618411 0.877229 2.526591 0.614551 0.662366 -1.246907 1.128066 1 1.732205 -0.858178 0.324008 1.022650 -1.172865 0.513133 -0.121611 2.630085 0.072425 2.332296 ... 0.737660 1.988225 2.544661 1.995471 0.447863 3.174428 0.444989 0.860426 2.137797 1.537580 2 -1.348308 -1.080221 1.753394 0.156193 0.440652 1.015287 -0.790644 1.215537 2.037030 0.476560 ... 0.296941 1.100837 -0.153194 1.329375 -0.188958 1.229344 -1.301919 0.938138 -0.860689 -0.860137 3 0.601103 -0.156419 0.850114 -0.324190 -0.311584 -2.232454 -0.903112 0.242687 0.801908 2.502464 ... -0.397007 1.150545 -0.473907 0.318961 -1.970126 1.967961 -0.186831 0.131873 0.947445 -0.281573 4 -1.821736 -0.785214 0.655805 -1.748969 1.072516 0.329445 1.969725 -1.593312 -0.423386 -0.870206 ... 0.232799 2.278685 0.255953 1.516287 0.622701 1.219178 1.346175 0.072133 0.510705 1.656851 5 1.555663 -0.717588 -0.179829 -1.666574 0.370867 -0.982811 2.255347 0.422208 0.123719 -0.034782 ... 0.091912 1.646094 -0.205354 1.187103 -1.287054 -0.068144 0.697607 0.626403 -0.333828 0.537782 6 0.599841 0.943214 -1.397184 -0.607349 -0.333995 -1.222589 -0.731189 -0.997706 1.848611 0.254238 ... 0.340986 1.383113 1.674592 2.229903 -0.157415 0.362868 -0.493762 1.904136 0.027903 1.196017 7 0.072234 1.386670 -0.985962 -1.184486 0.958293 -0.295773 -1.529277 -0.727844 1.510503 1.268154 ... -0.356459 0.382331 0.138104 -0.360916 -0.638448 1.305404 -0.756442 0.299150 0.154600 -0.466154 8 -0.008645 -1.066763 -0.716555 2.148885 -0.709739 -0.137266 0.385401 0.699139 1.907906 -2.357567 ... 0.490190 -1.215412 1.216459 0.659227 -0.282908 -0.912266 0.595569 1.210701 0.737407 0.801672 9 -0.006332 -0.949928 0.192689 3.158421 -1.292153 -0.830248 0.966141 -2.056514 0.042364 1.485927 ... 0.480763 -0.318554 0.005837 3.031636 -0.448117 1.059403 0.598106 0.871427 0.327321 1.090921 10 -1.152681 -0.710162 -0.456591 -0.468090 -0.292566 0.747535 -0.149907 -0.395523 0.170872 -2.372754 ... -1.267461 0.043283 -0.114980 1.083042 -0.288776 1.442318 0.775591 0.728716 -0.576776 -0.727257 11 -0.955986 -0.277475 0.946888 -0.242687 1.257744 0.369994 0.460073 0.728078 -0.165204 -0.761762 ... -0.307983 2.078995 -1.067792 1.805637 0.608968 1.722982 -0.371174 -0.603182 0.285387 1.112932 12 -0.844347 0.883224 1.222388 -0.811387 -0.593557 0.157268 -0.650315 1.289236 -1.472027 -0.447092 ... -0.536433 2.465097 -0.822905 1.272786 0.703664 2.687270 -0.924388 0.596134 -0.367138 0.812242 13 0.776470 1.549248 -0.239693 0.133783 0.767255 1.996130 -0.436228 -0.327975 -0.650743 0.507769 ... -0.821793 1.387792 -1.052105 2.123603 1.421092 2.066746 -0.747766 0.627081 -1.749071 -0.679443 14 1.277579 0.653945 0.045632 -0.409790 0.829708 0.249433 -0.682051 0.601958 -1.932014 -2.077397 ... 0.160611 1.037856 0.656832 0.992817 -0.684056 1.031199 -0.180866 4.579140 -1.123555 0.181580 15 0.356328 -2.038538 -1.018938 1.112716 1.035987 -2.281600 0.416325 -0.129400 -0.718316 -1.042091 ... -0.056092 0.559381 0.805026 1.783032 1.605907 0.323934 -0.243863 2.005648 -1.016667 1.604090 [16 rows x 64 columns] batch 2: 0 1 2 3 4 5 6 7 8 9 ... 54 55 56 57 58 59 60 61 62 63 0 0.645854 1.291073 -1.588931 1.814376 -0.185270 0.846816 -1.686862 0.982995 -0.973108 1.297203 ... 0.852600 1.533231 0.692729 2.437029 -0.178137 0.493413 0.597484 1.909155 1.257821 2.644325 1 1.732205 -0.858178 0.324008 1.022650 -1.172865 0.513133 -0.121611 2.630085 0.072425 2.332296 ... 0.737660 1.988225 2.544661 1.995471 0.447863 3.174428 0.444989 0.860426 2.137797 1.537580 2 3.298391 -0.363908 0.376535 -0.276692 1.262433 -0.595659 1.694541 0.542514 -0.464756 0.368460 ... -0.169474 1.420809 0.304488 1.689731 -1.128037 -0.024476 -1.356808 2.160992 -2.110703 -0.472404 3 0.626955 -2.988524 0.915578 1.123503 0.635983 0.078006 0.466728 -0.930765 2.189286 1.505499 ... 2.496649 1.691578 0.642664 2.089205 1.926187 1.185045 -0.969952 0.666007 -0.030641 0.667574 4 0.396447 -2.116415 0.384262 -1.632779 0.859029 -0.726599 2.121946 -1.314046 0.744388 -0.227106 ... -1.937352 2.378620 0.029220 1.215336 -0.405487 -0.834419 -1.219825 0.000676 -0.821293 0.340797 5 -2.133014 0.379737 -1.320323 -0.425003 -0.298524 -2.237205 0.953327 0.168006 0.519205 0.698976 ... 0.788771 1.237731 1.515378 1.296695 0.070718 0.763281 1.098920 0.557059 -0.582510 2.151497 6 -0.390918 0.634039 -1.350461 0.032129 0.106428 0.370410 1.292387 0.986316 -0.095396 0.555067 ... -1.792372 -0.357599 0.912276 0.088746 0.866950 0.927208 -0.381643 2.532119 0.464615 -1.044299 7 -0.407947 0.622332 -0.345048 -0.247587 -0.419677 0.256695 1.165026 -2.459640 -0.576545 -1.770781 ... 0.234064 2.278682 0.256901 1.516285 0.623413 1.219177 1.346708 0.072133 0.511105 1.656851 8 3.503946 -1.146751 0.111070 0.114221 -0.930330 -0.248769 1.166547 -0.038856 -0.301910 -0.843072 ... 0.093177 1.646091 -0.204405 1.187101 -1.286342 -0.068145 0.698141 0.626402 -0.333428 0.537781 9 -1.946920 -0.443788 0.560103 3.584257 -0.134643 -1.538940 -1.059084 -0.128679 2.503847 -2.244587 ... -0.643552 1.608934 -0.488734 -0.291253 1.633294 -0.018763 0.696360 -0.657761 0.692395 1.741288 10 0.376520 0.583786 -0.705047 0.855548 0.471473 0.687240 -0.605646 0.463047 1.619052 -1.894214 ... -0.688652 1.974150 -1.399412 2.567682 -0.050040 1.782055 -0.297912 2.366196 -1.888527 0.635260 11 -0.109256 -1.394054 0.565499 -0.093785 -1.803309 0.662382 -1.528203 1.644028 -0.569133 0.438101 ... 0.741877 1.988214 2.547823 1.995465 0.450234 3.174424 0.446768 0.860424 2.139130 1.537579 12 -1.553993 -0.983421 0.392842 -1.473186 1.530387 1.894017 -0.732786 -1.601045 -0.740344 0.245303 ... -0.328828 3.013883 1.178296 1.263333 0.284824 0.791874 2.402131 -0.231270 -1.025411 0.178748 13 -0.757965 1.771306 0.805440 -0.509121 1.212250 0.388750 -0.606959 2.352489 -2.445346 -0.103223 ... 0.425556 1.783019 0.698336 1.871530 2.314023 0.424368 -1.002745 0.983784 -0.090133 0.905337 14 -0.183482 0.232812 -1.628186 -0.080981 0.377364 -1.309922 -0.275964 1.609030 -1.440347 0.427969 ... 0.792566 1.237715 1.518224 1.296686 0.072853 0.763276 1.100520 0.557057 -0.581310 2.151496 15 -0.151101 -0.257150 -0.478131 -1.170082 1.318685 -0.188166 0.146375 2.895475 -0.918949 -0.305261 ... 1.623350 1.656103 -0.600456 1.039260 -1.944202 0.894911 1.409396 1.722673 -0.172070 2.265543 [16 rows x 64 columns] batch 3: 0 1 2 3 4 5 6 7 8 9 ... 54 55 56 57 58 59 60 61 62 63 0 0.377847 -0.380613 1.958640 0.224087 -0.420293 0.915635 -1.077748 1.255988 -0.223147 0.977568 ... -1.290532 1.460963 1.365088 -2.037483 -2.213841 1.039091 -2.129649 0.108403 -0.356996 2.239356 1 0.527961 0.342787 0.096746 0.885016 0.706699 2.873656 0.139732 0.497379 -0.009022 -0.147825 ... -0.409913 0.785146 -0.138166 2.041000 0.277500 1.578947 -1.535113 0.912230 -0.312735 0.540365 2 1.054965 -0.134411 2.155045 -0.188724 0.651576 -0.265663 -0.777263 0.571080 1.508661 1.021718 ... 0.762458 2.297400 -0.624743 -0.979212 2.024008 1.295633 0.208825 0.953138 -2.962624 1.586901 3 -1.032970 -0.893918 0.029077 -0.232068 0.370793 -1.407092 1.048066 0.981123 0.331907 1.292072 ... 0.787928 1.237732 1.514746 1.296695 0.070244 0.763281 1.098564 0.557060 -0.582777 2.151497 4 -0.980037 -1.014605 1.875135 -2.459635 0.486067 -0.941092 1.205490 1.248531 1.801383 0.576983 ... 0.192097 1.784109 -0.201023 0.405095 0.982041 1.927637 0.008535 1.063376 -1.439787 2.967185 5 -0.369996 -1.151058 -0.126222 0.768431 0.107524 -0.481010 2.056029 -0.872815 1.522675 -0.440916 ... 0.246007 -1.032684 0.572565 0.944744 0.790383 -0.034063 -1.704374 -0.053319 1.739537 2.381506 6 -0.555136 -0.284736 -0.162689 -1.542923 -1.619371 -2.014224 0.957231 -0.338164 1.353500 -2.048436 ... 0.180549 -0.598603 0.427175 1.845072 0.924364 -0.013093 -0.054108 -0.082885 -0.719218 0.960552 7 0.548834 1.130444 1.207497 0.565839 -1.814344 -0.111523 0.480270 -1.741823 1.451116 -0.977640 ... 1.692325 -0.708754 -0.747591 1.373189 -0.224415 -0.074035 -0.323435 2.001849 -1.102584 1.644658 8 0.117209 -0.905490 0.272336 0.994848 0.648951 0.354459 -0.731171 -1.641071 -0.966286 -0.837498 ... 0.294006 1.008774 1.376944 2.969555 0.997452 2.076708 0.631358 1.080600 0.075384 1.819302 9 0.557786 -0.629395 1.606758 0.633762 -1.190379 -0.355466 -2.132275 -0.887707 1.208793 -0.741505 ... 0.765410 2.297393 -0.622529 -0.979216 2.025668 1.295631 0.210070 0.953136 -2.961691 1.586900 10 1.107697 -2.050459 1.399869 1.271179 -1.391529 1.103020 -0.910370 -0.398901 -0.803458 -2.081302 ... 1.462017 -0.115730 0.171052 0.594118 0.514388 1.593223 0.064085 -0.029184 -0.044621 1.206415 11 -1.771933 0.469475 0.961730 0.002798 1.386089 0.250342 -0.062900 -0.569053 -2.149857 -0.519952 ... -0.725692 -0.727693 -0.178683 1.675822 -0.401712 1.109331 0.980627 -0.357667 -0.484853 0.208340 12 -1.518213 1.899549 -0.320427 -0.929415 -0.701020 0.727833 -2.764498 0.612756 0.041370 -1.599998 ... -0.136314 1.068995 0.635501 0.765369 0.270007 0.319588 -0.652992 1.322658 1.724227 2.343042 13 0.094923 0.575470 -0.852224 -2.098593 0.998579 0.347285 -0.467688 0.773722 -1.664829 -0.412623 ... -1.274262 0.454381 -1.142107 1.853844 -1.912537 0.544311 0.667555 -1.187468 1.291108 2.275956 14 -0.183482 0.232812 -1.628186 -0.080981 0.377364 -1.309922 -0.275964 1.609030 -1.440347 0.427969 ... 0.792566 1.237715 1.518224 1.296686 0.072853 0.763276 1.100520 0.557057 -0.581310 2.151496 15 2.053710 -2.769740 -0.148796 0.983717 -0.038190 -0.655360 1.826909 -0.332533 -1.036128 -1.001430 ... 0.674310 0.695848 -0.181635 1.051397 -0.884897 1.590696 -1.375117 0.596254 -0.651398 0.797715 [16 rows x 64 columns]

the final input embeddings to be fed to the Transformer decoder blocks for training.

This final result matrix is called Position-wise Input Embedding which in a shape of (batch_size, context_length, d_model) = [4, 16, 64].

This is all about Positional Encoding.

But why use sine and cosine function to encode position? Why not just a random number? Why does summing the two numbers can contain both its meaning and position information? Well, I had the exact same wonderment at the beginning. However, I discovered that it isn't necessary to fully comprehend its underlying mathematics in order to train a model. Therefore, if you are eager for detailed explanations, please refer to the separate section or watch my video clips. I will leave it at that and proceed to the next step.

Done and start 1

So far we've been covered the input encoding and positional encoding parts of the model. Let's move to the Transformer Block.

Step 4 : Transformer Block

Transformer block

A Transformer block is a stack of three layers: a masked multi-head attention mechanism, two normalization layers and a feed-forward network.

The masked multi-head attention is a set of self-attentions, each of which is called a head. So let's first look at the self-attention mechanism.

4.1 Multi-Head Attention Overview

Transformers get their strength from something called self-attention. With self-attention, the model pays close attention to the most critical parts of the input. Each part is called a head.

This is how a head works: A head works by processing the input through three unique layers called queries (Q), keys (K), and values (V). It starts by comparing Q and K, adjusting the results, and then using these comparisons to create a set of scores showing what's important. These scores are then used to weigh the information in V, giving more attention to the important parts. The head's learning comes from adjusting the settings in these Q, K, and V layers over time.

Multi-head attention simply consists of several individual heads stacked together. All heads receive the exact same input, although they employ their own specific sets of weights during computation. After processing the input, the outputs from all the heads are concatenated and then passed through a linear layer.

Following figure provides a visual representation of the processes within a head, and details in Multi-head Attention block.

Attention block step

In order to perform the attetion calculation, let's bring the formula from original paper "Attention is all you need":

Attention formula

From the formula, we first require three matrices: Q (queries), K (keys), and V (values). To compute the attention scores, we need to perform the following steps:

  1. multiply Q by K transpose (denoted as K^T)

  2. divide by the square root of dimension of K

  3. apply softmax function

  4. multiply by V

We'll go through one by one.

4.2 Prepare Q,K,V

The first step to calculating attention is to acquire the Q, K, and V matrices, representing query, key, and value. These three values will be used in our attention layer to calculate the attention probabilities(weights). These are determined by applying our Position-wise Input Embedding matrix from previous step, denoted as X, to three distinct linear layers labeled Wq, Wk, and Wv (all values are randomly assigned first and learnable). The output of each linear layer is then split into a number of heads, denoted as num_heads, here we choose 4 heads.

The Wq, Wk, Wv are three matrices defined with dimension of (d_model, d_model) = [64, 64]. All values are randomly assigned. This in neural network is called a linear layer or trainable parameters. The trainable parameters are the values that the model will learn and self-update during training.

To obtain our Q,K,V values, we do matrix multiplication between our input embedding matrix X and each of the three matrices Wq, Wk, Wv (once again, their initial values are randomly assigned).

  • Q = X*Wq

  • K = X*Wk

  • V = X*Wv

The calculation(matrix multiplication) logic of the above function:

X is in shape of (batch_size, context_length, d_model) = [4, 16, 64], we break it into 4 sub-matrices of shape [16, 64]. While Wq, Wk, Wv are in shape of (d_model, d_model) = [64, 64]. We can do matrix multiplication on each of the 4 X's sub-matrix against Wq, Wk, Wv.

If you recall linear algebra, the multiplication of two matrices is only possible when the number of columns in the first matrix is equal to the number of rows in the second matrix. In our case, the number of columns in X are 64, and the number of rows in Wq, Wk, Wv are also 64. Therefore, the multiplication is possible.

The matrix multiplication result in shape of 4 sub-matrices of shape [16, 64], which can be combined expressed as (batch_size, context_length, d_model) = [4, 16, 64].

Now, we have our Q, K, V matrices in shape of (batch_size, context_length, d_model) = [4, 16, 64]. Next we need to split them into multiple heads. This is why the transformer architecture name it as Multi-Head Attention.

Splitting heads simply means out of the 64 dimensions of our d_model, we cut them into multiple heads, each head contains a certain number of dimensions. Each head will be able to learn certain patterns or semantics of the input.

Let's say we set our num_heads to be 4 as well. This means we will split our Q, K, V matrices of shape [4, 16, 64] into multiple sub-matrices.

The actual splitting is done by reshaping the last dimension of 64 into 4 sub-dimensions of 16.

Each Q, K, V matrices transformed from shape [4, 16, 64] to [4, 16, 4, 16]. The last two dimensions are the heads. In other words, it transformed from:

[batch_size, context_length, d_model]

to:

[batch_size, context_length, num_heads, head_size]

Tb 1

To understand the Q, K, and V matrices with identical shapes [4, 16, 4, 16], consider the following perspective:

In the pipeline, there are four batches. Each batch consists of 16 tokens (words). For every token, there are 4 heads, and each head encodes 16 dimensions of semantic information.

4.3 Calculate Q,K Attention

Attention formula 0

Now that we have all three matrices of Q, K and V, let’s start calculating single-head attention step by step.

From the diagram of transformer, the Q and K matrices are multiplied first.

Now, if we discard the batch_size from the Q and K matrices with retaining only the last three dimensions, now Q = K = V = [context_length, num_heads, head_size] = [16, 4, 16].

We need to do another transpose on the first two dimensions, to make them in shape of Q = K = V = [num_heads, context_length, head_size] = [4 ,16, 16]. This is because we need to do matrix multiplication on the last two dimensions.

Q * K^T = [4, 16, 16] * [4, 16, 16] = [4, 16, 16]

Why do we do this? The transposition here is done to facilitate matrix multiplication between different contexts. It's more straightforward to explain with a diagram. The last two dimensions, represented as [16, 16], can be visualized as follows:

Self attn diag

This matrix, where each row and each column represents a token(word) in the context of our example sentence. The matrix multiplication is the measurement of similarity between each word and every other word in the context. The higher the value, the more similar they are.

Let me bring up one head of the attention score:

[ 0.2712, 0.5608, -0.4975, ..., -0.4172, -0.2944, 0.1899], [-0.0456, 0.3352, -0.2611, ..., 0.0419, 1.0149, 0.2020], [-0.0627, 0.1498, -0.3736, ..., -0.3537, 0.6299, 0.3374], ..., ..., ..., ..., ..., ..., ..., ..., ..., ..., ..., ..., ..., ..., [-0.4166, -0.3364, -0.0458, ..., -0.2498, -0.1401, -0.0726], [ 0.4109, 1.3533, -0.9120, ..., 0.7061, -0.0945, 0.2296], [-0.0602, 0.2428, -0.3014, ..., -0.0209, -0.6606, -0.3170] [16 rows x 16 columns]

The numbers in this 16 by 16 matrix represent attention scores of our example sentence " . By mastering the art of identifying underlying motivations and desires, we equip ourselves with ".

It is easier to see as a plot:

Qk plot 1

The horizontal axis represents one of Q's head, the vertical axis represents one of K's head, colored squares represent the similarity score between each token and each other token in the context. The darker the color, the higher the similarity.

The similarities shown above of course don't make much sense now, because those are just from randomly assigned values. But after training, the similarity scores will be meaningful.

OK, now let's bring the batch dimension, batch_size, back into the Q*K attention scores. The final result will have the shape [batch_size, num_heads, context_length, head_size], which is [4, 4, 16, 16].

This is the Q*K Attention Score at current step.

4.4 Scale

Attention formula 1

The scale part is quite straightforward, we just need to divide the Q*K^T attention score by the square root of the dimension of K.

Here, our dimension of K equals to the dimension of Q, which is d_model divided by num_heads: 64/4 = 16.

Then we take the square root of 16, which is 4. And divide the Q*K^T attention score by 4.

The reason of doing this is to prevent the Q*K^T attention score from being too large, which could cause the softmax function to saturate, which in turn could cause the gradient to vanish.

4.5 Mask

In the decoder-only transformer model, masked self-attention essentially functions as sequence padding.

Decoder could only view the previous characters, but not the future characters. So the future characters are masked out and used for calculating the attention weights.

This is very easy to understand if we visualize the plot again:

Qk plot 2

white spaces means scores in 0, which are masked out

The point of masking in the multi-headed attention layer is to prevent the decoder from “seeing into the future”. In our example sentence, the decoder is only allowed to see the current word and all the words that came before it.

4.6 Softmax

Attention formula 2

The softmax step changes numbers into a special kind of list, where the whole list adds up to one. It increases high numbers and decreases low ones, creating clear choices.

In a nutshell, the softmax function is used to convert the output of the linear layer into a probability distribution.

In modern deep learning frameworks, like PyTorch, the softmax function is a built-in function which quite straightforward to use:

torch.softmax(attention_score, dim=-1)

This line of code will apply the softmax to all attentions scores we calculated in previous step and result in a probability distributions between 0 and 1.

Let's also bring up the same head's attention scores after applying softmax:

[1.0000, 0.0000, 0.0000, ..., 0.0000, 0.0000, 0.0000], [0.4059, 0.5941, 0.0000, ..., 0.0000, 0.0000, 0.0000], [0.3368, 0.4165, 0.2468, ..., 0.0000, 0.0000, 0.0000], ..., [0.0463, 0.0501, 0.0670, ..., 0.0547, 0.0000, 0.0000], [0.0769, 0.1974, 0.0205, ..., 0.1034, 0.0464, 0.0000], [0.0684, 0.0926, 0.0537, ..., 0.0711, 0.0375, 0.0529]

all the probability scores are now positive and add up to 1.

4.7 Calculate V Attention

Attention formula 3

The last step is to multiply the softmax output with the V matrix.

Remember, our V matrix also split it into multiple heads in shape of (batch_size, num_heads, context_length, head_size) = [4, 4, 16, 16].

While the output from previous softmax step is in shape of (batch_size, num_heads, context_length, head_size) = [4, 4, 16, 16].

Here we perform another matrix multiplication on the last two dimensions of both matrices.

softmax_output * V = [4, 4, 16, 16] * [4, 4, 16, 16] = [4, 4, 16, 16]

The result is in shape of [batch_size, num_heads, context_length, head_size] = [4, 4, 16, 16].

We call this result A.

Tb 2

4.8 Concatenate and Output

The final step of our multi-head attention is to concatenate all the heads together and pass them through a linear layer.

The ideal of concatenation is to combine the information from all heads together. Thus, we need to reshape our A matrix from [batch_size, num_heads, context_length, head_size] = [4, 4, 16, 16] to [batch_size, context_length, num_heads, head_size] = [4, 16, 4, 16]. The reason is that we need to put num_heads and head_size back together to the last two dimensions, so will easily combine them (by doing matrix multiplication) back to size d_model = 64.

This can be easily done through PyTorch's built-in function:

A = A.transpose(1, 2) # [4, 16, 4, 16] [batch_size, context_length, num_heads, head_size]

Next, we need to combine the last two dimensions [num_heads, head_size] = [4, 16] to [d_model] = [64].

A = A.reshape(batch_size, -1, d_model) # [4, 16, 64] [batch_size, context_length, d_model]

As you can see, after series of calculations our result matrix A is now back to the shape same as our input embedding matrix X, which is [batch_size, context_length, d_model] = [4, 16, 64]. Since this output result will be passed to the next layer as the input, so keeping the input and output the same shape is necessary.

But before passing it to the next layer, we need perform another linear transformation on it. This is done by doing another matrix multiplication between the concatenated matrix A and Wo.

Tb 3

This Wo is randomly assigned with shape [d_model, d_model] and will be updated during training.

Output = A* Wo = [4, 16, 64] * [64, 64] = [4, 16, 64]

The output of the linear layer is the output of the single-head attention, denoted as output.

Congratulations! Now we have finished the masked multi-head attention part! And let's start the rest of the transformer block. Those are quite straightforward, so I will go through them quickly.

Done and start 2

Step 5 : Residual Connection and Layer Normalization

A residual connection, sometimes referred as a skip connection, is a connection that let original input X bypasses one or more layers.

This is simply a summation of the original input X and the output of the multi-head attention layer. Since they are in same shape, adding them up is straightforward.

output = output + X

After the residual connection, the process goes to layer normalization. LayerNorm is a technique that normalizes the outputs of each layer in the network. This is done by subtracting the mean and dividing by the standard deviation of the layer's outputs. This technique is used to prevent the outputs of a layer from becoming too large or too small, which can cause the network to become unstable.

This is also single line of code in PyTorch using nn.LayerNorm function. We will see in action in the Let's Code LLM section.

The residual connection and layer normalization are denoted as Add & Norm in the diagram of "Attention is All You Need"'s original paper.

Step 6 : Feed-Forward Network

Once we have the normalized attention weights (probability scores), it will be processed through a position-wise feedforward network.

The feed-forward network (FFN) consists of two linear layers with a ReLU activation function between them. Let's see how the python code implements this:

# Define Feed Forward Network output = nn.Linear(d_model, d_model * 4)(output) output = nn.ReLU()(output) output = nn.Linear(d_model * 4, d_model)(output)

Take the ChatGPT explanation to the above code:

  1. output = nn.Linear(d_model, d_model * 4)(output): This applies a linear transformation to the incoming data, i.e., y = xA^T + b. The input and output sizes are d_model and d_model * 4, respectively. This transformation increases the dimensionality of the input data.

  2. output = nn.ReLU()(output): This applies the Rectified Linear Unit (ReLU) function element-wise. It's used as an activation function that introduces non-linearity to the model, allowing it to learn more complex patterns.

  3. output = nn.Linear(d_model * 4, d_model)(output): This applies another linear transformation, reducing the dimensionality back to d_model. This kind of "expansion then contraction" is a common pattern in neural networks.

As someone new to machine learning or LLM, like you and me, could have get lost by these explanations. I got exactly same feeling when I first encounter these terms.

But no worries, we can understand it like this: this feed-forward network is just a standard neural network module whose input and output are both attention scores. Its purpose is to expand the dimensionality of the attention scores from 64 to 256, which makes the information more granular and enables the model to learn more complex knowledge structures. Then, it compresses the dimensions back down to 64, making it suitable for subsequent computations.

Step 7 : Repeat Step 4 to 6

Cool! We have finished the first transformer block. Now we need to repeat the same process for the rest of the number of transformer blocks we want to have.

In terms of heads, I quoted from HuggingChat's AI response:

By having multiple blocks, the output is trained and being passed into next block as input X, so after iterations the model can learn more complex patterns and relationships between the words in the input sequence.

Step 8 : Output Probabilities

Done and start 3

During inference, you want to get the next predicted token from the model, but what we get so far is actually a probability distribution of all the tokens in the vocabulary. Do you still recall our vocabulary size is 3,771 in the example above? So in order to select one of the highest probability token, we will form a matrix with the size of the model dimensions d_model = 64 by our vocab_size = 3,771. This step is no difference in training than in inference.

# Apply the final linear layer to get the logits logits = nn.Linear(d_model, vocab_size)(output)

We call the output after this linear layer the logits. The logits is a matrix of shape [batch_size, context_length, vocab_size] = [4, 16, 3771].

Then a final softmax function is used to convert the logits of the linear layer into a probability distribution.

logits = torch.softmax(logits, dim=-1)

Note: during training, we don't need to apply softmax function here, but instead use nn.CrossEntropy function since it build-in softmax behaviours.

How do we look into the resulting logits of shape [4, 16, 3771]? Actually after all the calculations it's quite a simple idea:

We have 4 batch pipelines, each pipeline contains all the 16 words in that input sequence, with each word mapped to a probability against every other word in the vocabulary.

Final probabilities

If the model is in training, we update those probability parameters, if the model is in inference, we simply select one with the highest probability. Then everything makes sense.

Conclusion

Typically, it can be challenging to grasp the complexities of the Transformer architecture entirely during the initial exposure. Personally, it required approximately a month for me to thoroughly comprehend every component within the system. Therefore, I would suggest reviewing additional resources listed on our reference page and gaining inspiration from other remarkable Trailblazers.

On a personal note, I find videos more effective as tutorials. To facilitate learning, I have prepared several video explanations covering concepts and hands-on coding experience related to this architecture. They shall be available shortly.

Once you feel confident about understanding the Transformer architecture, moving forward to implementing the code through a series of guided steps would be advisable. Exciting progress awaits!

Last modified: 22 February 2024