The Mathematics Behind Transformers

The transformer architecture, introduced in the landmark 2017 paper Attention Is All You Need, fundamentally changed how we process sequences. Unlike recurrent networks that process tokens one at a time, transformers process entire sequences in parallel using a mechanism called self-attention.

Input Representation

Every input token is converted to a dense vector of dimension d_model (typically 512 or 768). But raw embeddings carry no positional information — the model would treat "cat sat" and "sat cat" identically. The fix is positional encoding:

PE(pos, 2i)   = sin(pos / 10000^(2i/d_model))
PE(pos, 2i+1) = cos(pos / 10000^(2i/d_model))

These sinusoidal functions create a unique signature for each position while preserving relative distance information — a position k steps away can always be expressed as a linear function of the current position's encoding.

The Attention Mechanism

The core idea is elegant: given a sequence of vectors, each position should be able to "attend to" every other position, weighted by relevance. To do this, we project each input vector into three spaces:

  • Query (Q): what this position is looking for
  • Key (K): what this position offers
  • Value (V): what this position will contribute

These projections use learned weight matrices:

Q = X * W_Q,   K = X * W_K,   V = X * W_V

The attention scores between all pairs of positions are computed as dot products between queries and keys, scaled to prevent vanishing gradients in the softmax:

Attention(Q, K, V) = softmax(Q*K^T / sqrt(d_k)) * V

The scaling factor sqrt(d_k) is crucial. Without it, dot products grow large for high-dimensional vectors, pushing the softmax into regions with near-zero gradients.

Why Dot-Product Attention Works

The dot product Q * K^T measures alignment between vectors. When a query vector aligns strongly with a key vector, that position's value gets high weight. After softmax, the attention weights sum to 1 across all positions — they form a probability distribution over which values to aggregate.

Multi-Head Attention

A single attention head can only look for one type of relationship at a time. Multi-head attention runs h attention operations in parallel, each with its own projection matrices:

MultiHead(Q, K, V) = Concat(head_1, ..., head_h) * W_O
where head_i = Attention(Q*W_Qi, K*W_Ki, V*W_Vi)

With 8 heads, each can specialize — one might track syntactic dependencies, another semantic similarity, another long-range coreference. The outputs are concatenated and projected back to d_model.

Feed-Forward Sublayer

After attention, each position passes independently through a two-layer network:

FFN(x) = max(0, x*W_1 + b_1) * W_2 + b_2

The inner dimension is typically 4x the model dimension (2048 for d_model = 512). Research has shown that specific facts can often be localized to specific FFN weight matrices — they function as a kind of key-value memory.

Residual Connections and Layer Normalization

Each sublayer is wrapped in a residual connection followed by layer normalization:

x = LayerNorm(x + Sublayer(x))

Residual connections allow gradients to flow directly through many layers during backpropagation. Layer normalization stabilizes training by normalizing activations per token, making it more suitable for variable-length sequences than batch normalization.

The Complete Architecture

A transformer encoder stacks N identical layers (typically 6-12), each with:

  1. Multi-head self-attention sublayer
  2. Feed-forward sublayer

The decoder adds a third sublayer: cross-attention, where queries come from the decoder and keys/values come from the encoder output, letting the decoder attend to the full input at every generation step.

Why Transformers Win

RNNs must compress all history into a fixed-size hidden state. Transformers maintain direct connections between every pair of positions — the path length between any two tokens is O(1) regardless of sequence length. This makes them dramatically better at learning long-range dependencies.

The trade-off is O(n²) memory and compute in the attention matrix with respect to sequence length. This is why efficient attention variants (sparse attention, linear attention, flash attention) remain an active area of engineering and research.

Glossary

Token — The basic unit a transformer processes. Depending on the system, a token might be a whole word, part of a word, or a single character. The sentence "Hello world" might become two or three tokens. Each token is converted into a vector before the model does any computation.

Vector — An ordered list of numbers. In machine learning, vectors represent meaning: a word like "king" might be encoded as a list of 512 numbers, each capturing some aspect of its meaning or context. The geometry of these numbers matters — related concepts end up as vectors that point in similar directions in this high-dimensional space.

Embedding — A vector assigned to a token that encodes its meaning in a form the model can work with. Embeddings start as random numbers before training; through training, the model adjusts them so that tokens with related meanings end up with similar vectors.

Weight matrix — A grid of numbers the model learns during training. Multiplying a vector by a weight matrix produces a new vector — a transformed version of the original. The Q, K, and V projections are weight matrices the model learns to specialize for its task; so are the parameters in the feed-forward sublayer.

Gradient — A measure of how much changing a particular weight would reduce the model's prediction errors. During training, all weights are nudged in the direction that reduces errors, guided by gradients.

Backpropagation — The algorithm that computes gradients for every weight in the network. It works by propagating the error signal backward from the output toward the input, using partial derivatives at each step. A partial derivative measures the effect of changing one specific weight while holding all others fixed — giving the model a precise answer to the question "if I adjust this weight slightly, how much does the overall error change?" By chaining these measurements layer by layer from output back to input, backpropagation efficiently assigns credit (or blame) to every weight in the network. This is why keeping a clear path for gradients to flow backward — via residual connections — matters so much in deep networks.

Softmax — A function that converts a set of raw scores into a probability distribution: positive numbers that sum to 1. In attention, the dot-product scores between queries and keys are passed through softmax to produce attention weights — determining how much each position contributes to the current output.

Recurrent neural network (RNN) — An earlier architecture for processing sequences. The "recurrent" refers to how it works: at each step, the network takes the current token as input along with its own output from the previous step — the hidden state — and produces a new hidden state. That output is then fed back in as input for the next token, and so on. This loop is what makes it recurrent. The hidden state acts as a running summary of everything seen so far, like reading a book word by word while rewriting a fixed set of notes after each word. The problem is that this summary has a fixed size, so as sequences grow longer, earlier information gets diluted or lost — a weakness transformers avoid by connecting every position directly to every other.