Full Stack Transformer Inference Optimization Season 2: Deploying Long-Context Models
Yao Fu | Website | Blog | Twitter / X
University of Edinburgh | yao.fu@ed.ac.uk
Released on May 13 2024, Updated on Jun 28 2024
Update Jun 2024: we strongly recommend reading the Mooncake paper and Character AI’s blog for a real-world long-context deployment solutions.
 
Why long-context models are important? Because they are the foundations for advanced AI applications such as hour-long video understanding, repository-level coding agents, and life-long AI companion. Our research objective is to foster an AI-based application ecosystem. For this to happen, we have to reduce the deployment cost of long-context transformers.
 
This is the second season of our transformer inference optimization posts. In our first post, we discuss generic short-context inference optimization. This post focuses on long-context optimization. We aim to address an ambitious research challenge:
💡
How to reduce the deployment of 1M context production-level transformers to be as cheap as 4K?
To tackle a problem we need understand the problem first. To that end we describe a concurrent programming framework for quantitatively analyzing the efficiency challenges in serving multiple long-context requests under limited size of GPU high-bandwidth memory (HBM) regime. We give a detailed analysis about how all additional computational cost, compared to 4K context, trace back to one single source: the large size of the KV cache. We use a 34B GPT-3.5 level model of 50K context on A100 NVLink as a running example, and describe how its large KV cache causes four types of deployment challenges:
  1. prefilling long inputs takes much longer compute time and GPU memory than short inputs;
  1. after prefilling, the large KV cache residing on the GPU HBM substantially restricts the number of concurrent users being served;
  1. during decoding, repeatedly reading the KV cache from HBM to SM largely increases latency;
  1. when KV cache memory overflows, swapping it from HBM to DDR causes significant context switching latency
We further analyze how existing efforts address the deployment challenges from these four perspectives and identify possibilities of combining them to build end- to-end efficient systems. We hope our work offers a foundational framework for analyzing long context transformer deployment and identifies important directions towards reducing the inference cost of 1M context to be as cheap as 4K.
Cite this work
 
Table of Contents
 

1 - A concurrent programming framework under limited GPU HBM size

Consider a 30+B 100K context GPT-3.5 quality open-source models like QWen or Yi, the differences between KV cache for 4K v.s. 100K context is:
Here we use the Yi-34B 200K configuration (60 layers, 8 kv heads and 128 hidden dimension). Suppose we use 2*80G A100 tensor parallelism to serve this model in bf16, then we have GB spare space for storing the KV cache. From this first glance, we immediately see that under this setting, we can achieve about 100+ users concurrency of 4K context, but only 5 users of 100K context. In fact, when actually deploying the model for serving multiple users, the workflow is usually like this:
notion image
Figure 1. A concurrent programming framework for serving multiple long-context user requests under limited GPU HBM size. There are four key factors collectively determine the overall throughput of user interaction sessions: (1) concurrency is bounded by the HBM size: the number of concurrent user being served is bounded by the size of the GPU high-bandwidth memory (HBM); (2) prefilling is compute bound: the latency to the first generated token, i.e., the prompt prefilling time, is bounded by the floating point operation per second (flops) of the GPU; (3) decoding is memory bound (under critical batch size): the latency (generated token per second) of autoregressive decoding is bounded by the bandwidth of the HBM; (4) context switching is PCIE bound: offloading user 1's KV cache to the CPU DDR and loading user 2's KV cache to the HBM is bounded by the PCIE bandwidth. All efficiency challenges from these four key factors eventually trace back to the size of the KV cache.

1.1 - Concurrent user interaction sessions and preferences

In a typical interaction session, a user starts from a prompt of a long document followed by a question, and feed it to the model. The model receives the initial prompt, prefill it to become the KV cache. The user wait for the prefilling stage until the first token start to generate and prefers the waiting time not be so long. After prefilling, the model starts autoregressive decoding. The user reads the output simultaneously with the decoding process and prefers the decoding to be faster than the human reading speed. After the model finishes decoding, the user continues to read the response, think, maybe take a sip of coffee, then start to type the next question. The followup prompts are usually not as long as the first prompt because the first prompt typically contains the long context (book or video) while the followup prompts are typically question-only. When the first user is reading the model response and thinking about what to ask next, the model is essentially idle, so at the same time if another user asks another question, the model could do context switching by offloading the first user's KV cache to the CPU DDR memory to make HBM space to store the second user's KV cache. The two users ask follow up questions interactively until their session ends.
 

1.2 - Session-based throughput as an end-to-end objective

We consider the situation where multiple users simultaneously interacts with the model. Assume on average, an user session consists of the document/ video of 50K tokens and 5 rounds of questions. After receiving the answer for the previous question, the user spend 1 minute reading the answer and thinking about the next question. Our objective is to maximize a session-based throughput defined as:
Note that this session-based throughput objective is different from existing token-based throughput (i.e., number of prefilled or decoded tokens in a given period). As we will discuss soon, token-based throughput is only part of the problem. Our session-based throughput, i.e., the number of concurrent user interactions in a given period, is an end-to-end objective, because we not only consider prefilling and decoding, but also consider memory restrictions and context switching.
 

1.3 - Compute v.s. memory boundedness, arithmetic intensity and critical batch size

One important observation of transformer inference is that prefilling is usually bounded by the GPU compute power, i.e., the flops, while decoding is bounded by the HBM memory bandwidth. We say an operator is compute bound if most of the time of finishing this operator is computing it on GPU's streaming multiprocessor (SMs, where GPU performs block-wise parallel computation). We say an operator is memory bound if most of the time for finishing this operator is moving the data from the memory to the SMs (instead of actually computing it on the SMs). Whether an operator is compute or memory bound depends on its arithmetic intensity, defined as how many floating point operation (FLOP) is performed per memory access operation (IO):
The higher level of parallelism, the higher flop per memory access, the more likely an operator is compute bound, the better we utilize the hardware. On a given GPU, the critical arithmetic intensity, i.e., the level of parallelism, for an operator to change from memory to compute bound is the ratio of its flop / memory bandwidth. For A100 it is:
For transformers, the level of parallelism is approximately how many tokens we feed into it, i.e., the batch size. This is to say, for A100, when our batch size is larger than 156 tokens, e.g., during prefilling the prompt has 50K tokens, we are compute bound and fully utilizing A100's compute power. When our batch size is smaller than 156 tokens, e.g., during decoding we only decode one token at a forward pass, we are memory bound and not fully utilizing A100's compute power.
 

1.4 - Prefilling

Now we analyze how long prefilling on A100 takes exactly. Since prefilling is compute bound, i.e., context length longer than 156 on A100, its theoretical peak latency is
For a prompt of 50K context length it is
Since 14.1 seconds is the theoretical peak, in Fig.1 we round it to 20 seconds to account for the implementation overhead. This means the actual implementation may achieve 14.1 / 20 70% of the theoretical peak performance, which is a common experience for cuda programming on A100.
If the context length is 4K instead of 50K, then repeating the above computation we get the latency 0.89 seconds. The difference here is
The 13 seconds gap, rooted from the additional flop from the long context, is what we eventually want to close.
 

1.5 - Decoding

Now we analyze how long decoding takes exactly. Since decoding is memory bound, i.e., batch size less than 156 on A100, the theoretical peak latency is
For decoding, one forward pass means
We assume on average the model generates one screen tokens (typically the user prefers the generation length right about one screen), i.e., about 250 tokens, then the peak latency is
Since 9.8 seconds is theoretical peak, in Fig.1 we round it to 12 seconds to account for the implementation overhead. If the sequence length is 4K, then its corresponding KV cache is only 0.91GB and the decoding latency reduces to 8.5 seconds. Yet if the sequence length increases to 200K, the KV cache becomes 44GB, the latency increases to 14 seconds. The relative latency increase is correlated with the relative size between the KV cache and the model size, and we eventually want to close it.
 

1.6 - Concurrency control and context switching

 
Another important consideration is that when the KV cache becomes large, the number of concurrent users' cache that the GPU HBM can hold is
This means that concurrency is bounded by the size of the HBM. Continuing with our 34B 50K model example, if we deploy it on one 80GB A100 we can only serve one user (Fig.1). But if the context is 4K, the KV cache is only about 1GB, and we can concurrently serve about 20 users.
When the second user comes to ask a question about a long document, to make room for their KV cache, we need to do context switching: offloading the first user's KV cache to the CPU memory, and load the second user's KV cache (Fig.1). This induces the context switching overhead:
This is to say, the context switching overhead is bounded by the PCIE bandwidth, i.e., how fast the GPU HBM is connected to the CPU DDR. Suppose we use PCIE gen 4 of 20G bytes per second, then the context switching overhead of 2 users of 50K context is:
In Fig.1 we round the 1.1 seconds to 2 seconds to account for the engineering overhead. As mentioned earlier, in our setting we can serve 20 users of 4K context length without context switching because the HBM is enough to hold their KV cache. If we like to increase our 50K concurrency to 20, then the overall context switching overhead also increases with concurrency:
This 22 seconds overhead do not exist in the 4K context regime but becomes problematic in the long context regime.
 

1.7 - Summary so far

we have discussed most of the details when deploying long-context transformers using the 34B model 50K context as the running example. We see that the overall performance, measured by the number of user interaction sessions in a given period, decomposes to four important metrics:
  • prefilling latency bounded by the GPU flops;
  • decoding latency bounded by the HBM bandwidth;
  • level of concurrency bounded by the size of the HBM;
  • context switching overhead bounded by the GPU-CPU connection bandwidth, i.e., the PCIE.
 
Below is a table summarizing the difference between 50K and 4K
50K
4K
Gap
KV size
11GB
1GB
10GB
Concurrency
1 user
20 users
19 users
Prefilling
14.1 seconds
0.89 second
13.2 seconds
Decoding 1 round
9.8 seconds
8.5 seconds
1.3 seconds
Decoding 5 rounds
49 seconds
42.5 seconds
6.5 seconds
Context Switching 1 user
1.1 seconds
0
1.1 seconds
Context Switching 20 users
22 seconds
0
22 seconds
 
In the next sections, we will discuss how these metrics change with context length and hardware architecture, and identify the bottleneck eventually trace back to the size of the KV cache.
 

2 - Factors that strongly influence the performance metrics

We start from two basic factors: context length and hardware architecture. When increased from 4K to 50K, we showthe four metrics (prefilling, decoding, concurrency and context switching) changes with different rate (linear, inverse, and quadratic). We further show that tensor parallelism improves concurrency, prefilling and decoding, but does not improve context switching. Overall the performance looks like this:
notion image
Figure 2. First row: how context length changes the four key performance metrics. Increasing the context length from 4K to 50K inversely reduces concurrency, quadratically increases prefilling latency, linearly increases context switching overhead, and slightly (but also linearly) increases decoding latency. Second row: how different generations of hardware influence the key performance metrics. Concurrency is measured by number of concurrent users. Prefilling, decoding, and context switching latency is measured by seconds.
 

2.1 - Context length

As shown in the first row of Fig.2, we compute the theoretical peak performance of the four metrics for the context length from 4K to 50K for our Yi 34B running example using the equations discussed in the previous section. We observe:
  • concurrency inversely decreases with longer context length;
  • prefilling latency quadratically increases with longer context length.
  • In comparison, decoding latency and context switching overhead only linearly increases with longer context length, and the decoding latency is the least influenced factor because 50K context KV cache is still relatively smaller than model parameters (11GB v.s. 68GB).
In general, concurrency and prefilling are the two most severely influenced factors.
 

2.2 - Hardware architecture

Can we improve the performance by simply using more advanced hardware? In Fig.2 second row, we show how the performance improvement tendency with hardware advancements. We observe:
  • concurrency linearly increases with the size of the HBM;
  • prefilling latency inversely reduces with the increased flops when upgrading the device from 4090 to A100 to H100;
  • decoding latency inversely reduces with the increased memory bandwidth;
  • context switching overhead inversely reduces with the increased PCIE bandwidth.
Note that the numbers of we use are based on the newest advances by May 2024, and even if we use the newest hardware, the cost gap between 50K and 4K are not closing.This is to say, we cannot count on hardware advances for reducing the cost of serving long-context models, and we have to make algorithmic innovations.
 

2.3 - Using multiple GPUs by tensor parallelism

As discussed in Kipply’s blog, utilizing multiple devices for accelerating inference with negligible communication overhead. In general:
  • Linearly increasing the number of devices to 2, 4, and 8 introduces more HBM space, thus linearly increasing concurrency.
  • Since we equally divide the computation on multiple devices, the prefilling and decoding latency also reduces inversely with the number of GPUs.
  • However, tensor parallelism cannot reduce the context switching overhead because the PCIE bandwidth between the DDR to the HBM is shared by all devices.
 

3 - Compressibility analysis and existing work

So far we have the following important observations when comparing 50K to 4K
  • To prefill the long input and produce the KV cache, the prefilling latency increases from 0.89 to 14.1 seconds;
  • Because the large KV cache residing on the GPU memory, the concurrency reduces from about 20 to 1;
  • During decoding, repeated loading the KV cache increases the latency from 8.5 to 9.8 seconds;
  • The large KV cache induces expensive context-switching overhead, for 20 users it takes about additional 22 seconds (Eq. 17). These four factors collectively induces significant cost in terms of the end-to-end session-based throughput.
If we could losslessly reduce the prefilling time and compress the KV cache, we may significantly reduce the cost of serving long-context models. Our eventual goal is to make the deployment of 1M context as cheap as 4K, and 4K tokens are about 1GB KV cache. Then our observations point to one key research problem:
💡
How to efficiently compress the 1M token KV cache to 1G bytes in a lossless way?
 
We first note that without any compression, storing 1M token into bytes only takes about 3 - 10MB disk storage (depending on the size of the tokenizer’s vocabulary), so 1GB is more than enough to store the full information of the input tokens. The problem is how to make their compressed representations usable by large transformers.
Practitioners usually test the model on a variety of long-context tasks to check if the compression is lossy, amoung which the Needle-in-a-Haystack test, which asks the model to precisely retrive the given information put at arbitrary location of a long context, serves as an entry barrier: if a model cannot pass this test, we do not trust it can do harder tasks.
Unfortunately, it seems that two important model families, state-space models (e.g., Mamba) and linear attention (e.g., LongT5), cannot pass the needle test, so we do not include them into our discussion. Our recent work shows that there exists a set of special attention heads responsible for retrieving imporant information from the context. Their discovery indicates that at least for some layer and some head, the full attention over most of the input tokens should be retained – these attention heads may not be that compressible. Below we discuss the compressibility of the KV cache from its four dimensions: layer, head, token and hidden, and how existing work improve long-context inference. C: concurrency, P: prefilling, D: decoding, S: context switching:
Desc.
Improves
Needle?
Layer
Early exit based on estimated confidence
C | P | D | S
?
Conditonally reducing computation on some layer
C | P | D | S
?
Skipping some layers then verify
C | P | D | S
?
Use only one global KV cache
C | P | D | S
Head
Reusing KV cache for groups of heads
C | D | S
Removing non-retrieval heads
C | D | S
Using latent head
C | P | D | S
Token
Dropping insignificant tokens after prefilling
C | D | S
?
Identify important tokens during prefilling
C | D | S
?
Dynamically merge tokens
C | P | D | S
?
Identify important tokens based on user questions
D
Speculative decoding for long-context
D
Hidden
KV cache quantization
C | D | S
?
Weight and KV cache quantization
C | D | S
?

2.1 - Layer

For the layer dimension, the basic hypothesis is that some tasks may not require the full-depth computation. Skipping some layers during prefilling could be beneficial to all four metrics because it simultaneously reduces prefilling flops and the size of the KV cache. In fact, the layer dimension may be radically reduced from the results of existing works like Memorizing Transfomers and YOCO , and it might be possible to only keep one layer KV cache for long-context tasks, which is a 1/60 compression ratio.
 

2.2 - Head

For the head dimension, the basic hypothesis is that some heads are specialized for retrieval and long-context related capabilities, so it may be possible to retain retrieval heads and prune others. Note that head pruning typically happens after prefilling, meaning that they only improve decoding, concurrency and context-switching, but prefilling remains the same expensive. In general, it seems that at the head dimension there is a very high level of sparsity, and the number of heads might be possible to radically remove to a very small number, e.g., we show the number of strongest retrieval heads is less than 20, which could induce a 20 / 1024 compression ratio.
 

2.3 - Token

For the token dimension, the basic hypothesis is that if information of a token can be inferred from its context, we can compress this token by either dropping it or merging it with its neighbors. Most of the compression at the token dimension does not much improve prefilling, but they typically improves concurrency, decoding and context switching. Currently, it seems that the token dimension might not be as sparse as the layer and head dimension because most of the tokens have to be retained for precise retrieval. We have not yet seen any work showing the potential of more than 50% compression ratio on the token dimension.
 

2.4 - Hidden

There is not much work on further reducing the dimension except for quantization, presumably because the hidden size is already 128, too small to be reduce further. Yet it may still be worth trying applying dimension reduction like LoRA on the KV cache, particularly given the recent progress from DeepSeek V2 which introduces LoRA-like idea that effectively reduces
the size of KV head.
 
An important observation here is that many existing works may only emphasize one aspect of the problem. For example, TriForce [ 18 ] only considers the decoding latency using speculative decoding. It does not make the KV cache smaller and even has a tradeoff of increased GPU memory consumption from the draft model. Many existing works are also orthogonal, such that their advantages from different aspects may join force. For example, if we could reduce the KV cache to be only 1 layer or 10 heads and keep only 50% of the tokens, we will have about 1000x performance improvements. This naturall leads to one call for research:
💡
Can we integrate existing efforts into an end-to-end system and push full-stack optimization?
 

3 - Application: analyzing Gemini 1.5 Flash 1M context QA

In this section, we use a textbook QA example to directly estimate the prefilling, decoding and context switching latency through Google AI Studio , and show how our analytical framework identifies the concurrency and context-switching as the two major system bottleneck. We use Gemini 1.5 Flash, the latest (May 15 2024) state-of-the-art long context understanding language model for our experiments.
 
I use the book Computer Architecture: A quantitative Approach, the classical textbook as an example. After tokenization, it is about 912,936 tokens.
Prefilling this 912K tokens takes about 64 seconds until the model start to generate the response.
This 64 seconds is not satisfactory — basically the user needs to wait there being idle. Improving this 64 seconds to be within 5 seconds should be the next research direction for the Gemini team.
 
notion image
 
 
 
notion image
notion image
After the first round conversation, there are 913,808 tokens in total. This means that decoding is about 913808 - 912936 = 872 tokens. The final time is 141 seconds. So decoding takes 141 - 64 = 77 seconds in total, which is 872 / 77 = 11 tokens per second. This rate is faster than my own speed.
Then I asked a follow up question about GPUs. The model takes 56 seconds to start generate the response. This latency is the combination of context switching and prefilling my follow up question. Since prefilling this followup question is quite fast, most of the time might be spent on context-switching.
 
Then the model takes about another 93.3 - 56.6 = 36.7 seconds to decode the answer to the follow up question.
 
 
 
Then I realized that the model uses the word “delve into” instead of “dive into” and I am not sure how the two phrases compare. So I ask the model to tell me the difference between the two. The problem here is, it still takes the model 60 seconds to reload the previous KV cache — yet to answer this question the model does not need the previous context!
notion image
 
 
 
notion image
 
notion image
 
Here we see that whether the model should use the input information or its internal knowledge is not only a modeling problem, but also interplays with deployment efficiency. This reemphasizes the need for algorithm-system co-design, which is also the major focus of this post.
In summary, we list the comparison between the real-world Gemini 1.5 Flash performance (which is after sophisticated optimization) and our theoretical peak performance of Yi 34B 200K (but without any optimization).
Gemini 1.5 Flash (highly optimized) 1M context
Theoretical Yi 34B (no optimization at all) 200K context
Prefilling
64 seconds
99 seconds
Decoding
77 seconds
41 seconds
Context switching
56 seconds
45 seconds
These numbers are not the same but of the same order of magnitude, verifying our estimate. We also highlight that even Gemini’s 64 seconds prefilling time for 1M token is still too long for human preference, it is a significant improvement over our 99 seconds estimate on 200K tokens.
For Gemini, since decoding is 8 tokens per second, already faster than the human reading speed, their directions for optimization should be the prefilling latency and context switching. As a general user I can understand reading a book takes a while, so the prefilling latency, although 64s minute, is OK and I won’t blame too much. But the context switching latency of 56 seconds is definitely too long to tolerate — one can wait for it to start, but one may not wait for every round of the conversation. So eventually, we identify concurrency and context-switching as the current major system bottleneck. Specifically:
  • Concurrency is more related to the overall deployment cost;
  • Context-switching is more related to user preference.

4 - Conclusion: towards full-stack optimization of an end-to-end system

In this post, we give a detailed analysis on the challenges in deploying long-context transformers. Our eventual objective is to reduce the serving cost of 1M context to be as cheap as 4K, such that we can democratize emerging AI applications like video understanding and generative agents. We describe a concurrent programming framework to illustrate the end-to-end user interaction session based throughput, and decompse it into four key performance metrics: concurrency, prefilling, decoding, and context switching. We discuss how common factors influence the four metrics and how existing work focus on different metrics. We believe there are great research opportunities to integrate existing efforts to build strong end-to-end long-context serving system and believe that this work can serve as a framework for full-stack long-context inference optimization.
 
Remarks: currently, we are actively research algorithm-system co-design for improving long-context deployment efficiecy. We welcome comments, pointers to existing works, criticism on the limitations of our analysis, and all related topics. So if you are interested in this direction, please definitely get in touch!
 

  • Utterance

文章数:
29
访客数:

公众号/知乎/雪球同名