Why S4 is Good at Long Sequence: Remembering a Sequence with Online Function Approximation
 
Yao Fu
University of Edinburgh
yao.fu@ed.ac.uk
Feb 01 2022
 
The Structured State Space for Sequence Modeling (S4) model achieves impressive results on the Long-range Arena benchmark with a substantial margin over previous methods. However, it is written in the language of control theory, ordinary differential equation, function approximation, and matrix decomposition, which is hard for a large portion of researchers and engineers from a computer science background. This post aims to explain the math in an intuitive way, providing an approximate feeling/ intuition/ understanding of the S4 model: Efficiently Modeling Long Sequences with Structured State Spaces. ICLR 2022
 
We are interested in the core question:
💡
Why S4 is good at modeling long sequences and where does its magic matrix come from?
 
Skipping all the math, the most straightforward answer is:
💡
Because it uses to remember all the history of the sequence
 
By “remember all the history” (which is mentioned at multiple places in the paper), we actually mean:
💡
Because it encodes all the history of the sequence into a hidden state (say a 500-dimensional vector), with this vector, we can literally reconstruct the full sequence (even if the sequence is long, say a 16000 length array).
 
Or to be more concise, we say:
💡
With S4, we can use a 500-dimensional hidden state to reconstruct the 16000 length input sequence
 
Or similarly:
💡
With S4, we compress the 16000 length input sequence into a 500-dimensional hidden state
 
This is why S4 is good at modeling long sequences (because it remembers/ encodes/ compresses the full sequence). The core technique is online function approximation: to dynamically approximate a function whose value is gradually revealed as time goes by. To be more detailed, we will go through the following story:
  1. We view the input sequence as a function of time (the time is continuous, obviously)
  1. Then we approximate this (probably complicated) input function with a linear combination of a set of pre-defined simple functions (i.e., polynomial basis)
  1. We only store the coefficients (of the linear combination) as the representation of the input function
  1. These coefficients are obtained by solving a linear ODE whose parameters are the magic matrix .
  1. When transforming the (continuous) time to discrete timesteps, the (continuous-time) linear ODE becomes a linear RNN (with discrete steps).
  1. Consequently, these coefficients of function approximation become the hidden states of this RNN
  1. At any time step, given the hidden states (coefficients), we can use them to linearly combine the polynomial basis, which gives an approximation of the full input function
 
Table of Content:

Preparations

We assume the reader has already read the S4 paper (in a confusion). Despite the confusion, it is important to differentiate the two representations: the convolutional and the recurrent.
Specifically, we need to know:
  • The convolutional representation is primarily for computational efficiency, i.e., fast training
  • It is the recurrent representation that does the magic, by using the HiPPO matrix (section 2.2) as the transition matrix of a linear RNN
  • The major contribution of S4 is NOT the HiPPO matrix, and the authors do NOT really explain why the HiPPO matrix does the magic in the paper
    • The major contribution is how to compute S4 efficiently, assuming having the HiPPO matrix at hand
    • The magic of the HiPPO matrix is explained in a previous paper: HiPPO: Recurrent Memory with Optimal Polynomial Projections. NeurIPS 2020
 
So technically, in this post we do not really explain S4, we are explaining the HiPPO NeurIPS 20 paper instead.

Encoding a Function by Linear Approximation

(This section roughly explains Sec. 2.1 of the HiPPO 20 paper)
 
The origin of the HiPPO paper is from function approximation: to approximate a complicated function with a list of easier functions. One example of function approximation is the Fourier transform: to approximate a complicated function with a linear combination of simple sinusoids:
 
Figure 1. Approximating the red curve with a linear combination of the blue sinusoids. The red curve (target function) is a weighted sum of the blue curves (sinusoids), where the weights (i.e., linear coefficients) are represented as the blue bars on the lower-right part of the figure.
Figure 1. Approximating the red curve with a linear combination of the blue sinusoids. The red curve (target function) is a weighted sum of the blue curves (sinusoids), where the weights (i.e., linear coefficients) are represented as the blue bars on the lower-right part of the figure.
In the HiPPO paper, instead of using the sinusoids, the authors use polynomials as the basis functions for the approximation.
 
The theory of function approximation is based on Functional Analysis. For now, we can just think of it as a (high-end) version of linear algebra defined on the space of functions (rather than the space of vectors). The theory is quite similar to linear algebra. In the linear algebra case, recall that any vector can be expressed as a linear combination of basis vectors
In functional analysis, we treat a function as if it is a vector. So it can also be expressed as a linear combination of basis functions
Pay attention to the similarity between the two above equations. Below is the comparison between the vector case and the function case:
  • is a vector, is a function (but we view is as if it is a vector)
  • basis are vectors, basis are functions
  • In the actual vector case, the dimension of the space (= the number of basis vectors) is , which is finite
  • In the function case, the dimension of the function space, thus the number of basis functions, is infinite
In practice we cannot use an infinite number of basis functions, so we truncate our approximation with a finite number of basis. The larger is, the less approximation error we have (we skip details of the approximation errors for now).
 
So the question becomes which polynomials should we use for approximating a function? The answer is to use the scaled Legendre polynomials (Section 3 in the HiPPO paper) because they have multiple advantages (skipped for now). The first six of them (recall they are infinite series of functions) look like this:
Figure 2. Visualization of Legendre Polynomials. These are just the polynomial figures that we learned in high school.
Figure 2. Visualization of Legendre Polynomials. These are just the polynomial figures that we learned in high school.
Do recognize the orange straight line P1 and the green parabola P2. As the order gets higher, the more twists there are, which is what we have learned about polynomials (in high school).
 
When using them to approximate a given function, the procedure looks like this:
Figure 3. Approximating the red curve with a linear combination of Legendre polynomials
Figure 3. Approximating the red curve with a linear combination of Legendre polynomials
where the red line is that target function and the blue line are the Legendre polynomials . The left bars represent the coefficients
 
Suppose is complicated and cannot be written down analytically, yet we have found a way of computing the coefficients (which turns out to be the method of least-square, exactly the same as the linear algebra case), then we can use the Legendre polynomial approximation
where the information of is encoded by the coefficients . If we want to know the value of at any time , we just plug in: .
 
We will see these coefficients will correspond to the hidden state of a special RNN and will be dynamically updated by the special HiPPO matrix used in the paper.
 
 

Online Function Approximation

(This section roughly corresponds to Sec. 2.2 of the HiPPO paper)
 
Now we change the setting a little bit and assume that we do not observe the full defined on at once, instead, we assume we gradually observe up to time , as gradually increases, say up to . Then we want to gradually update our approximate of (as the time increases) like the following:
Figure 4. Online function approximation. As time goes by, we observe more values of the target function, and dynamically adjust our approximation.
Figure 4. Online function approximation. As time goes by, we observe more values of the target function, and dynamically adjust our approximation.
where is defined as the target function only revealed before time and is the corresponding estimates. We note:
  • At time , we only observe the function (black line) up to . So our approximate is:
    • where the coefficients also depend on because they are tailored for a chunk of on the interval (we usually assume ).
  • As the time goes from to , we observe more values of . Our previous approximation at no longer fit well at interval , so we would like to adjust the coefficients to to incorporate the two intervals, namely and , this is where the name online comes from
  • After the adjustment of the coefficients at , the resulting approximate deviates from the previous , but becomes more accurate on interval
  • As time goes by, we observe more and more chunks of , then the coefficients are updated dynamically/ online. At the end of the day (say time ), we have the approximate:
    • which we only store the coefficients
  • Again, the reconstruction of any previous before can be achieved by evaluating the approximation :

    Finding Coefficients by Solving an ODE

    (This section roughly corresponds to Sec. 2.1 - 2.2 of the HiPPO paper)
     
    Now the question is how do we compute the coefficients ? The method is, again, similar to linear algebra: we project the target function to the basis function by computing the inner product:
    Note that here I rewrite to to indicate that they are functions of the end time. Be careful when I say the word “inner product”: as and are functions, we need to define what the inner product between two functions are (they are not the good old inner product between vectors). In functional analysis, the inner product of two functions with regard to a measure , denoted , is defined as the integral of the product of two functions with regard to a given measure :
    We will discuss the measure later. For now, just think of the integral as the normal integral that we have learned in elementary calculus (this is to say, do not be frightened by the symbol — they will be transformed to be elementary calculus). Differentiating the above function for all and doing a little bit (a large amount of) of math will give us an ODE whose solutions are the coefficients (see HiPPO paper Appendix D.3):
    where collectively denote the coefficient vector: and is the magic matrix in the S4 paper (note that depend on the choice of the measure, here particularly the scaled Legendre measure which we discuss in the next section).
    The sketch of the underlying math (primarily Appendix D.3) is as follows:
    1. Differentiate w.r.t. time for all (Appendix C.3)
    1. Eliminating the integral using the recurrence relations on Legendre Polynomials(Appendix B.1.1 + D.3)
    1. Vectorizing all the resulting differential equations will give use the above ODE, the matrix will emerge at this step (Appendix D.3)
    This is where the magic matrix comes from.

    Dynamically Scaling Polynomial Basis

    (This section corresponds to Sec. 3 of the HiPPO paper)
     
    Things seem to be going fine till now, yet there is a caveat missed: what if the time goes on and on and beyond the domain of the basis, like this:
    Figure 5. The caveat in online function approximation. The target function may go beyond the interval where the basis functions are defined upon.
    Figure 5. The caveat in online function approximation. The target function may go beyond the interval where the basis functions are defined upon.
    To address this issue, the HiPPO authors propose to dynamically scale the basis function to be within the same domain of the target function, like this:
    Figure 6. Scaling the basis function dynamically as the target function goes.
    Figure 6. Scaling the basis function dynamically as the target function goes.
    • Initially, we observe the target function within the domain , then our basis are defined on
    • As time goes by, we observe the target function on the new interval , then we scale our basis functions to the larger interval
    • This process goes along with the new observation of : whatever interval we observe the target function, we always scale our basis upon that interval.
    After this dynamic scaling, the basis functions will be depend on , as written below:
     
    In the HiPPO paper, the authors scales the basis functions by defining a new measure, the scaled Legendre measure:
    upon which these scaled polynomials can be induced. We will skip the math about how to derive the polynomial basis with regard to this measure (which does the scaling automatically), but the general workflow has three steps:
    1. Define a measure (in this case, the scaled Legendre measure, which depend on the end time )
    1. This measure will induce an inner product between two functions
    1. Performing the Gram-Schmit Orthogonalization with the induced inner product (which depend on thus the end time ), we induce an (infinite) series of polynomial basis functions (which is scaled to the end time through its dependency on )
    The above process is detailed in Appendix B.1 in the HiPPO paper.
     
    Plugging this measure, its induced inner product, and its induced polynomial basis into the ODE in the previous section, we obtain the magic matrix (details in Appendix D.3 of the HiPPO paper).
     

    Getting an RNN by Discretizing the ODE

    (correspond to Sec. 2.4 of the HiPPO paper)
     
    Now we finally come to the step of deriving the magic RNN that can remember all the history of the input sequence. This is achieved by discretizing the previous ODE whose solutions are coefficients of the function approximation. We firstly copy from the paper how the procedure goes:
    (i). consumes an input sequence (ii). implicitly defines a function where for some step size (iii). produces a function through the ODE dynamics (iv). discretizes back to an output sequence
    Skipping all the math, we have
    Before:
    After:
    again, is the magic HiPPO matrix derived in Appendix D.3 by solving the online function approximation ODE. A little bit of math derivation can recover the full S4 recurrence from here. Importantly, is viewed as the hidden state of the S4 RNN.
     
    Additionally, are also the coefficients for our online function approximation at time , recall that the history can be reconstructed by:
    where are the dynamically scaled polynomial basis up to time discussed above. Replacing with any time before , we approximately recover the history of .
     
    Here is a table of differences between continuous-time and discrete timesteps
    Continuous-time
    Discrete timesteps
    Time t is a continuous variable
    Time t is a discrete sequence
    Input is a function of time
    Input is a discrete sequence
    Coefficients are functions of time
    Coefficients are discrete sequences
    Coefficients obtained by solving a linear ODE
    Coefficients updated by running a linear RNN, and are viewed as RNN hidden states
     

    Summary

    In this post, we explain the math behind the S4 paper and where the magic matrix in the S4 paper comes from. The core theory is to approximate a function online and scale the basis functions adaptively. The magic matrix is derived for solving the ODE whose solutions are the coefficients of the function approximation. Discretizing the continuous ODE into discrete sequences gives us the RNN recurrence that is used in the S4 paper.
     
    Now is the time to actually run this RNN and see if its hidden states, i.e., the coefficients of the function approximation, can remember all the history input. Yet we immediately encounter an efficiency problem: if the input sequence is super long, say 16000 timesteps, we need to run matrix multiplication 16000 times, which is extremely inefficient.
     
    The method of speeding up the training of this magic RNN is another story, that is, exactly the story of the S4 paper. See The Annotated S4 for an explanation.
     

    • Utterance

    文章数:
    29
    访客数:

    公众号/知乎/雪球同名