Towards 100x Speedup: Full Stack Transformer Inference Optimization
University of Edinburgh | yao.fu@ed.ac.uk
Started writing in Sep 2023, released on Dec 11 2023, last updated on May 17 2024
Imagine two companies have equally powerful models. Company A can serve the model to 10 users with 1 GPU, but company B can serve 20 users. Who will win in the long run?
- Company B, because its cost is cheaper
Imagine a researcher has come up with a super smart decoding method: clever algorithm, solid math, but not compatible with FlashAttention. Can this method be used in production?
- Probably not, because flash attention is essential for large scale model deployment
An in-depth understanding of transformer inference can be extremely beneficiary for both research and production. Yet in real world, large scale production is usually not so close to cutting edge research, such that people know algorithm may not know MLsys, and verse visa.
In this post, we discuss full-stack transformer inference optimization, from hardware specs like A100 memory hierarchy, to MLSys methods like FlashAttention and vLLM, to model architectures like Mixture of Experts, to decoding algorithms like Speculative Decoding and its variants. We identify the most foundamental fact that transformer inference is memory bound, and most of the optimization, either from MLSys or from modeling, is based on / exploiting this fact. Like adding buffs in an RPG game, we see how transformer inference is scaled and speed up, step by step.
Cite this work
Table of Content
1 - Hardware: inference on GPUs1.1 - Preliminary 1.1.1 - GPU architecture1.1.2 - GPU programming basics1.1.3 - Compute v.s. memory bound1.2 - Transformer inference basics1.2.1 - Prefilling and decoding1.2.2 - Transformer inference is memory bound1.2.3 - Memory layout1.2.4 - Online and offline inference, throughput v.s. latency 1.2.5 - Context length scaling2 - MLSys: Flash attention and vLLM2.1 - vLLM and Paged attention2.2 - Flash attention2.3 - Flash Decoding3 - Modeling: architecture and decoding algorithm3.1 - Classical methods3.1.1 - Distillation and sparse attention3.1.2 - Quantization 3.1.3 - Multi-query attention and group-query attention3.2 - Advanced techniques3.2.1 - Mixture of experts3.2.2 - Early exit3.2.3 - Blockwise decoding4 - What we have not yet covered5 - Personal pickConclusionReferences
1 - Hardware: inference on GPUs
We start with a discussion of GPU architecture, particularly its memory hierarchy. We identify two important patterns: compute bound and memory bound, and discuss why large transformer inference is memory bound. Then most of the optimization is based on the foundamental knowledge that transformer inference is memory bound, such as as long as we improve the flop utilization we can improve efficiency.
1.1 - Preliminary
1.1.1 - GPU architecture
Overall it looks like this
- Basics: DRAM, L2 cache, and SM
- Comparison to CPU
- SM is simialr to CPU cores, but with significantly larger level of parallelism
- L2 cache and DRAM is simialr to CPU L2 and DRAM
- In FlashAttention paper,
L2L1 is called SRAM
- A100 80G SXM
- 108 SM, DRAM 80G, 40M L2 cache
What’s inside an SM?
- L1 cache: instruction and data
- Tensor core: where matrix multiplication happens. Recall that neural network computation is basically gigantic batchs of matrix multiplication
1.1.2 - GPU programming basics
When performing
model.generate(prompt)
we do- Memory access:
- load model weights from HBM to L2 cache to SM
- Compute:
- perform a matrix multiplication in SM, SM asks tensor core to do it
- Which operation takes more time? Actually it is the memory access.
- We will show that when running transformer inference, most of the time is spent on waiting to move model parameters / activation from / to the memory, rather than the actual computation
- A100:
- 108 SM, DRAM 80G, 40M L2 cache
- bf16 tensor core: 312 trillion float point operation per second (tflops)
- DRAM memory bandwidth 2039 GB/sec = 2.039T /sec
- If the model is large, we split it into multiple GPUs, say two, where the two GPUs are connected by NVLink
- NVLink 300GB/sec = 0.3T / sec
- We roughly observe the speed hierarchy. Although they are not directly comparible, their orders of magnitude difference is the major places we need to optimize:
- 312T (SM compute) > 2.03T (DRAM memory access) > 0.3T = 300G (NVLink cross-device communication) > 60G (PCIe cross-device communication)
- This means, if we we want things to be fast, we should try our best to
- fully utilize the SMs,
- reducing memory access in one GPU (because its much slower than compute),
- and reducing communication between GPUs (because its even slower than memory access)
1.1.3 - Compute v.s. memory bound
How to determine if we have fully utilized the SMs? We check whether we are compute or memory bounded by:
- define GPU operations per byte = flop / memory bandwidth
- A100 = 312 / 2.039
- define Arithmetic intensity = compute / memory access
- if arithmetic intensity ~ ops per byte, then compute bound, if smaller then memory bound. below is the arithmetic intensity of typical neural network layers, as noted by Nvidia blog:
- Increasing batch size will change the behavior from
computememory bound tomemorycompute bounded
- kernel fushion: reduce memory access ops, because we fuse multiple operations into one
1.2 - Transformer inference basics
1.2.1 - Prefilling and decoding
There are two steps when calling
model.generate(prompt)
- Prefilling:
- compute kv cache for prompt.
- Compute bound, because we compute a sequence of tokens in parallel
- decoding:
- sample next tokens autoregressively.
- Memory bound, because we only compute one token, not fully utilizing SMs
1.2.2 - Transformer inference is memory bound
Increasing batch size can change the pattern from memory to compute bound, as illustrated by the following figure from Kipply’s awesome blog.
- because decoding only sample one token at one pass
- Increase batch size improves hardware efficiency.
- because we compute multiple tokens at one time
- large batch changes from memory bound to flop bound.
- However, we may not do so large batch, because GPU memory is not large, currently maximum 80G.
Liqun also gives an awesome table dissecting the transformer inferece flop / memory
See how the auto-regression different from initial computation in terms of arithmetic intensity
1.2.3 - Memory layout
As we can see, to serve a 13B model in bf16 we only have about 10G memory to store the kv cache. This means that
- we cannot have too large batch (though we want the batch size to be large to improve efficiency),
- or too long sequence (though we want to serve 100k length for sure)
1.2.4 - Online and offline inference, throughput v.s. latency
- Offline: throughput optimization
- we care about this scenario because we may want to offline evaluate the model, say, run an intermediate pretrained checkpoint on 100 benchmarks to verify our pretraining is healthy
- Increase batch size can help, but remember currently 80G memory is maximum for single device.
- Online: latency and throughput tradeoff
- When batch size large (assume still fit in memory), we become compute bound, then latency increases
- The latency should not be slower than human read speed, or the user will complain, or switch to your competitor’s model
- But again, we do want the batch size to be large, to improve efficiency
1.2.5 - Context length scaling
So far we have implicitly assumed that our prompt is not long, say <4k. Now we consider the situation when our prompt is longer than 100k, this happens in scenarios where we want the model to read multiple pdfs then do document QA.
- Prefilling,
- this time, prefilling takes sigfinicantly longer time, because input lengths is so long
- in this case, latency to first token generation is important, because the user does not want to wait for 10s to see the model speak
- large KV cache
- simply because the context length is long
- currently it seems not so much of work improving inference in this space
Update: see Full Stack Transformer Inference Optimization Season 2: Deploying Long-Context Models for a detailed discussion for serving long-context models
2 - MLSys: Flash attention and vLLM
This section discusses how to fully exploit the GPU memory hierarchy. vLLM Gives a way of doing GPU memory management, just like the virtual memory for CPUs in an operating system; FlashAttention shows how to effectively reduce memory IO by keep most of the operation on the SMs, largely reducing the memory access overhead.
2.1 - vLLM and Paged attention
We have limited GPU memory, so we want to use them wisely for the kv cache. Yet GPU / Pytorch on itself does not automatically give you the best way to place kv cache into the memory, and its default strategy is actually quite bad. This motivates the Paged Attention in vLLM for GPU memory management:
- Basically constructing a memory management system similar to CPU memory management, to reduce fraction and fully utilize memory throughput
- Now the goto places for transformer inference
2.2 - Flash attention
FlashAttention is a must for every single practioners. Please simply memorize the full algorithm
Key idea:
- Instead of storing the full attention matrix in the HBM, do blockwise computation of the dot product, such that all the computation is performed in the L2 cache
Key advantage:
- significantly reduced memory usage, such that you can put in 100k context length using brutal force — yes, there is no fancy algorithm for 100k, just brutal force
- in the original paper the authors only test up to 16k, but 100k is totally doable
- significantly improved throughput, particularly for small models where a large portion of the flop is in the dot-product operation
2.3 - Flash Decoding
key idea: instead of using one query to scan the full kv cache, duplicate the query such that different chunks of the kv cache can be scanned in parallel
3 - Modeling: architecture and decoding algorithm
Now we enter the space that is more familiar with everyone. We start from standard well known techniques like distillation and quantization, then dive deep into advanced topics like mixture of experts and speculative decoding.
3.1 - Classical methods
3.1.1 - Distillation and sparse attention
- Distillation: to finetune a small model using larger models’ outputs / logits
- Finetune on outputs: nowadays everybody distill from GPT and you guys are very good at it, so I just skip this part
- Finetune on logits / distribution is a field that is less explored, some results show faster convergence speed and better quality
- Sparse and local attention, particularly important for long-context
- In the age of small models, this part is very well studied (c.f. Yi Tay’s awesome survey), but not sure whether these results hold for larger scale
- In the age of large models, there is little work in this space except Mistral’s sliding window attention
3.1.2 - Quantization
bitsandbytes
bitsandbytes-foundation • Updated Jan 10, 2025
- Most foundamental method, quantize the model weight to int 8.
- Quantization is nowadays a must to deploy a large model. The good news is that it does not really harm performance
- Yi-34B chat 4 bit quantization. Only requires 17G memory. Performance on benchmarks almost the same.
- Actually it’s quite fast, you can try it on huggingface
3.1.3 - Multi-query attention and group-query attention
- In genereal, multi-query attention significantly speed up training and inference by simultanously reducing memory and compute.
- Multi-query attention is also a great example where differences in small models do not exist in large models: for small models like 7B, multi-query attettion is worse than full attention, but when models become as large as 70B, multi-query attention has basically the same performance as full attention. LLaMA2 7B uses full attention, and 70B uses GQA.
- Current state of art large models by default use multi-query attention
3.2 - Advanced techniques
3.2.1 - Mixture of experts
Pretraining from scratch
- Say we have 7B activation, 34B params in total.
- Can we achieve?
- performance similar to 34B
- throughput better than 34B
- latency similar to 7B
- The above goal is made possible with the recently released Mistral MoE, see discussion on X
As we can see from the above table:
- Performance: the 50B MoE Mistral with 7B dense part is somewhere near a 34B model Yi and 67B DeepSeek
- Inference efficiency: the dense part of MoE is 7B, with top 2 activation
Mistral MoE shows a possibility of achieving the performance of a much large model while reducing the cost to a much smaller model
But what is the actual compute pattern of MoEs? As indicated by the awesome threads from Dmytro Dzhulgakov
- When concurrency is low, most of the time is spent on loading the two activated expert to the memory, which is smaller than a dense layer, thus requiring less memory access time, thus lower latency
- in other words, for a single query, MoE has lower latency than dense because we need to read less parameters from the memory
- When concurrency is large, we enter the flop bound regime, but since MoE has less activation than dense, the throughput is higher
- in other words, for many concurrent queries, MoE has higher throughput because the
#query | Single query | Many queries |
Bound | Memory bound | Compute bound |
MoE’s advantage | Shorter latency | Higher throughput |
Reason | Less number of activated parameters, thus less memory access | Less flop per token, thus larger batch size and higher throughput |
Another a natural question is, say if I have already trained a large dense model, can I reduce a dense model to be a MoE model?
- MoEfication: decompose a dense model to an MoE model, making it
- as efficient as a small model,
- but as strong as a large dense model.
- Ref:
- Zhang e.t al. 2021. MoEfication: Transformer Feed-forward Layers are Mixtures of Experts
- Zhang et. al. 2023. Emergent Modularity in Pre-trained Transformers
3.2.2 - Early exit
The key idea here is that, for some easy tokens, we do not need to compute all the transformer layers — just some of the layer will be enough, because they are easy tokens.
- for all tokens, use a gate to determine whether early exit or not
Ref:
- Schuster et. al. 2022. Confident Adaptive Language Modeling
- Bae et. al. 2023. Fast and Robust Early-Exiting Framework for Autoregressive Language Models with Synchronized Parallel Decoding
3.2.3 - Blockwise decoding
Speculative decoding
The key idea is, decoding one token at a time does not fully use the compute power of SMs, so we want to decode multiple tokens at a time. The following papers are worth paying attention
- Leviathan et. al. 2022. Fast Inference from Transformers via Speculative Decoding
- Chen et. al. 2023. Accelerating Large Language Model Decoding with Speculative Sampling
- Liu et. al. 2023. Online Speculative Decoding
Using a small draft model has the followin disadvantages:
- it is weak, so rejection rate may be high in some challening domains
- we should consider the accuracy and overhead tradeoff: a smaller draft model is fast but inaccurate, a larger draft model is accurate and slow
- we need to put in two models in your GPU, but remember we have already running out of GPU memories
So we want to use the large model as the proposal model for itself, because
- It is strong, so rejection rate can be reduced
- we only need to keep one model in the memory
This how we motivate Medusa
- use multiple heads to decode multiple tokens at a time
- use the large model itself as the draft model
4 - What we have not yet covered
- Deep water Hardware and MLSys
- I am new to MLSys. I have to consistently ask my friends whether my take on MLSys is correct or not. So inevidently my take in this space, although I try to cover the most important foundamentals, is merely scratching the surface
- There are many more awesome articles covering hardware the MLSys which I will add to the reference at the end of this article
- Distributed Inference of super large models. c.f. Pope et. al. Efficiently Scaling Transformer Inference
- key tech: model sharding, pipeline and tensor parallelism
- Strongly suggest reading the Nvidia blog on this part
- Continuous batching and libraries like DeepSpeed MII
- Linear architectures like Mamba
- Distillation and quantization for mixture of experts
- It seems that MoE’s experts in general have a higher level of intrinsic sparsity than dense FFNs, thus may take more radical quantization
- It is not so clear about the mistillation from / between MoE and dense, like how to distill from a gigantic MoE to a small dense
- More advanced blockwise decoding
5 - Personal pick
Yes, unfortunate for the abundant amount of model architecture papers, most of the gain, the real heavy lifting, is not from these papers, but from MLSys. Many fancy papers are simply not realistic because they are not compatible with SOTA MLSys advances like model parallelism and flash attention. So dear fellow researchers, it is time we get our hands dirty, and look very carefully on real codes.
Yet there are indeed performance gain from the modeling side, below are my favorate techniques, and you should always consider them by default:
MLSys | Modeling |
Model parallelism | Multi-query attention |
Flash attention | Speculative decoding |
vLLM / Paged attention | ㅤ |
Quantization | ㅤ |
Conclusion
In this post, we review the full-stack transformer inference optimization method, from GPU architecture to MLsys methods, from model architiecture to decoding algorithms. We can see that most of the performance gain come from the exploitation of a single principle: transformer inference is memory bound, such that we have addition computation power / flops to release. Then the optimization comes either from optimizing memory access like Flash Attention and Paged Attention, or from releasing the computation powers like Medusa and lookahead decoding.
We believe there are still plenty of rooms for improvements, either from MLSys perspective or modeling perspective. In the upcoming 2024, with larger models, longer context, the debutante of more open-source MoEs, hardwares with higher memory bandwidth and large memory capacity, mobile devices with larger DRAMs and dedicated compute engines, these factors combined will lead to more powerful AI accessible and hackable to everyone. A new era is coming.
Enjoyed reading? Why not pump a coffee into Yao’s brain to make him better at MLSys ☕️ 😁?
An ice latte at Manner, Shanghai
A cappuccino at Blue Bottle, New York
References
- Nvidia official blog on transformer inference. This blog covers model parallelism that we do not discuss
- Pytorch official support for transformer inference
- GPT-2 level inference, not so large though
- Deployment of PaLM 540B
- Deployment of Claude 52B
- Awesome GPU guide
- Single device inference. Compute and memory bound. Arithmetic indensity.
- Detailed analysis on the math of transformer inference
- Inference optimization algorithms
- Runtime of 7B / 13B model on 8 * 80G A100 with different paralelism strategies
- Utterance