Working with Imbalanced Datasets
Post
Cancel

# Working with Imbalanced Datasets

Say that you are working on a simple binary classification task. For instance, you have a dataset of credit card transactions, and you want to classify each trade as either legitimate or fraudulent. As you might imagine, there will be a substantially lower amount of fraudulent transactions than legitimate ones. In this post, I will discuss how such a case of imbalanced data can negatively impact your model’s performance and what one can do to remediate the issue.

## Generating toy data

First, since I do not have a credit card transaction dataset, we will have to generate some imbalanced data ourselves. This choice is somewhat arbitrary, and we could also use other data-generating processes (DGP). For no specific reason, we will choose the following DGP:

\begin{align} y &= \mathbb{1}\{| f\left(\mathbf{X}\right) | \geq 8\}, \ \text{ with}\\ f(x_1, x_2, x_3) &= \frac{\sin(x_1) \cdot \sqrt{|x_2|}}{x_3} + \epsilon, \ \text{ with } \epsilon \sim \mathcal{N}(0, 1) \end{align}

This DGP might look awkward, but the main idea is to have a highly nonlinear DGP with a solid imbalance between classes. This DGP will give us around 5% of fraudulent transactions (or $$y=1$$ labels), you can convince yourself of this fact by trying a Monte Carlo simulation. Many other DGPs could also achieve the same result.

1 2 3 4 5 6 7 # Data generating process f(x) = sin(x) * sqrt(abs(x)) / x + randn() function generate_batch(batchsize) X = randn(3, batchsize) # Generate random features y = abs.(f.(eachcol(X))) .≥ 8 X, y end 

## Normal model training

Armed with our DGP, we now turn to the actual training of our model. We train the model with batches of size 512. We will use a straightforward deep feedforward neural network architecture with an Adam optimizer and 250 training epochs.

1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 # Set random seed Random.seed!(72) epochs = 250 # Train the model for 250 epochs # Create our first model and initialize optimizer m₁ = Chain( Dense(3 => 64, relu), Dense(64 => 64, relu), Dense(64 => 1, σ) ) opt₁ = ADAM() θ₁ = Flux.params(m₁) # Training loop for epoch ∈ 1:epochs # Generate random train batch of size 512 X, y = generate_batch(512) ∇ = gradient(θ₁) do # Compute gradients Flux.Losses.binarycrossentropy(m₁(X) |> vec, y) end Flux.update!(opt₁, θ₁, ∇) end # Assessing model performance X, y = generate_batch(100_000) ŷ₁ = vec(m₁(X) .> .5) acc₁ = 100mean(y .== ŷ₁) # 95.186 

Our model achieves an accuracy of 95.186%, quite impressive! Or is it? Looking at the results in more detail, our model predicts every observation to be legitimate (i.e., $$y=0$$), making for a rather useless model.

## Improving our model

The above exercise begs the question: how can we proceed to improve our model? Unfortunately, as is more often the case than not, the answer is: it depends.

There are multiple changes that we can attempt to make.

1. If we have a virtually infinite amount of data (as in this case), we could resample our batches and ensure the imbalance is not too significant. For instance, we could have around 30-50% of positive (fraudulent) observations in each training batch.
2. We could use another loss function. We might want to use recall, sensitivity and specificy, and ROC curves to finetune our model. Perhaps precision is not what interests us the most; the above is a prime example of such a case.

## An improved example

Let’s have a quick look at how the resampling idea might improve our training. Instead of training our model on the full batch of 512 observations, we now discard legitimate observations at random such that we obtain a 25% - 75% ratio in our training batch. This is still imbalanced but much better than the original 5% of fraudulent observations.

1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 # Set random seed Random.seed!(72) # Create our second model and initialize optimizer m₂ = Chain( Dense(3 => 64, relu), Dense(64 => 64, relu), Dense(64 => 1, σ) ) opt₂ = ADAM() θ₂ = Flux.params(m₂) # Training loop for epoch ∈ 1:epochs # Generate random train batch of size 512 X, y = generate_batch(512) idx₁ = findall(isequal(1), y) # Find all observations of class 'fraudulent' (1) idx₀ = sample(findall(isequal(0), y), 3length(idx₁), replace=false) # Select sample of class 'legitimate' (0) X = X[:, vcat(idx₀, idx₁)] y = y[vcat(idx₀, idx₁)] ∇ = gradient(θ₂) do # Compute gradients Flux.Losses.binarycrossentropy(m₂(X) |> vec, y) end Flux.update!(opt₂, θ₂, ∇) end # Assessing model performance X, y = generate_batch(100_000) ŷ₂ = vec(m₂(X) .> .5) acc₂ = 100mean(y .== ŷ₂) # 95.712 

Our accuracy is now 95.712%. So, was this all for a measly .6 percentage points accuracy increase? Not exactly. As discussed above, there is more to our problem than precision. For instance, the number of true/false positive/negatives gives a better insight into the models’ performances:

Model True Positive True Negative False Positive False Negative
Normal training 0 95’186 4’814 0
Resampled training 4’661 91’051 97 4’191

The sensitivity (true positive rate) of the model trained using resampled batches is now 97.96%. It is hard to argue that such a model is not much more helpful than the first one when predicting fraudulent credit card transactions! The resampled training is still far from perfect, and there are still a few tuning options we could use to improve it, but that is beside the point of this post.

In all fairness, given an infinite amount of data, recovering the true DGP is not a challenging task. With enough epochs, even the first training procedure should reach a similar performance to the second. Nonetheless, this short example shows how one can drastically improve their model performance when being thoughtful about the underlying data.

Recent Update
Trending Tags
Contents