A fastai/Pytorch implementation of MixMatch

When trained on CIFAR10 with 250 labeled images, MixMatch outperforms the next best technique (Virtual Adversarial Training) by almost 25% on the error rate (11.

08% vs 36.

03%; for comparison the fully supervised case on all 50k images has an error rate of 4.

13%).

[1] These are far from incremental results, and the technique shows the potential to dramatically improve the state of semi-supervised learning.

Semi-supervised learning is largely a battle against overfitting; when the labeled set is small it doesn’t take a very large neural network to memorize the entire training set.

The general idea behind nearly all semi-supervised approaches is to leverage unlabeled data as a regularizer on the training of labeled data.

For a great overview of various semi-supervised learning methods, see this blog post by Sebastian Ruder.

Different techniques employ different forms of regularization, and the MixMatch paper divides these into three groups: entropy minimization, consistency regularization, and generic regularization.

As all three forms of regularization have proved effective, the MixMatch algorithm contains features from each.

MixMatch is a combination and improvement upon several techniques that have come out in recent years, including: Mean Teacher [2], Virtual Adversarial Training [3], and Mixup [4].

At a high level, the idea of MixMatch is to label the unlabeled data using predictions from the model and then applying heavy regularization in several forms.

The first is performing data augmentation several times and taking the average for the label prediction.

These predictions are then ‘sharpened’ to reduce their entropy.

Finally, Mixup is performed on the labeled and unlabeled sets.

Diagram of MixMatch — Image taken from original paper [1]I am aiming this post at those familiar with Pytorch, but not necessarily fastai.

For a Jupyter notebook version of this post containing the full code needed to reproduce all the results see this repository.

fastaiBefore diving into the paper, I’ll briefly talk about fastai.

Fastai is a library, built on Pytorch, which makes writing machine learning applications much easier and simpler.

Fast.

ai also offers a terrific online course covering both fastai and deep learning in general.

Compared to pure Pytorch, fastai dramatically reduces the amount of boilerplate code required to produce state of the art neural networks.

Here we’ll be using the data pipeline and training loop features of fastai.

ImportsComponentsLet’s first describe the individual pieces needed to assemble MixMatch, and then at the end put them together to form the complete algorithm.

Following the paper, we’ll be using CIFAR10 and taking 500 randomly selected images as the labeled training set.

The standard 10000 image test set is used for all accuracy measurements.

Data AugmentationData augmentation is a widely used consistency regularization technique, with its biggest success (so far) found in the computer vision realm.

The idea is to alter the input data while preserving its semantic label.

For images, common augmentations include rotation, cropping, zooming, brightening, etc.

 — all transformations which do not change the underlying content of the image.

MixMatch takes this a step further by performing augmentation multiple times to produce multiple new images.

The predictions of the model on these images are then averaged to produce a target for the unlabeled data.

This makes the predictions more robust than using a single image.

The authors found that just two augments were sufficient to see this benefit.

Fastai has an efficient transformation system which we’ll utilize on the data.

However, as it’s designed to produce only one augmentation per image and we will need several, we will start by modifying the default LabelList to emit multiple augments.

Multiple AugmentationFastai’s data block api allows for flexibly loading, labeling, and collating nearly any form of data.

However, it doesn’t have a method to grab a subset of one folder and the entirety of another folder, which is required here.

Thus, we’ll subclass the ImageList class and add a custom method.

We’ll use fastai’s get_transforms method with no arguments to use the default image transforms; these are flipping around the center y axis, rotation up to 10 degrees, zooming, lighting change, and warping.

Fastai's transform system automatically randomizes the exact parameters of each transform when applied.

MixupMixupMixup was first introduced by Zhang, Cisse, Dauphin, and Lopez-Paz [4] in 2018 and falls into the category of general or traditional regularization.

Instead of passing single images to the model, Mixup performs a linear interpolation between two separate training images and passes that to the model.

The one hot encoded labels of the images are also interpolated, using the same λ coefficient as the images.

That coefficient is randomly drawn from the beta distribution, parameterized by the alpha.

Typically, α needs to be tuned to the dataset.

At small values of α, the beta distribution has most of its weight in the tails, close to 0 or 1.

As α increases, the distribution becomes uniform and then increasingly spiked around 0.

5.

Thus, α can be seen as controlling the intensity of the mixup; small values result in only a small amount of mixup, while larger values bias towards maximum mixup (50/50).

At the extremes, α=0 results in no mixup at all, and as α→∞, β approaches a Dirac delta distribution centered at 0.

5.

The authors recommend starting with a value of .

75, which as seen below still has most of the weight in the tails.

The paper makes one modification to the original method, which is to set λ to max(λ,1-λ); this biases the mixup towards the original image.

Beta DistributionMixupSharpeningSharpenThe authors sharpen the model’s predictions on the unlabeled data with the above equation as a form of entropy minimization.

If the temperature T < 1, the effect is to make the predictions more certain, and as T drops towards zero the predictions approach a one-hot distribution (see figure below).

This relatively simple step, which involves no learned parameters, turns out to be incredibly important to the algorithm.

In an ablation study, the paper reports an accuracy reduction of over 16% when removing the sharpening step (setting T to 1).

Sharpening random distributionThe idea behind entropy minimization in semi-supervised learning is that the decision boundary of the classifier should not pass through high density regions of the data space.

If this were the case, the boundary would split data that are very close together.

In addition, small perturbations would result in large changes in predictions.

As predictions near the decision boundary are more uncertain, entropy minimization seeks to make the model more confident in its predictions thus moving the boundary away from the data.

While other approaches [3] add an entropy term to the loss, MixMatch directly lowers the entropy of the unlabeled targets via the equation above.

As an example of this technique, let’s try a classification problem that’s simpler and easier to visualize than CIFAR — MNIST.

We’ll still take 500 random examples as the labeled training set and reserve the rest as the unlabeled set.

The full images are used for training, but we’ll also reduce each image to two dimensions using tSNE for visualization.

Training in a semi-supervised manner following the same approach as MixMatch with regards to the unlabeled data, we’ll use the model itself to generate pseudo-labels.

The model consists of just two convolution layers and a linear head.

No mixup or data augmentation is used, so we can isolate the effects of entropy minimization.

The loss function is also largely the same as MixMatch, using cross-entropy for the labeled data and mean squared error for the unlabeled data (see the loss section below for the rationale behind this).

The upper image is trained without using sharpening and in the lower image the pseudo-labels were sharpened with T=0.

5.

Training each for ten epochs, the unsharpened model has a test accuracy of 80.

1%, and the sharpened model has an accuracy of 90.

7%.

In the images below, colors correspond to predicted class, and marker size is inversely proportional to prediction confidence (smaller markers are more confident).

As shown by the marker sizes, the unsharpened model has a lot of uncertainty, especially around the edges of the clusters, while the sharpened model is much more confident in its predictions.

The effect of sharpening on the semi-supervised training of MNIST.

Images in MNIST were reduced to two dimensions using tSNE.

Colors correspond to predicted class, and marker size is inversely proportional to prediction confidence (smaller markers are more confident).

The upper image depicts training with T=1, and the lower image with T=0.

5.

The MixMatch AlgorithmNow with all the pieces in place, the full algorithm can be implemented.

Here are the steps for a single training iteration:Supply a batch of labeled data with its labels, and a batch of unlabeled dataAugment the labeled batch to produce a new training batch.

Augment each image in the unlabeled batch K times, to produce a total of Batch Size * K new unlabeled examples.

For each original image in the unlabeled batch, pass the K augmented versions to the model.

Average the model’s predictions across the augments to produce a single pseudo-label for the augmented images.

Sharpen the pseudo-labels.

The augmented labeled dataset and its labels form set X.

The augmented unlabeled data and its (predicted) labels form set U.

Concatenate sets U and X into set W.

Shuffle W.

Form set X’ by applying mixup to sets X and |X| examples from W.

Form set U’ by applying mixup to sets U and the examples in W that were not used in step 8.

Sets X’ (labeled mixup) and U’ (unlabeled mixup) are then passed to the model, and the loss is computed using the corresponding mixed-up labels.

ModelWe will use a wide resnet model with 28 layers and a growth factor of 2 to match the paper.

We’ll use fastai’s included WRN implementation and match the architecture used in the paper.

LossWith data and model in hand, we’ll now implement the final piece required for training — the loss function.

The loss function is the summation of two terms, the labeled and unlabeled losses.

The labeled loss uses standard cross entropy; however the unlabeled loss function is the l2 loss instead.

This is because the l2 loss is much less sensitive to very incorrect predictions.

Cross entropy loss is unbounded, and as the model’s predicted probability of the correct class goes to zero cross entropy goes to infinity.

However, with l2 loss, since we are working with probabilities, the worst case is that the model predicts 0 when the target is 1 or vice versa; this results in a loss of 1.

With the unlabeled targets coming from the model itself, the algorithm doesn’t want to penalize incorrect predictions too harshly.

The parameter λ ( l in the code since lambda is reserved) controls the balance between the two terms.

We’ll make one slight departure from the paper by linearly ramping up the weight of the unlabeled loss over the first 3000 iterations (roughly 10 epochs).

Before applying this rampup, there were difficulties training the model; and we found the accuracy would increase very slowly in early epochs.

Since the predicted labels at the start of training are essentially random, it makes sense to delay the application of unlabeled loss.

By the time the weight of the unlabeled loss becomes significant, the model should be making reasonably good predictions.

TrainingBefore training, let’s review the hyperparameters that have been introduced.

HyperparametersThe authors of the paper claim that T and K should be relatively constant across most datasets, while α and λ need to be tuned per set.

We’ll use the same hyperparameter values as the paper’s official implementation.

One implementation detail: the paper mentions that instead of learning rate annealing, it updates a second model with the exponentially moving average of the training model’s parameters.

This is yet another form of regularization but is not essential to the algorithm.

For those interested, there is code for training with an EMA model in the repository.

However, there was not significant benefit over learning rate scheduling, and in the name of simplicity we’ll forgo EMA and use fastai’s implementation of the one cycle policy to schedule the learning and momentum rates.

We’ll use fastai’s callback system to write a method which handles most of the MixMatch steps.

This method takes in batches from the labeled and unlabeled sets, gets the predicted labels, and then performs mixup.

A fastai Learner object contains the dataloaders and the model and is responsible for executing the training loop.

It also has a lot of utility functions, such as learning rate finding and prediction interpretation.

An epoch in this implementation is one pass through the entire unlabeled dataset.

LearnerResultsFor reference, these tests were run on a Google Compute Engine virtual machine with 16 CPUs and a single P100 GPU.

The first step is to establish some baselines so that MixMatch’s performance can be compared.

First, we’ll try the fully supervised case with all 50k training images.

Next we’ll train on just the 500 labeled images, with no unsupervised component.

Finally, we’ll train with MixMatch, using the learner defined in the previous section.

The full details of these runs can be found in the notebook in this post’s repository.

ResultsConclusionMixMatch clearly boasts impressive performance, but the downside is the additional time cost in training.

Compared to the fully supervised case, training MixMatch takes approximately 2.

5x longer.

Some of this may be due to inefficiencies in the implementation but generating multiple augmentations and then obtaining model predictions for labels has a significant cost, especially in the one GPU case.

We trained using the offical Tensorflow implementation for comparison and verified that MixMatch takes a long time to fully converge; over twelve hours of training resulted in an error rate several percent higher than the one reported in the paper.

It would take nearly 36 hours of training on the P100 setup to match their results fully.

However, a few hours of training will achieve the vast majority of accuracy improvement, with the final few percent taking most of the total training time.

While augmentation and sharpening are hugely beneficial, the paper’s ablation study shows that the single most important component, error wise, is MixUp.

This is also the most mysterious component in terms of why it works so well — why should enforcing linearity in predictions between images help the model?.Certainly it reduces memorization of training data, but so does data augmentation and to not nearly the same effect in this case.

Even the original MixUp paper only provides informal arguments as to its efficacy; from that paper:“We argue that this linear behavior reduces the amount of undesirable oscillations when predicting outside the training examples.

Also, linearity is a good inductive bias from the perspective of Occam’s razor, since it is one of the simplest possible behaviors” [4]Other researches have expanded upon the idea, for example by mixing up intermediate states instead of the input [6], or using a neural network instead of the beta function to generate the mixup coefficient [5].

However, I am unable to find a solid theoretical justification; this is yet another technique that falls into the ‘it just works’ category.

It would be difficult to draw a biological analogy — humans hardly learn a concept by blending it with an unrelated concept.

MixMatch is a hugely promising approach, with applicability to domains beyond computer vision.

It will be interesting to see it both understood and applied further.

References[1]: Berthelot, David, Nicholas Carlini, Ian Goodfellow, Nicolas Papernot, Avital Oliver, and Colin Raffel.

“MixMatch: A Holistic Approach to Semi-Supervised Learning.

” ArXiv:1905.

02249 [Cs, Stat], May 6, 2019.

http://arxiv.

org/abs/1905.

02249.

[2]: Tarvainen, Antti, and Harri Valpola.

“Mean Teachers Are Better Role Models: Weight-Averaged Consistency Targets Improve Semi-Supervised Deep Learning Results.

” ArXiv:1703.

01780 [Cs, Stat], March 6, 2017.

http://arxiv.

org/abs/1703.

01780.

[3]: Miyato, Takeru, Shin-ichi Maeda, Masanori Koyama, and Shin Ishii.

“Virtual Adversarial Training: A Regularization Method for Supervised and Semi-Supervised Learning.

” ArXiv:1704.

03976 [Cs, Stat], April 12, 2017.

http://arxiv.

org/abs/1704.

03976.

[4]: Zhang, Hongyi, Moustapha Cisse, Yann N.

Dauphin, and David Lopez-Paz.

“Mixup: Beyond Empirical Risk Minimization.

” ArXiv:1710.

09412 [Cs, Stat], October 25, 2017.

http://arxiv.

org/abs/1710.

09412.

[5]: Guo, Hongyu, Yongyi Mao, and Richong Zhang.

“MixUp as Locally Linear Out-Of-Manifold Regularization,” n.

d.

, 9.

[6]: Verma, Vikas, Alex Lamb, Christopher Beckham, Amir Najafi, Ioannis Mitliagkas, Aaron Courville, David Lopez-Paz, and Yoshua Bengio.

“Manifold Mixup: Better Representations by Interpolating Hidden States,” June 13, 2018.

https://arxiv.

org/abs/1806.

05236v7.

.

. More details

Leave a Reply