Understanding the Transformer Architecture (Part 1)
Recently there have been some extraordinary breakthroughs in the field of Artificial Intelligence, from ChatGPT to LLaMA to Midjourney. These AI tools have seen one of the fastest adoptions to a tool in human history ever. Something common between all these models is that they work on top of an AI model architecture called Transformer architecture.
In a series of 2 blogs, we will delve into the transformative world of AI with robust Transformer architecture. From natural language processing to computer vision, the Transformer has revolutionized various domains, enabling machines to understand, generate, and interpret complex data like never before. Though before jumping right into the transformer architecture, we will first need to address why there was a need for such a model.
<strong><h1><center>RNNs and LSTMs</strong></h1></center>
One of the first neural network-based models introduced for processing sequential data was <strong>RNN (Recurrent Neural Network)</strong>. Natural Language Processing tasks like sentence translation and question-answering are sequential tasks. RNN has <b>hidden states</b> as one of the outputs, which is passed in as future input to the same architecture. These hidden states, in theory, allow RNN to store previous words in the sequence, giving it an "infinitely large memory" compared to CNN, which has a limited amount of memory.
<img src="https://i.ibb.co/HnLLhQq/rnn.png" alt="rnn" border="0">
<center> h<sup>t</sup> = g<sub>1</sub>(W<sub>hh</sub>h<sup>t-1</sup> + W<sub>hx</sub>x<sub><t></sub> + b<sub>h</sub>)</center>
<center> y<sup>t</sup> = g<sub>2</sub>(W<sub>yh</sub>h<sup>t</sup> + b<sub>y</sub>) </center>
Here, h<sup>t</sup> is the hidden state, y<sup>t</sup> is the output at time t, W<sub>hh</sub>, W<sub>hx</sub>, W<sub>yh</sub>, b<sub>y</sub>, and, b<sub>h</sub> are the parameters that need to be learned.
<img src="https://i.ibb.co/DGDW9vK/rnn2.png" alt="rnn2" border="0">
Though the idea sounds good and works in theory, RNNs struggle with relatively larger sentences. It is observed that they tend to forget the information encountered earlier in the sentence. This would happen because RNN would try to store all the words without filtering them out. RNN model would also suffer from vanishing/exploding gradient problems.
<strong>LSTMs (Long-Short Term Memory)</strong> were introduced to overcome this problem faced by RNNs. LSTMs introduce the idea of gates, <b>input gate, forget gate, and output gate</b>. These gates allow LSTM to retain more important information rather than "remembering" everything from the start, which helps LSTM retain more information. These gates also address the vanishing gradients problem by allowing the network to selectively retain important information, update the memory cell with new information, and regulate the flow of information. This prevents the gradients from diminishing too quickly over time, enabling LSTMs to capture long-term dependencies and improve training efficiency.
Seq2Seq models, short for Sequence-to-Sequence models, are neural network architectures designed to process sequential data of variable lengths and generate output sequences of varying lengths. We can use this architecture to build models for different NLP tasks like language translation, text summarization, image captioning, etc.
The Seq2seq architecture has an <strong>encoder and decoder</strong>. The encoder takes a variable-sized input and returns a fixed-size output called a <b>state</b>. The state represents the input sequence in a fixed-size vector format. The state is then passed into the decoder, which gives the required output for the given input. <b>The encoder and decoder models are generally LSTMs or computationally cheaper versions of LSTMs called GRU</b>.
For example, if we want to build a language translation model from English to French, the English sentence acts as the input. The input is converted into a fixed-size state using encoder models. The state is built using the output of LSTM or GRU at each step. The fixed-size state is then passed into the decoder model, an RNN or LSTM. The model's output at each step acts as input for predicting the next word, along with the output of the encoder model.
<img src="https://i.ibb.co/wgcrSPp/seq2seq.png" alt="seq2seq" border="0">
Seq2Seq models are supervised learning models. To train the model, we pass the input to the encoder (in the above case, it would be an English sentence) and the corresponding output to the decoder (in the above case, it would be a French sentence). This technique is called Teacher Forcing. In RNN/LSTMs, their output from the previous state acts as input for the next state. But in Teacher forcing, for the decoder model, instead of passing the "wrong" output generated by the model at each step, we pass the correct input, leading to faster convergence to the accurate weights.
The encoder tries to squeeze all the information of varied-sized sentences into a fixed-size output state. As a result, the Seq2seq model faces an information bottleneck problem for longer sentences. This can lead to the loss of important information required to generate the correct output (in the above example, the correct French translation). Apart from that, as LSTMs and RNNs work sequentially (output at time <i>t-1</i> is input at time <i>t</i>), it is not possible to parallelize the Seq2seq architecture, which is computationally expensive. Building robust and efficient LLMs (Large Language Models) based on the Seq2Seq architecture would be almost impossible.
The attention mechanism was introduced to tackle the problem regarding information loss when the encoder tries to convert varying-size input to fixed-size output. The attention mechanism allows the model to focus on the most important parts of the input sequence enabling the encoder to give a fixed-size output with only relevant information.
In a sequence-to-sequence model with attention, the model has an additional component that learns to assign weights to different parts of the input sequence based on their relevance to the current step of generating the output sequence. These weights represent the attention scores.
<img src="https://i.ibb.co/6DNqX5r/attention.png" alt="attention" border="0">
In the next part of this article, I will describe the transformer architecture and how it is an improvement over the seq2seq architecture with attention.
- Ojas Srivastava, 04:44 PM, 14 May, 2023