Cerebro

A note

This readme looks better on github: https://github.com/stanleyjzheng/cerebro

Our project is live and deployed at https://sjz.ca (please try it out!)

Team

  • Vivi - UCalgary 2nd year CS
  • Stanley - UBC Sciences gap
  • Matt - UCalgary 2nd year CS

Inspiration

In the world of wildlife and climate change education, finding the correct clips and images is critical to one's storytelling. I like X-Men and have always wanted mutant powers... what would be cooler than channeling Charles Xavier's ability to look through the eyes of others and see everything instantly with his Cerebro?

How it works

TL;DR: We trained our own CLIP-style ViT model on a large dataset of 3M image-text pairs and then 153k specialized wildlife photos (scheme in "Training" below). We use its embeddings for the distance between the user's text input and the wildlife photos.

An embedding is a high-dimensional vector representation of something (a text or an image); the closer two vectors are, the more semantically similar they are. Eg. if you subtract the embedding of "mother" from the embedding of "father", you get a vector that is similar to the vector of "parents". This means that we can use distance to determine the similarity between two embeddings.

CLIP has two neural networks: one for images (in this case, a vision transformer) and one for text (a transformer, like GPT).

Given an image, the image encoder produces an embedding representing its visual features. Given a text string, the text encoder generates a corresponding embedding that captures its semantic meaning.

We scrape wildlife images, and for each image, we precompute its embedding and store it (scheme in "Similarity Search" below). When a user inputs a text query, it is converted into an embedding using the text encoder. Then, we calculate the distance between the text embedding and each image embedding.

The image with the highest distance is selected as the best match. This shared embedding space allows CLIP to generalize well to unseen image-text pairs, making it highly effective for zero-shot learning and retrieval tasks like ours.

Similarity Search Scheme

What is similarity search?

Given a set of vectors $x_i$ in dimension d, Faiss builds a data structure in RAM from it. After the structure is constructed, when given a new vector x in dimension d it performs efficiently the operation:

$$j = argmin_i||x-x_i||$$

where $||·||$ is the Euclidean distance ($L^2$).

In Faiss terms, the data structure is an index, an object that has an add method to add $x_i$ vectors. Note that the $x_i$'s are assumed to be fixed.

Computing the argmin is the search operation on the index.

This is all what Faiss is about. We use it to return the k-th nearest neighbour in a large dataset of 152982 images and tens of thousands of video stills in real-time.

Training Scheme

Checkpoints are in training/.

We trained on 4x Nvidia RTX4090 with a VIT-B-32 model (~151m params). Everything is fairly meticulously optimized to use data parallelism and split across multiple GPU's. It's capable of scaling across multiple nodes for larger models (but which would take much more than our 24hr time limit to train).

Our training scheme is as follows:

graph TD
    A["Initialize Model <br> (Random Weights, No Pretraining)"] --> B["Train on CC3M Dataset <br> (generalized dataset, 3M Image-Text Pairs)"]
    B --> C["Infer Pseudo Labels <br> on Wildlife Images (from 8 datasets)"]
    C --> D["Fine-Tune Model on <br> 153k Wildlife Images with Pseudo Labels"]
    D -.->|Iterate 3x: Generate New Pseudo Labels & Fine-Tune Again| C

For more details on pseudo-labelling, I gave a talk on it 4 years ago. It's fairly authoritative and it's much longer than I can explain in a readme.

CC3m was downloaded from img2dataset and took ~1 hour. Initial CC3M training took ~6 hours on the 4x RTX4090 GPU's. Then, each pseudo label scheme took ~45min for 152982 images.

We used no externally pretrained models. Our model was a random weights initialization which we trained from scratch on CC3M, then fine tuned on wildlife images.

How to run the frontend

It's deployed at sjz.ca but you can also run it locally on your own data.

  1. Install requirements with uv pip install -r pyproject.toml
  2. Have a directory of images in any format, and run uv run populate_embeddings.py ../dataset to generate embeddings and insert them into the FAISS database.
  3. Run the API with uv run uvicorn app:app --reload
  4. Install the interface with cd interface && npm i
  5. In a separate terminal, run the interface with npx run dev

How to run training

It's a really long process. Reach out to me!

Biggest challenge

We had to write the interface twice because streamlit doesn't support streaming text changes. And I thought it would be cool to have instantaneous inference while a user was still typing. It was a pain.

Built With

Share this project:

Updates