Learning from Temporal Graph Networks
This repo is what I worked on for Stanford's 2021 Treehacks virtual hackathon. My intentions were to revamp my graph machine learning skills as well as practice implementing an interesting and timely AI/ML publication.
Interested in how Twitter uses graph machine learning methods to scale content moderation and other trust and safety applications on their platform, I dove in to a paper written by ML engineers at Twitter about how they model temporal graphs. This blog post explains the general idea of the paper -- use a "raw message store" to accumulate past information that is then used to transform node messages into node states and eventually node embeddings for task-oriented prediction.
This repo contains my implementation of Twitter's TemporalGraphNetwork (TGN).
Running this code
To run this code, you should first install the necessary dependencies: pip install -r requirements.txt. The code has been tested for link prediction on the College Messages dataset available here, but free to adapt the code to use another temporal graph network as input or predict on a different task. If you would like to try to train a self-supervised TGN on the College Messages dataset, download it to data/CollegeMessage.txt and run python3 main.py --input-graph data/CollegeMessage.txt --output result.pkl. This will run 4 different experiments doing a simple hyperparameter search over 2 batch sizes and 2 node state sizes. (You may add a --parallel flag to train different experiment setups on separate processes.)
Experiment Setup
In training the TGN, I use an Adam optimizer with exponential learning rate decay and binary cross-entropy loss to compute the gradient as well as to evaluate model performance. The model is trained in a self-supervised way, meaning it predicts the next batch's interactions using information from all previous batches.
I hyperparameter searched over batch sizes (the # interactions included in each forward pass of the TGN during training), node memory/state size, and the learning rate. Intuitively, I found that smaller batch sizes generally achieve better performance. Learning rates of 5e-2 or 1e-2 work well, but I did not notice performance swings that corresponded with node memory/state size. While an effective model, as defined by the hyperparameters above, does indeed learn, the predictions are still quite noisy. The best training loss achieved is around 0.28, with an average final training loss (the loss in the last batch) in the .5-.6 range. Expected BCE loss for a uniform random predictor is around .7.
Extending this code
The TGN framework aims to be modular so that it is easy to extend it for different tasks. Dive into the modules.py file to find the simple implementations and simple learners I used for message aggregation, memory updates, and so forth. To architect your own TGN from your own specified modules, either write them in modules.py or include them in experiments.py and wire them together as the layers input to TemporalGraphNetwork constructor (this is done in the BasicTGNExperiment constructor).
Next steps
In the future, this TGN could be improved by incorporating better-designed modules. The modules I used in this project were essentially all either MLPs or identity mappings. However, the TGN's performance might be improved drastically is more complex modules were implemented and used for the various layers of the TGN. For example, the Twitter paper was very successful using attention for the graph embeddings. Or, a dynamic embedding trajectory could be learned as described in Stanford SNAP group's JODIE paper.
Another point of interest is to see how well the method performs after more training time. The College Message dataset is relatively tiny, so it would be interesting to test this TGN learner on the much-larger Wikipedia or Reddit datasets available here. These datasets and other popular temporal interaction datasets have node features and/or edge features which would need to be incorporated into the TGN code for the best chance of success. The TGN framework proposed by Twitter in their paper is designed to be versatile with respect to various graph prediction tasks, node/edge feature availabilities, and layer learners, so adapting the TGN programmed in this repo for the Reddit and Wikipedia datasets should be an accessible and interesting next step.
Finally, while this TGN is trained using self-supervision and evaluated on the same set of (all) nodes of the College Messages graph, in practice it would be much better to withold a validation set of nodes and evaluate the performance of the learned model on these nodes after training the TGN on the training nodes. For the sake of learning, I prioritized implementing the details of the TGN model over architecting a more mature training loop. To improve this codebase I could implement a few options for training the TGN, including self-supervised with train/val/test splits and semi-supervised with node labels and so forth, in order to evaluate generalizability of the learned model as well as to make the training loop itself more generalizable to different prediction tasks.
Log in or sign up for Devpost to join the conversation.