Inspiration

We were inspired by the Quick, Draw! dataset released by Google

What it does

Our model is trained on a new dataset, similar to the Quick, Draw! dataset, but with only 15 classes and far less datapoint. Given a sequence of strokes, it will predict the image the user is trying to draw.

How we built it

We trained our models using PyTorch and Torchvision. We opted not to use a pretrained Quick, Draw model, instead using models that were pretrained on the more general ImageNet dataset. Specifically, we use 4 different CNN architechtures: RegNet, ResNeXt, ConvNext, and EfficientNetv2. These are moderately large models with 15,296,552, 83,455,272, 28,589,128, and 21,458,488 parameters, respectively. During evaluation, we make predictions for each of these models, then average the predicted class probabilities from the 4 models before selecting our final output. In order to fine-tune such large models on this task, we used a variety of data augmentation procedures such as image rotations, flipping, blurring, and deformations. Furthermore, since the guessing process stops whenever we guess correctly, we strictly enforce the model to not predict the same class twice for any given image. We also incomplete images from the series of strokes that make up the image in order to improve our model's "speed" in regards to predicting the class in the fewest number of strokes. Finally, we used early stopping based on the performance on a validation hold-out set to select the optimal model checkpoint and prevent overfitting.

Challenges we ran into

We wanted to use a model that was pretrained on the Quick, Draw dataset, however we could not find a good model that was implemented in PyTorch, and we had a lot of trouble with getting TensorFlow/Keras to run on our machines. This is the main reason we chose to use models that were pretrained on ImageNet, however, we also feel that pretraining on ImageNet gives robust performance on a variety of downstream tasks. Secondly, we were not able to incorporate the series of stroke with some sort of recurrent architecture as many previous works have done because we did not have enough time left. There were many things we wanted to try if we had more time, however this was probably the most significant as our model cannot capture the temporal relationship between different strokes.

Accomplishments that we're proud of

We were proud of how much we were able to get done in 24 hours, and we think the data augmentations and ensembling that we did helped the performance quite a bit.

What we learned

We learned a lot about computer vision, transfer learning, and pretraining while working on this project. We also learned never to use TensorFlow/Keras.

What's next for Pictionary Plungers

Recurrent CNNs

Built With

Share this project:

Updates