Posts Flux and Time Series Batching
Post
Cancel

Flux and Time Series Batching

Following one of my previous posts, I recently received an email asking me about batching time series data in Flux and whether I could find the time to write about it. I am thrilled that this blog is also helpful to other people, not just myself. Thus I will happily give a brief discussion and examples on the subject matter in this post.

Once again, this post follows a purely applied approach. I do not want to discuss what type of batching is the right choice under which conditions, as it is not a subject I am comfortable with yet. Perhaps at a later date. In the meantime, I will only discuss ways of implementing different batch approaches in Julia.

Different batching methods

Let us consider some time series data \(\{x_t : t \in \mathbb{Z}\}\), with \(x_t \in \mathbb{R}^k \ \forall t\). Hence, \(x_t\) is not necessarily univariate as in the above-cited post, but it could now also be multivariate. Furthermore, we are trying to predict some series \(\{y_t : t \in \mathbb{Z}\}\). To relate this setup to the one in the previous post, this is equivalent to setting \(k=1\) and \(y_t \equiv x_{t+1}\). For the first part of this post, one can consider \(x_t\) to be univariate. The multivariate aspect is only relevant for the Julia examples as the univariate specification is nothing but a special case of the multivariate definition.

While we could use the entire sequence of \(x_t\) to feed as input to our recurrent model, this is something we typically want to avoid. Instead, we want to split our input series into batches.

For instance, consider that we have a series of \(x_t\) of length \(T=100\), but we are only interested in using sequences of length \(s=20\) as input to our recurrent network.

Non-overlapping batches

The first way we can split our data is to create partitions, i.e., non-overlapping batches. Thus, in the example mentioned above with \(s=20, T=100\), we would break \(\{x_1, x_2, \dots, x_{99}, x_{100}\}\) into \(b=5\) sequences of length \(s=20\) as follows:

\[\begin{align*} \{ &\{x_1, x_2, \dots, x_{19}, x_{20}\}, \\ &\{x_{21}, x_{22}, \dots, x_{39}, x_{40} \}, \\ &\{x_{41}, x_{42}, \dots, x_{59}, x_{60} \}, \\ &\{x_{61}, x_{62}, \dots, x_{79}, x_{80} \}, \\ &\{x_{81}, x_{82}, \dots, x_{99}, x_{100}\} \}. \end{align*}\]

Overlapping batches

Another less common way of splitting the data consists in building overlapping batches. This approach is comparable to sliding a window on our time series \(x_t\). For \(s=20\) and sliding the window by \(r=1\) observations only, this yields the following data:

\[\begin{align*} \{ \{x_1, x_2, &\dots, x_{19}, x_{20}\}, \\ \{x_{2}, x_{3}, &\dots, x_{20}, x_{21} \}, \\ \{x_{3}, x_{4}, &\dots, x_{21}, x_{22} \}, \\ &\dots \\ \{x_{79}, x_{80}, &\dots, x_{97}, x_{98} \}, \\ \{x_{80}, x_{81}, &\dots, x_{98}, x_{99} \}, \\ \{x_{81}, x_{82}, &\dots, x_{99}, x_{100}\} \}. \end{align*}\]

We see how this approach produces much more batches than the partition method. Indeed, the non-overlapping procedure yields \(b=\frac{T}{s}\) sequences (rounded up or down to the nearest integer, depending on whether we use padding or discard observations), while using the overlapping technique yields \(b=\frac{T-s+r}{r}\) sequences (once again, rounded up or down to the closest integer). Notice how the partitioning technique described above is a particular case of this second approach when \(r = s\).

So why is the non-overlapping method used more often than the overlapping one if it produces less data? As we can see, there is less correlation between each sequence. Indeed, in the partition, every element is only represented once in each series. However, in the overlapping approach, the first sequence has \(s-r\) elements in common with the second sequence, which has \(s-r\) elements in common with the third one, and so forth. This property is typically not a desirable one, but, as mentioned above, this post is about providing Julia examples, not discussing the theory.

Batching time series in Julia

Now comes the fun part: some Julia examples!

Recall that in Flux, the recurrent network data format is a vector with the size of our desired sequence length, and each element of this vector is a matrix where the rows are the features, and the columns are the different batches.

Let’s go ahead and build a function to batch tabular data into recurrent sequences for Flux. Our function will take as input time series data X, a desired sequence length s, and a shifting integer r. As mentioned above, when s=r, we obtain the non-overlapping batches. Additionally, when s=size(X, 1), it produces the non-batching approach, i.e., it yields the entire sequence as input, akin to my first post.

Importantly, this function discards the first observations to ensure each series is the same length. If you don’t want to discard anything, you must pad your sequence such that \(T\) is a multiple of \(s\) (in the non-overlapping case) or \(T-s\) is a multiple of \(r\) (in the overlapping case).

The function expects your input to be in \(k \times T\) matrix form, i.e., the observations are the columns, and the features are the rows. Often, our data is shaped differently, i.e., \(T \times k\). In such a case just use permutedims(X) before calling the function. In Julia, arrays are in column-major order, hence having the data in a \(k \times T\) is faster!

Sometimes, describing code in words is more complex than the code itself. The above paragraph is, without doubt, such a case. The following should help clarify:

1
2
3
4
5
6
7
8
9
10
# Create batches of a time series
function batch_timeseries(X, s::Int, r::Int)
    if isa(X, AbstractVector)       # If X is passed in format T×1, reshape it
        X = permutedims(X)
    end
    T = size(X, 2)
    @assert s  T "s cannot be longer than the total series"
    X = X[:, ((T - s) % r)+1:end]   # Ensure uniform sequence lengths
    [X[:, t:r:end-s+t] for t  1:s] # Output
end

Et voilà! We now have a simple function to deal with batches of length s for time series data. I highly recommend you play with this function around to understand better what is happening, as list comprehension can be somewhat challenging to grasp (for me, at least!). Ideally, try it out with inputs such as

  • batch_timeseries(collect(1:100), 20, 20) (the non-overlapping case described above)
  • batch_timeseries(collect(1:100), 20, 1) (the overlapping case with shift size 1)
  • batch_timeseries(permutedims(hcat(1:100, -1:-1:-100)), 20, 10) (a multivariate example)

Using this batching function with Flux

Finally, we can create a simple recurrent model in Flux to ensure this batching function produces an output that Flux recurrent models accept:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
using Flux
using Statistics
T, k = 100, 12 # 100 observations, 12 features
# Create a recurrent model
m = Chain(
    LSTM(k, 64), # Notice the k inputs
    LSTM(64, 1)
)
# Create a T×k matrix of input data, don't forget Float32 !
X = randn(Float32, k, T) # k×T input
y = randn(Float32, 1, T) # Univariate output
# Transform the data to a format we can use in our input
X_rnn = batch_timeseries(X, 20, 20) # Non-overlapping batches of size 20
y_rnn = batch_timeseries(y, 20, 20)
# Run the LSTM model
θ = Flux.params(m) # Keep track of the parameters
opt = ADAM() # Optimizer
for epoch  1:10
    Flux.reset!(m) # Reset hidden state
    # Compute gradients
     = gradient(θ) do 
        # Warm up model
        m(X_rnn[1])
        # Compute MSE loss on rest of sequence
        Flux.Losses.mse.([m(x) for x  X_rnn[2:end]], y_rnn[2:end]) |> mean
    end
    Flux.update!(opt, θ, )
end

Great! Everything works as expected. Notice that this model is a sequence-to-sequence model, i.e., it takes in a sequence \(\{x_1, x_2, \dots, x_s\}\) and produces an output of matching size \(\{\hat{y}_1, \hat{y}_2, \dots, \hat{y}_s\}\).

I am currently writing another post on sequence-to-one modeling in Flux. However, this post might take more time to be published here due to an upcoming conference.

Recovering the output

We now have a method to reformat tabular data for Flux recurrent networks correctly. Unfortunately, the output of recurrent models is also in the same data format. This format can sometimes be annoying, e.g., when trying to plot our results. Hence, let’s finish this post by transforming our outputs back into tabular format.

In the case of non-overlapping batches, this is straightforward. We can append the outputs of the recurrent model together. However, suppose we decided to use overlapping sequences. In that case, our outcomes will also be overlapping, and thus, we must choose how to concatenate them back to a tabular format.

Thus, we will create an array of outputs for each sequence. In the non-overlapping case, we can then concatenate this array to retrieve the prediction on the full sequence.

Univariate output case

If you are outputting a univariate variable, which will be the case most of the time, you can use vcat(y...), where y is your network’s output. This will create a matrix with \(s\) rows and \(b\) columns, i.e., each column is a different batch, and each row is a different timestep in that batch.

This is much simpler than the multivariate case and what I use in almost all of the cases I encounter. The only real issue is with multivariate outputs.

If we used non-overlapping batches, we can use vec(vcat(y...)) to retrieve the output vector from this \(s \times b\) matrix.

Multivariate output case

For the generic case where y might be multivariate, we could instead use the following function:

1
2
3
4
5
6
7
8
# Creates an array where each element is the tabular version of each sequence 
# output by the recurrent network
function retrieve_rnn_output(y) 
    [
        vcat([permutedims(y[i][:, j]) for i  1:length(y)]...) 
        for j  1:size(y[1], 2)
    ]
end

This function yields an array of size \(b\), where each element is a matrix of size \(s \times m\), with \(\hat{y} \in \mathbb{R}^m\).