Chunked Scan

rnn
Author

Volodymyr Kyrylov

Published

December 29, 2023

We have a linear recurrence:

\(x_t = \lambda_t \cdot x_{t-1} + u_t\)

We would like to evaluate it in chunks of \(T\) elements: one chunk is evaluated by one processor. When all chunks are ready, one processor can integrate results of another.

Let \(x_{-1}\) be the final element of the previous chunk. We’re ready to define the chunked recursion inductively:

The base case:

\(x_0 = \lambda_0 x_{-1} + u_0\)

Here \(\bar{x}_0 := \lambda_0 x_{-1}\) can be seen as a port, where we will connect the previous chunk. By linearity we can treat \(x'_0 := u_0\) as a base case of the chunk.

Using associativity we have:

\(x_1 = \lambda_1 \lambda_0 x_{-1} + \lambda_1 u_0 + u_1\)

\(x_T = \prod_{i=1}^{i=T} \lambda_i \bar{x}_0 + \lambda_t x'_{T-1} + u_T\)

This leads us to the following algorithm:

  1. compute the recurrence for \(x'_T\)
  2. compute cumulative products (e.g. using a zero-order scan) for \(\Lambda_t := \prod_{i=0}^{i=t}\lambda_i\)
  3. receive the value of \(x_{-1}\) from the previous neighbor
  4. compute \(x_t = \Lambda_t x_{-1} + x'_t\)
  5. send the value of \(x_T\) to the next neighbor

References:

[Ble93] Prefix Sums and Their Applications. Guy E. Blelloch https://www.cs.cmu.edu/~guyb/papers/Ble93.pdf