Explainable AI for Alzheimer’s Disease MRI Classification

Alzheimer’s disease is a complex neurodegenerative disorder that remains difficult to diagnose accurately in its early stages. I became interested in how machine learning could be used to identify subtle structural patterns in brain MRI scans that are not easily observable through visual inspection alone. As I explored existing deep learning approaches, I found that although many models achieve high accuracy, they often lack interpretability, making their predictions difficult to trust in medical contexts. This motivated me to focus not only on model performance, but also on explainability, using explainable AI techniques to better understand how predictions are made.

I am relatively new to machine learning, and this project served as an opportunity to learn how deep learning can be applied to real-world healthcare problems. The Hack4Health starter project also provided additional inspiration and guidance in structuring my work.

I developed an explainable deep learning pipeline to classify Alzheimer’s disease stages from structural MRI scans. The system takes MRI images as input and outputs one of four disease stage predictions. To improve transparency, I integrated Grad-CAM, which highlights regions of the brain that contribute most to each prediction. This enables better analysis of model being presented.

I began by performing exploratory data analysis to understand the dataset structure and class distribution. After preprocessing the MRI scans by normalizing pixel values and adapting grayscale images to three-channel input, I implemented a baseline convolutional neural network to validate the data pipeline. To improve performance, I applied transfer learning using a pretrained ResNet18 model, replacing the final classification layer to match the four-class task. The model was trained using stratified data splitting to preserve class proportions across training and validation sets. Grad-CAM was then applied to the trained model to generate heatmaps for both correct and incorrect predictions.

One of the main challenges was severe class imbalance, with one disease stage being significantly less than others. Additionally, differences between adjacent Alzheimer’s stages were often subtle, making misclassifications difficult to avoid. Interpreting Grad-CAM outputs also required careful analysis, as attention maps can be noisy for uncertain predictions.

Through this project, I gained hands-on experience building machine learning models for medical applications. I learned how to apply convolutional neural networks, transfer learning, and explainable AI techniques, as well as the importance of reproducibility, evaluation under class imbalance, and the limitations of relying solely on accuracy metrics in medical machine learning.

The final model achieved approximately 89% validation accuracy, outperforming the baseline CNN. Grad-CAM visualizations showed focused activation patterns for correct predictions and more dispersed attention for misclassified cases, highlighting both the strengths and limitations of the model. This demonstrates how explainable deep learning can contribute to more transparent and trustworthy medical imaging research.

With additional time, future improvements could include experimenting with class-weighted loss functions, 3D CNN architectures, and multimodal data integration to further enhance performance and robustness.

This project strengthened my interest in applying explainable machine learning to real-world healthcare problems and deepened my understanding of how artificial intelligence can support biomedical research.

Built With

Share this project: