Advances in few-shot learning: reproducing results in PyTorch

For Omniglot it will have shape (n_support + n_query, 1, 28, 28)The math in the previous post is for one query sample but Matching Networks are in fact trained with a batch of query samples of size q_queries * k_wayI was unable to reproduce the results of this paper using cosine distance but was successful when using l2 distance..Seeing as the choice of distance is not key to the paper and results are very good using l2 distance I decided to spare myself that debugging effort.Prototypical NetworksIn Prototypical Networks Snell et al use a compelling inductive bias motivated by the theory of Bregman divergences to achieve impressive few-shot performance.The Prototypical Network algorithm can be summarised as follows:Embed all query and support samples (line 36)Calculate class prototypes taking the mean of the embeddings of each class (line 48)Predictions are a softmax over the distances between the query samples and the class prototypes (line 63)I found this paper delightfully easy to reproduce as the authors provided the full set of hyperparameters..This paper was the most difficult yet most rewarding to reproduce of the three in this article.The MAML algorithm can be summarised as follows:For each n-shot task in a meta-batch of tasks, create a new model using the weights of the base model AKA meta-learner (line 79)Update the weights of the new model using the loss from the samples in the task by stochastic gradient descent (lines 81–92)Calculate loss of the updated model on some more data from the same task (lines 94–97)If performing 1st order MAML update the meta-learner weights with the gradient of the loss from part 3..If performing 2nd order MAML calculate the derivative of this loss with respect to the original weights (lines 110+)The biggest appeal of PyTorch is its autograd system..I had to learn a bit more about this system in order to calculate and apply parameter updates to the meta-learner, which I will now share with you.1st Order MAML — gradient swappingTypically when training a model in PyTorch you create an Optimizer object tied to the parameters of a particular model.from torch.optim import Adamopt = Adam(model.parameters(), lr=0.001)When opt.step() is called the optimiser reads the gradients on the model parameters and calculates an update to those parameters..However in 1st order MAML we’ re going to calculate the gradients using one model (the fast weights) and apply the update to a different model i.e..This means that when opt.step() is called the gradients of the fast model will be used to update the meta-learner weights as desired.2nd Order MAML — autograd issuesWhen making my first attempt at implementing MAML I instantiated a new model object (subclass of torch.nn.Module) and set the values of its weights equal to the meta-learner’s weights..However this makes it impossible to perform 2nd order MAML as the weights of the fast model are disconnected from the weights of the meta-learner in the eyes of torch.autograd..What this means is when I call optimiser.step() (line 140 in the gist) the autograd graph for the meta-learner weights is empty and no update is performed.# This didn't work, meta_learner weights remain unchangedmeta_learner = ModelClass()opt = Adam(meta_learner.parameters(), lr=0.001)task_losses = []for x, y in meta_batch: fast_model = ModelClass() # torch.autograd loses reference here!.I was, however, retaining the autograd graph of the loss on the query samples (line 97) but this was insufficient to perform a 2nd order update as the unrolled training graph was not created.Training time was quite long (over 24 hours for the 5-way, 5-shot miniImageNet experiment) but in the end I had fairly good success reproducing results.I hope that you’ve learnt something useful from this technical deep dive.. More details

Leave a Reply