Inspiration

Recent advances in deep unsupervised learning allow for learning concise yet rich representations of images, audio, natural language, and more. Integrating these representations into sequential decision-making paradigms such as reinforcement learning is an essential step to creating general-purpose agents that can robustly incorporate diverse unstructured sources of data. We consider the contextual bandit setting as a tractable and real-world applicable version of reinforcement learning.

What it does

We base our work off of recent work from Google Brain: Deep Bayesian Bandits Showdown. This paper (and accompanying TensorFlow code) implements a simple MLP-based method for learning contexts from hand-crafted features via contextual bandit feedback.

Our contribution: we extended this work to include a novel unsupervised representation learning step. Specifically, we pre-train an unsupervised model, and use the learned embedding as an input to the context encoding MLP. We re-implemented contextual bandit algorithms with deep Thompson sampling in PyTorch, and test our algorithm on several tasks, including the Mushroom dataset, MNIST, and polarized Yelp reviews.

We confirmed that our code works properly by testing on the Edible Mushroom dataset. For the image MNIST dataset and text Yelp dataset, we trained a variational autoencoder and obtained averaged BERT embeddings respectively. We load and process our data using the PyTorch Hub and Torchvision Datasets APIs.

How we built it

We ported a contextual bandit model from Tensorflow to Pytorch. Then, we took various sources of real data (MNIST, Yelp 1-2 reviews) and built unsupervised models to contextualize such data into a low-dimension context vector. Finally, we trained the contextual bandit model on the generated contexts and compared regret to its baselines.

Challenges we ran into

Accomplishments that we're proud of

What we learned

What's next for Unsupervised Representation Learning for Contextual Bandits

Built With

Share this project:
×

Updates