Using Graph Convolutional Networks for Single-Cell Classification

Madeline Hughes (mhughe10), Nitya Thakkar (nthakka3)

Introduction

The development of single-cell RNA sequencing (scRNA-seq) has allowed us to study heterogeneity across cells and characterize them based on transcriptomic markers. This type of data represents the gene expression levels across various cells in a sample. Due to the unique transcriptional signals, scRNA-seq data can be used to extract unique cell types, which provides significant insight into their role in biological processes and informs studies on gene regulatory and developmental mechanisms. Prior to the implementation of computational classification methods, researchers relied on manual, qualitative approaches like microscopy and histology that used visual cell surface markers to identify cell types. Transcriptomic markers in scRNA-seq data allow us to quantitatively identify cell types which is a significantly more rigorous and precise approach, eliminating human errors. Previous supervised and unsupervised computational approaches to single cell classification involve clustering cell samples based on similarity of key transcriptional signals and labeling based on average expression values. This is insufficient, as it assumes that these cells in the same cluster are of the same type. Other models that utilize deep learning networks to identify individual cells only utilize gene expression data, failing to extract global, dynamic features from the data. We aimed to implement a model that classifies cells independently, and incorporates information from gene interaction networks for more substantiated results. The most effective way to represent dynamic interaction data like this is with graphs. Our model applies a graph convolutional network to the gene adjacency and gene expression data in parallel with feed-forward neural network applied only to the gene expression data, whose outputs are concatenated and fed into a fully connected layer to obtain the cell type that is most probable for each input cell. This is validated and tested against confirmed cell labels.

Methodology

(See method overview figure in link to paper)

The single-cell RNA sequencing data was obtained from the Gene Expression Omnibus, and the corresponding cell annotations were obtained from the NCBI gene expression omnibus. The gene expression data is a matrix of shape [number of cells, number of genes] and gives the quantitative expression values for each gene in each cell, measured from single-cell RNA sequencing. We preprocess the gene expression matrix by removing genes that have an expression value of 0 across all cells, applying log scale and min-max normalization to each value, and isolating the top 1000 most variant genes (in terms of expression across cells). The gene expression data also comes with a corresponding set of cell type annotations of shape: [number of cells], where there are 5 unique cell types. We also convert the labels array to a one-hot vector. Next, we obtain a gene interaction network from STRINGdb, which calculates the interconnectivity between every pair of genes in the dataset as a confidence score. We use these scores to construct the adjacency matrix, which displays these scores in a matrix of shape: [number of genes, number of genes]. We set the diagonal to 0 to reflect how the nodes (genes) do not interact with themselves (i.e. no self-loops). The genes that are not included in the gene adjacency matrix are dropped from the gene expression matrix, as they are likely not significant cell type signals. We use the preprocessed gene expression matrix and gene adjacency matrix to construct the graph. The nodes represent genes embedded with expression values,and the edges represent connections between genes, obtained from the gene adjacency matrix. The graph is fed into an encoder composed of a graph convolution, maxpool, flatten, and dense layer, with an output size of 32. In the graph convolution layer, the adjacency matrix is converted into the normalized Laplacian matrix of which we obtain the eigenvalues from. Then, we perform a chebyshev polynomial expansion on the Laplacian matrix; the trainable variables in this layer are the chebyshev coefficients. In parallel, the gene expression matrix is fed into a feed-forward neural network. The encoder-decoder loss is calculated using mean squared error. The output of the feed-forward network is concatenated with the output of the encoder and fed into a final dense layer, whose output is of shape: [number of cells, number of cell type classes], where the number of classes is 5 because there are 5 cell types in this dataset. For each cell, the value in each of the 5 columns corresponds to the probability that it belongs to that class (the probability that it is that cell type). We obtain the predicted labels for each cell by calculating the column index that corresponds to the highest probability. Loss here is calculated using categorical cross entropy, comparing the predicted labels to the actual labels. Accuracy is calculated by dividing the number of correctly predicted labels by the number of cells. We train the model for 20 epochs using gene expression and interaction data obtained from 3,804 cells and 838 genes. The preprocessed gene expression dataset was segmented so that 80% was used as train data, 10% was used as validation data, and 10% was used as test data. We used a learning rate of 0.001 with an Adam optimizer, a batch size of 256, and a max pool size of 8.

Results

(See results figures in link to paper)

The first figure on the left visualizes the loss of our model as we train it. As shown in the graph, the average loss decreases as we train for more epochs, which indicates our model is indeed learning the trainable parameters. Then, to quantify the accuracy of our model, we first computed the accuracy of the predictions by recording the proportion of correct predictions to total predictions. These results are highlighted in the figure on the right, and we were able to obtain a test accuracy of 0.71. This was with minimal unique hyperparameter tuning, as we used the hyperparameters the paper reported which were very specific to their model, but perhaps we can boost our performance in the future by tuning the parameters further. The other accuracy metrics we used, as described in the paper, were precision, recall, and F1 score. Precision is the ratio of true positive scores to the total number of true and false positives. Recall is the ratio of true positive scores to the total number of true positives and false negatives. The F1 score is the weighted average of the precision and recall scores. As indicated by these moderately high score metrics, we can conclude that our model is able to predict the single-cell cell types given gene expression inputs, and with further tuning we can hopefully further boost its performance.

Discussion, Challenges, Future Work

The greatest challenge was constructing the graph convolution (GCN) layer. There is no GCN layer in keras, so we had to construct it manually. This involved rigorous math that was difficult to understand. This part of the project took us the longest to debug. However, the process demystified the theory behind graph convolution, helping us understand the underlying mechanisms for how the features are extracted from the data. We were also able to improve our research skills by closely analyzing inputs and outputs of each component of the layer to find solutions to our bugs, as well as using external resources to identify errors. It was really interesting and rewarding to explore an implementation of a model that we had not studied as closely in class. However, our model only has an accuracy of 70%, so at its current state, it should not be used in other research studies.

We hope to further tune our model so we can improve its performance to match the paper’s performance, keeping in mind that we are using both a different dataset and a different framework (Tensorflow rather than PyTorch). We hypothesize that the paper finely tuned their model to perform really well on their datasets, which is why our performance may not be as great on theirs using the same hyperparameters. We also hope to test this model against another version of the model that has no encoder/decoder layers and instead is just a simple feed forward network; we hypothesize that it will not perform as well without the encoder output, but hope to test it to validate the effectiveness of this model approach.

An interesting addition to this model would be if it could incorporate information about cells’ spatial locations in a tissue. This could allow for interesting developments downstream with understanding intercellular interactions.

Reflection

Our project ultimately turned out much better than expected, because we chose a complicated model to implement. We were hoping to get our model to run with an accuracy that was above 50% and we accomplished that, yielding an accuracy of 71% on our final run. At first, we tried to follow other graph convolution tutorials due to the complexity of the computations described in the paper, but ultimately followed the paper’s mathematical method because we ran into many errors and the model wouldn’t function properly. If we had done our project over again, we would have outlined our understanding of the architecture before starting, as we just started coding following their architecture depiction and likely wasted a lot of time on the bugs that were derived from misconceptions. Our accuracy could definitely be improved, as the latest run is approximately 71%. We could potentially improve this accuracy by adding more layers and tuning hyperparameters like the learning rate, batch size, feature space, and number of epochs. The biggest takeaway from this project is how it demystified the process of computational research. We gained more insight into how researchers find projects in the literature that they want to work on and where they begin. We were always intimidated by the gap between academic coding projects and professional research endeavors, but completing this project gave us more confidence in our abilities and fueled our interest in a research career.

Link to final write-up

Link to reflection

Link to initial outline

Built With

Share this project:

Updates

posted an update

Reflection 11/30 - Nitya Thakkar (nthakka3) and Madeline Hughes (mhughe10) Introduction: This can be copied from the proposal. We are implementing an existing paper, and the goal is to identify cell types in single cell data using gene expression data. This is a classification problem, and they accomplish it using a Graph Convolutional Network (GCN). We chose this paper because we are both interested in computational biology and thought GCNs were really interesting (and not something we talked about in class).

Challenges: What has been the hardest part of the project you’ve encountered so far? Preprocessing the data has been really difficult for us. It was hard to find a dataset that worked (that was different from one used in the paper). After finding it, there were many steps we had to take: remove unlabeled cells and cells labelled as debris and doublets remove genes with zero expression values across all cells transform gene expression values into log scale and normalize each dataset by min–max scaling after calculating variances of the genes across all the cells, sort the variances in descending order and choose the top 1000 genes as the input of the classifiers construct gene adjacency network from the selected genes We have not fully finished this yet, and after we complete it we next need to construct the gene adjacency matrix as follows: chose top N genes with highest variances in expression values for training Size is N x N (N = number of genes) elements in matrix represent the confident score between pairs of genes extracted from the gene–gene interaction database Normalize weights by row sums Use this to build a weighted graph where nodes are genes and edges represent the connection between genes and the normalized confidence scores are weights of edges We anticipate this will be the most challenging part of our project. Insights: Are there any concrete results you can show at this point? How is your model performing compared with expectations?

We unfortunately haven’t gotten to this point yet since we’ve been stuck on data pre-processing, but hope to have results soon (our goal is by the end of this week).

Plan: Are you on track with your project? What do you need to dedicate more time to? What are you thinking of changing, if anything?

We are running a bit behind just because it is taking us so long on the data pre-processing side. We are hopeful that once we are done with this, the rest of the modeling will go faster. Our goal is to be done with data preprocessing this week so we can also create the model this week and have results by this weekend. We may have to change how we approach the model, since we were hoping to change it a bit if possible but we may not have time to experiment with that.

Log in or sign up for Devpost to join the conversation.