Few Shot Learning with LSTMs for Image Classification
Summary: Few shot imputation from images
Who: Justin Sanders (jsander8), Gus VanNewkirk (gvannewk), AM (amadoff)
Final Write-up
https://docs.google.com/document/d/1NVUx_CA6yRacuCqD_Ykr-3Tg5vDAHwHg8cAOWtWUVew/edit?usp=sharing
Introduction:
Humans are very good at learning from very few examples - show a child a single picture of an animal and they’ll be able to recognize it again in the future.
Deep learning models, on the other hand, require a glut of data. In order to recognize images, a standard CNN needs hundreds if not thousands of labeled training examples to reach a reasonable performance.
The goal of our project is to explore this problem of learning a new concept from relatively few examples. In particular, we hope to train a network that can take in a few examples of a never-before-seen class of image along with a test image, and output whether the test image belongs to the same class as the example images. Effectively, our model will be using the prior knowledge that it’s learned about images in general to create a binary image classifier without the need to see hundreds of images of the new class.
Related Work:
The following blog post helped inspire the idea of tackling few-shot imputation, and gave a background introduction to what the task entailed: https://sorenbouma.github.io/blog/oneshot/
The following review paper helped us understand the task further: https://arxiv.org/pdf/1904.05046.pdf
Additionally, the following paper from Google DeepMind describes their approach to the problem of n-way k-shot learning: https://arxiv.org/pdf/1606.04080v2.pdf
Data:
We will potentially first perform a proof of concept using the MNIST dataset. To do this we will supply the network training data for the digits 1-8, and then supplement k-examples each of the digits 9 and 10, and see whether the network is able to recognize these digits based on the few examples it was shown. .
After finding areas of improvement to our model from testing on the MNIST dataset, we will then transfer to using CIFAR-10 and 100. For this task, we will train using all of the classes from CIFAR-100, and then supplement k-shot examples from CIFAR-10. None of the classes in CIFAR-10 are present in CIFAR-100, therefore these additional examples will represent classes that are entirely new to the network. To evaluate the model, we will then test the network’s accuracy at recognizing these 10 never before seen classes.
Methodology:
Our model first creates vector embeddings for images using either our own (for MNIST) or a pretrained (for CIFAR) CNN, and then combines embeddings of multiple images of one class using an LSTM. It then compares this merged embedding to the embedding of a test image using either a linear layer or a standard distance metric (eg. cosine similarity) to classify whether or not the test image is of the same class. To train this model, we will create mini batches of k images of the same class, and feed the model batch_size random images with binary labels representing whether or not those images belong to the class of the k images.
We think this architecture makes sense for the problem, because ideally the LSTM will merge the embeddings from the k example images into a single embedding that represents only the most important/relevant features for identifying that class of image.
Metrics:
This is very dataset dependent. For both MNIST and CIFAR-100, we expect accuracy better than a naive implementation, however, we expect the accuracy on MNIST to be much higher than that of CIFAR-100. Our baseline goal is to achieve accuracy on MNIST that is better than that of a naive implementation for few-shot learning such as nearest neighbor clustering based on euclidean distance. Our target goal is to have our method outperform a one-shot imputation technique that compares the embeddings of the example and test image based on cosine similarity - meaning that our model needs to be actually combining the embeddings from all of the examples it’s given in a way that improves performance. Our stretch goal would be to achieve accuracy that approaches the accuracy achieved by state of the art methods, albeit on a smaller dataset.
Ethics:
One issue present in the problem space of recognizing images based on few examples is the potential applications to data-mining and surveillance. One could imagine models similar to ours being used to analyze social media posts and figuring out where users are traveling, what they’re purchasing, and other information that people would prefer to keep private. The ability for a model to learn from a few examples makes it even easier for companies to deploy. Want to serve targeted ads to people who went to the Beyonce concert last night? Show a model a couple of images from the event and it can recognize other posts made from the venue. Without the requirement of large labeled datasets, anyone can quickly adapt a DL model to a niche clarification task - further impeding on people's online privacy.
Additionally, due to its simplification of the problem of classification, as well as reuse of other methods, our methodology would be a much more energy efficient alternative to training a massive CNN. Because this methodology generalizes to any classification problem, few-shot learning could save enormous amounts of energy if applied to real-world classification problems. Of course, it is not useful for all classification problems to be simplified to binary classification, however, this method can be generalized to classify any image according to “n” previously seen classes seen “k” times. The flexibility of this approach makes it more likely to be easily implemented to solve a variety of real world problems, and in doing so help the environment by using less energy than alternative methods.
Division of Labor:
As preprocessing will likely be the bulk of the work and is not super exciting, we think we should share it equally. For the most near-term tasks, we have the following division of labor:
- Change MNIST data into something that generates batches appropriate for our model: Gus
- Find pretrained CNN for creating embeddings for CIFAR dataset: Alex
- Start preprocessing CIFAR data: Justin
Link to Week 1 reflections:
https://docs.google.com/document/d/1pQovOJYoXx6LML2gR7a-A5lq6N_PB41gdRPzLxeWE9o/edit



Log in or sign up for Devpost to join the conversation.