NeuroStageNet: An Explainable Deep Learning Framework for Alzheimer’s Disease Stage Classification from MRI Scans
Inspiration
Alzheimer’s disease progresses gradually, and MRI often contains subtle structural signals that correlate with severity. In real-world settings, manual staging can be time-consuming and may vary across readers—especially for adjacent stages. This motivated building NeuroStageNet, a system that can automatically classify Alzheimer’s stage from MRI and also provide visual explanations (heatmaps) to improve transparency and trust.
What it does
NeuroStageNet takes a brain MRI image as input and predicts one of four Alzheimer’s disease stages (classes 0–3).
It outputs:
- a predicted stage (softmax class),
- a confidence distribution (class probabilities),
- and Grad-CAM heatmaps that highlight regions most influential to the prediction.
The system is designed as an end-to-end pipeline: dataset loading → image decoding → training → evaluation → interpretability.
How I built it
Data (provided dataset)
The dataset was provided as Parquet files:
train.parquettest.parquet
Each row contains:
- an image stored as binary bytes (sometimes wrapped in a dictionary structure),
- a label in one of four stages (0–3).
Preprocessing and input pipeline
- Byte decoding: Binary payloads are extracted robustly (handling dict-wrapped byte formats) and decoded using PIL into grayscale arrays.
- EDA checks: Training/testing shapes, class distribution, and image dimension statistics were examined to confirm dataset integrity and modeling feasibility.
- Train/validation split: An 80/20 stratified split preserves class proportions to ensure reliable validation metrics.
- Label encoding: Labels are converted to integers and then to one-hot vectors for categorical learning.
- tf.data pipeline: Images are stacked into tensors shaped
(N, H, W, 1)and streamed using batching and prefetching for efficient training.
Data augmentation
To reduce overfitting and improve generalization, augmentation is applied only during training:
- random horizontal flip,
- small rotations and zoom,
- brightness and contrast changes.
This simulates realistic imaging variability (alignment and intensity differences).
Model architecture (core contribution)
NeuroStageNet is a transfer-learning model built around three components:
(1) Learnable grayscale → 3-channel projection
- Input MRI:
(128, 128, 1) Conv2D(3, 3×3)maps grayscale to 3 channels- Rationale: EfficientNet expects RGB; a learnable projection is more expressive than simple channel replication.
(2) EfficientNetB0 backbone (ImageNet pretrained) + partial fine-tuning
EfficientNetB0(include_top=False, weights="imagenet")- Fine-tuning begins at
fine_tune_at=150(early layers frozen, later layers trainable) - Rationale: keep stable low-level features while adapting higher-level layers to MRI-specific staging patterns.
(3) Residual refinement block with SE attention (filters=256)
- Two
SeparableConv2Dlayers + BatchNorm + Swish - Residual skip connection (with 1×1 conv if channel mismatch)
- SE attention (GlobalAveragePooling → Dense → sigmoid gating → channel reweighting)
- Rationale: acts as a lightweight “adapter” to refine backbone features; residual learning stabilizes training; SE improves channel selectivity for subtle medical cues.
Feature aggregation + classifier
GlobalAveragePooling+GlobalMaxPoolingconcatenation
- captures both distributed (global) and strongest localized evidence.
- captures both distributed (global) and strongest localized evidence.
- Dense head:
- Dense(1024) + BN + Dropout(0.5)
- Dense(512) + BN + Dropout(0.3) + L2
- Output Dense(4) + softmax
- Rationale: two-stage dense head learns complex boundaries while regularization prevents memorization.
Training setup
- Optimizer: Adam, learning rate $5\times10^{-4}$
- Loss: CategoricalCrossentropy with label smoothing = 0.15
- reduces overconfidence and improves robustness on borderline stages.
- Callbacks:
- EarlyStopping (monitor
val_auc, restore best weights) - ReduceLROnPlateau (monitor
val_loss) - ModelCheckpoint (save best by
val_auc)
- EarlyStopping (monitor
Evaluation
The model is evaluated on the unseen test set using:
- accuracy,
- per-class precision/recall/F1,
- confusion matrix,
- multi-class AUC-ROC (OvR).
Observed performance (provided results):
- Test accuracy: ~0.98
- Multi-class AUC-ROC: ~0.9975
- Most confusion occurs between adjacent stages (especially 2 vs 3), consistent with gradual progression.
Interpretability (Grad-CAM)
Grad-CAM is used to generate heatmaps by computing gradients of the predicted class score with respect to a selected convolutional layer. This provides a visual explanation of what regions contributed most to the prediction.
Challenges I ran into
Selecting the right model architecture: Achieving strong performance on Alzheimer stage classification required experimentation. A key challenge was finding a model that is powerful enough to capture subtle stage differences while still being efficient and stable to train.
Balancing transfer learning with domain adaptation: Pretrained CNN backbones offer strong generic features, but MRI differs from natural images. It was challenging to decide how much of the backbone to fine-tune (what layers to freeze vs. unfreeze) to adapt to MRI patterns without overfitting or destabilizing training.
Integrating residual refinement effectively: Adding a custom residual block after the backbone improved task-specific learning, but designing it correctly was non-trivial. The block needed to enhance Alzheimer-relevant cues while preserving useful pretrained representations, and required tuning of filter size, normalization, and attention.
Preprocessing and augmentation sensitivity: MRI models can be highly sensitive to preprocessing. A major challenge was choosing augmentation types and strengths (flip/rotation/zoom/brightness/contrast) that improve generalization without distorting medically meaningful anatomy. Over-augmentation can introduce unrealistic images; under-augmentation can lead to overfitting.
Reducing confusion between adjacent stages: Alzheimer stages can overlap visually, especially neighboring classes. Minimizing confusion required careful regularization (dropout, L2), label smoothing, and monitoring validation behavior to ensure the model learned consistent, discriminative boundaries.
Achieving stable convergence: Learning rate, label smoothing magnitude, augmentation intensity, and callback settings (early stopping and LR scheduling) all interact. Tuning these components was necessary to avoid validation oscillations, premature stopping, or overfitting.
Accomplishments that I am proud of
- Built a complete end-to-end system from Parquet bytes → model-ready tensors → training → evaluation → interpretability.
- Achieved high test performance (accuracy ~0.98, AUC ~0.9975) with strong per-class metrics.
- Produced meaningful confusion matrix patterns (errors mostly between neighboring stages rather than random classes).
- Added Grad-CAM explanations for transparency and debugging.
What I learned
- Transfer learning works well even for MRI when combined with careful fine-tuning and robust regularization.
- Residual + attention refinement can improve performance by adapting generic features to subtle medical cues.
- AUC-ROC and per-class metrics are essential in medical datasets—accuracy alone is not sufficient.
- Explainability tools (Grad-CAM) are valuable to validate whether the model focuses on plausible image regions.
What’s next for NeuroStageNet
- Improve validation rigor by ensuring subject-level splitting if patient identifiers are available.
- Test external generalization on additional cohorts or scanners.
- Add probability calibration so confidence scores better reflect true likelihood.
- Explore medical-pretrained encoders or self-supervised pretraining on MRI for even stronger representations.
- Package the pipeline into a reproducible tool: saved model + inference script + Grad-CAM visualization for end users.
Built With
- deep-learning
- kaggle
- python
Log in or sign up for Devpost to join the conversation.