1. Introduction
Dementia is a condition characterized by a range of symptoms, including memory impairment and difficulties in learning. It occurs due to the loss of brain cells caused by injury or other medical conditions. Among the several types of dementia, Alzheimer’s disease (AD) is the most prevalent. Neurodegenerative disease is characterized by the loss of neurons, particularly in the cortex. This loss is caused by the formation of protein plaques made of amyloid, which damage cells outwardly, and protein tangles called tau, which destroy cells within. Amyloid plaques consist of aggregates of beta-amyloid. When these plaques accumulate between neurons, they cause a blockage in neural signaling, leading to impaired cognitive abilities such as memory. The plaques are believed to initiate an immunological response that induces inflammation, leading to subsequent cellular damage [
1]. The formation of tangles is believed to be caused by the external influence of beta-amyloid, which triggers internal pathways within the cell. This activation leads to the phosphorylation of the tau protein by kinases and results in a modification of the conformation of the tau protein, causing it to cease its support of microtubules. Consequently, clumps of tau accumulate, forming neurofibrillary tangles. Neurons with impaired microtubules and tangles cannot effectively transmit signals, ultimately leading to programmed cell death [
2,
3,
4]. The gyri undergo atrophy, resulting in their narrowing, whereas the sulci and ventricles experience dilation when the brain undergoes shrinkage due to cellular apoptosis [
2]. The majority of Alzheimer’s disease (AD) cases that start before the age of 60 (early onset) and around the age of 85 (late onset) are classified as sporadic AD, which accounts for about 90% and 50% of these cases, respectively.
These occurrences are frequently influenced by environmental, behavioral, and genetic factors [
5]. The lack of transparency in the internal mechanisms of cutting-edge AI models makes them appear as enigmatic systems, raising problems regarding trust [
6,
7]. Justifying each forecast helps to bridge this gap, as measurements like accuracy do not offer conclusive assessments of dependability in real-life situations. This occurs when a model is trained using static data, which could include instances that aid in categorization during an experimental scenario but do not accurately reflect real-world circumstances. In order to improve the generalization of models, it is important to have a deeper understanding of their behavior through interpretable explanations [
8,
9].
Deep learning applications within the scope of image processing extend through segmentation, classification and detection tasks, such as the research on multi-Source domain adaptation (MSDA) for medical image segmentation which improves the performance of segmentation models when applied to unseen datasets by leveraging multiple labeled datasets from different source domains. This addresses the challenge of domain shift, which often causes a drop in performance when a model is trained on one dataset and tested on another [
10,
11]. A proposed AI system integrates unmanned aerial systems (UASs) with computer vision based on the You Only Look Once (YOLO) framework to enable quick and accurate detection and removal of foreign object debris (FOD) on airport runways. The framework utilizes open-world recognition to identify both known and new types of debris, addressing challenges such as limited data through problem-specific data augmentation, showing improved detection capabilities compared to traditional methods as it enhances runway safety [
12]. To segment the left atrium from 3D MRIs using semi-supervised learning, transformers were used for capturing global context, along with V-Net for detailed local feature extraction. Combining these networks improves accuracy in medical image segmentation tasks with limited labeled data. The framework extends Transformer capabilities to 3D data, coupled with a discriminator module to enhance segmentation results [
13].
ResNet-50, or Residual Network with 50 layers, is a deep convolutional neural network architecture belonging to the ResNet family. It was developed to overcome the challenges associated with training extremely deep neural networks. The key innovation in ResNet-50 involves the incorporation of residual blocks, designed to enable the training of very deep networks without encountering the vanishing gradient problem. These residual blocks include shortcut connections, facilitating the direct flow of gradients through the block and preventing significant degradation. The ResNet-50 architecture consists of 50 layers, encompassing convolutional, pooling, and fully connected layers. It incorporates three main types of blocks: identity blocks, convolutional blocks with a shortcut, and the bottleneck architecture. The bottleneck architecture is particularly noteworthy for reducing computational complexity while preserving representational power.
Mathematically, the residual block in ResNet-50 is expressed as , where represents the input to the block (independent variable), is the output (dependent variable), the residual mapping to be learned is denoted by , and represents the block’s weights (coefficients associated with ). Including a skip connection allows for the bypassing of the residual mapping, facilitating the smooth flow of gradients during the backpropagation process. In essence, the architecture employs a linear equation for each layer, contributing to the overall effectiveness of ResNet-50 in training deep neural networks. The linear equation for a single layer can be denoted as in the equation .
ResNet-50 has demonstrated impressive performance across various computer vision tasks, with a notable emphasis on image classification. It functions as a robust feature extractor after pre-training on extensive datasets such as ImageNet. This quality is particularly advantageous for transfer learning scenarios, especially when dealing with applications with constraints on labeled data. The architecture’s depth, combined with the incorporation of skip connections, enhances its capability to capture intricate hierarchical features. These attributes collectively contribute to ResNet-50 being widely favored and integrated into cutting-edge deep learning models [
14,
15]. In this study, we applied a global average pooling layer to down-sample the ResNet-50 model and calculate the average value of each feature map in the input tensor, which results in a single value feature map calculated with the following formula:
, where
represents the value at position (
) in the channel
of the feature map,
is the height as the width of the feature map is denoted by
. The summation is performed over all positions in the feature map (
). This operation reduces dimensionality before feeding the features into fully connected layers. We created a dense (fully connected) layer with three units and a softmax activation function to convert raw scores into probability distributions over multiple classes. For a given vector
of k real numbers, the softmax function
for each element
is computed by
, where
is Euler’s number, the
-th element is
(of the input vector
), and the denominator is the sum of the exponentials of all elements in the vector. This function ensures that non-negative values are returned and represent valid probabilities summing to 1. The output class is predicted based on the class with the highest probability [
16]. The model was compiled to create a keras model to specify inputs from the predefined ResNet-50 model and outputs from the dense layer.
The primary objectives of this study are as follows:
To develop a reliable and accurate classification model utilizing a deep transfer learning architecture.
To extract and represent deep features relevant for inclusion and exclusion during classification.
To enhance the interpretability of results by exploiting XAI techniques through visualization.
2. Related Works
Explainable artificial intelligence (XAI) does play a pivotal role in the landscape of medical applications within the precedent decades with a specific emphasis on addressing the complexities inherent in diagnosing neurodegenerative diseases, such as Alzheimer’s disease, vascular dementia, Parkinson’s disease, and other disorders related to cognitive decline. ML (machine learning) methods, e.g., support vector machines and random forest, trade off interpretability at the expense of accuracy, whereas deep learning (black-box) models trade off accuracy at the expense of interpretability. Incorporating XAI techniques within this domain is driven by the overarching goal of augmenting transparency, interpretability, and trust in machine learning models. This strategic integration aims to furnish clinicians with in-depth insights into the intricate decision-making processes of these black box models. A discerning overview of the pertinent literature reveals a multitude of studies delving into the application of XAI methodologies for Alzheimer’s diagnosis, underscoring the substantive contributions within this rapidly evolving field.
Muthamil Sudar et al. [
17] endeavored to delineate the various stages of Alzheimer’s disease through the utilization of the layer-wise relevance propagation (LRP) method within the realm of explainable artificial intelligence (XAI), employing image data as input. Beyond LRP, the study incorporated additional algorithms, such as VGG-16 and CNN, with the goal of improving overall performance and achieving heightened batch accuracy. The primary focus of the project was to enable a thorough analysis of Alzheimer’s disease by leveraging XAI, accompanied by detailed feature explanations. The findings outlined in the article provide a lucid comprehension of Alzheimer’s analysis, reinforcing the results with elaborate explanations to bolster the trustworthiness and dependability of XAI.
The investigation by El-Sappagh, S. et al. [
18] highlights the intuitive significance of cognitive scores, including CDRSB and MMSE, in effectively identifying patients with Alzheimer’s disease (AD), a consensus validated by domain experts. However, when focusing on progression detection, the study reveals that the volumes of Hippocampus and MidTerp obtained from MRI images, coupled with FDG and SROI from PET images, exert notable influence. The study employs SHAP explainers to compute feature contributions of random forest (RF) models, as delineated in the Explainability Capabilities Section of the Material and Methods. A condensed overview of the explainer’s responsiveness to diverse feature values is presented for both the first and second layers of the research findings. However, this study employed SHAP on clinical variables that are unlikely to be interpretable by patients and entry-level practitioners, unlike the approach we propose.
In an investigation by Achraf Essemlali et al. [
19], an experiment employing explainable AI aimed to unravel the connectomic structure associated with Alzheimer’s disease (AD). Through the utilization of a CNN trained on the brain connectomes of ADNI patients, the researchers executed an ablation procedure to showcase that the manifestation of AD is not solely linked to a specific brain region but rather results from the cumulative effects across various cortical regions. The study underscored the entorhinal region as the most notable distinction between AD and normal control (NC) groups, while the hippocampus exhibited significance in the comparison between mild cognitive impairment (MCI) and NC. These findings align with established research methodologies such as voxel-based morphometry, cortical thickness, or functional connectomics in AD studies. The research signifies the potential of deep convolutional networks in providing intricate insights into the complexities of neurodegenerative diseases. However, the study emphasizes the necessity for a cautious interpretation of the saliency map, acknowledging that the correlation with neural net predictions may be influenced by variations in structural connectivity estimated from DW-MRI. Our study puts this notion into practice by applying a channel-wise attention mechanism to enhance the performance within Grad-CAM.
The study by Eduardo Nigri et al. [
20] introduces the swap test method as a novel approach to generate heatmaps, offering insight into the key brain regions indicating Alzheimer’s disease (AD) for improved interpretability by clinicians. Through axiomatic evaluation experiments, it is demonstrated that the swap test outperforms a conventional occlusion test in explaining AD diagnosis using MRI data. These findings suggest that the swap test has the potential to mitigate the inherent black box nature of deep neural networks commonly used in AD diagnosis, providing a valuable tool to enhance transparency and interpretability in the decision-making process. In our study, we employ XAI techniques with lower computational complexity yet similarly achieve results consistent with the medical literature.
Shangran Qiu et al. [
21] introduce a sophisticated deep learning pipeline that combines a fully convolutional network (FCN) with a multilayer perceptron (MLP) to directly predict Alzheimer’s disease status using MRI data or a blend of MRI and non-imaging data. The FCN produces high-resolution disease probability maps illustrating local cerebral morphology and Alzheimer’s risk. Leveraging these maps and non-imaging features like age, gender, and MMSE score, the MLP achieves accurate predictions across diverse cohorts. The FCN is specifically trained on randomly selected sub-volumes of MRI data, allowing for efficient processing without redundant decomposition of full-sized test images. The study underscores the interpretability of disease probability maps and their anatomical consistency, shedding light on structures most impacted by neuropathological changes in Alzheimer’s disease. Population-wide maps of Matthew’s correlation coefficient contribute to identifying crucial regions for precise disease status predictions. Our approach, however, achieves high prediction performance with anatomical consistency as well, without the dependence on non-imaging data, which is prone to error if obtainable in practice.
In the context of predicting biological age (BA), I. Boscolo Galazzo et al. [
22] explore the valuable framework of the BA prediction paradigm, aiming to understand the underlying factors influencing an individual’s biological age and to characterize diverse aging trajectories. This paradigm not only offers insights into brain mechanisms but also provides a means to identify potential risks associated with cognitive aging and age-related brain disorders. The study emphasizes the promising potential of both ML and deep learning (DL) approaches, particularly in multimodal settings, and highlights the significance of investigating specific BA estimates derived from selective and regional ensembles of intrinsic disorder profiles (IDPs). Furthermore, there is an emphasis on the crucial role of explainable AI (XAI) in enhancing BA prediction, as it contributes to the interpretability of linear and latent variable models, providing user-friendly visualizations of essential features and supporting the application of complex deep models. Our study bridges the gap between complexity and interpretability cited by the authors using model-agnostic and model-specific approaches.
The recent study by Yousefzadeh et al. [
23] introduces a novel explainable AI framework called “Granular Neuron-level Explainer” (LAVA) aimed at assessing Alzheimer’s disease (AD) using retinal fundus images. LAVA delves into the intermediate layers of a CNN model to identify key neurons that play a significant role in distinguishing between various stages of AD, thus offering an interpretable diagnostic method. Leveraging data from the UK Biobank, the research demonstrates LAVA’s effectiveness in differentiating AD stages by analyzing retinal vascular features, suggesting that retinal imaging could be a valuable, non-invasive tool for early AD diagnosis. However, the study acknowledges the limitations of its small sample size and emphasizes the need for further research to validate these findings.
3. Materials and Methods
In this paper, the input data are sent to a deep learning model for multiclassification. The resulting predicted output is evaluated using two explainable AI techniques (
Figure 1). The method we propose is based on a ResNet-50 network coupled with a channel-wise attention mechanism to perform classification. We created a local ResNet-50 model based on ImageNet weights to which we internally trained on the MRI scans. This curbs the risk of inadvertent sharing of participant-level data while preserving the analytical advantage of our approach. To assess the model’s prediction, we first employ LIME using the quickshift method to highlight key features with superpixels. Our number of perturbations is 150, with the quickshift kernel size as 70, a max distance of 200 and a 0.2 ratio. We generate superpixels with this set-up and analyze the output image. Secondarily, we used Grad-CAM to create a superimposed image, displaying a heat map based on the final feature map from the last convolution layer. We applied channel-wise attention within Grad-CAM in order to enhance the quality of the three channels when generating the final jet heat map.
To conduct the experiments in this research, we used the Alzheimer’s Disease Neuroimaging Initiative (ADNI) data, which the Laboratory disseminated for Neuro Imaging at the University of Southern California. Samples can be seen in
Figure 2.
This study is based on the publicly available, large-scale ADNI dataset consisting of 10,346 sagittal brain MRI scans, categorized into three classes, as shown in
Table 1: normal cognition (NC), mild cognitive impairment (MCI), and Alzheimer’s disease (AD). Normal cognition is the label for patients that have no indicators of cognitive decline, such as no memory loss or motor impairments, as evidenced by the lack of pathological changes in their brain scans, while patients with mild cognitive impairment show signs of forgetfulness, which is consistent with signs of deterioration around the hippocampus. This progresses into moderate dementia, where patients begin to forget their personal history. When cognition continues to decline as the patient develops Alzheimer’s disease, more historical details are forgotten, along with confusion due to neuronal damage throughout the cerebral cortex, as evidenced by the narrowing of gyri coupled with the dilation of the sulci and ventricles. Severe dementia from this point requires supervision as patients begin to forget their family members while requiring assistance for daily activities due to the development of motor symptoms; this is the stage before death according to the global deterioration scale [
2,
3,
4].
These sets were then normalized through mean standardization and thereafter used to train a deep learning model using ResNet50. The positive outcome derived from the classification was utilized as input for the XAI models. The standardization is given by
where
represents the standardized value of the original variable X, Mean(X) is the average of X, and Standard Deviation(X) quantifies how much each data point differs from the dataset’s average. In this study, we implemented a channel-wise self-attention mechanism to improve feature representation in a convolutional neural network. This mechanism consists of an attention block. The input tensor is initially processed through a convolutional layer with a 1 × 1 kernel and ReLU activation to produce an intermediate feature map. This map then passes through another convolutional layer with a 1 × 1 kernel and sigmoid activation to generate attention weights. These weights are applied element-wise to the original input tensor, highlighting crucial channels and diminishing less important ones. After convolution and pooling layers, this attention block is integrated twice within the network to dynamically modify channel significance at various stages, thereby enhancing feature extraction and classification accuracy. This methodology showcases the efficacy of channel-wise attention in boosting neural network performance. The mechanism is denoted as follows:
The input feature map
F is a tensor representing the original input features extracted from the previous layers, with dimensions corresponding to the image’s height, width, and number of channels. The weight matrix
of the first convolutional layer uses a 1 × 1 kernel to transform these input features, with learnable parameters adjusted during training. The bias vector
adds bias terms to each channel output from the first convolutional layer. The Rectified Linear Unit (ReLU) activation function introduces non-linearity by setting negative values to zero. The intermediate feature map
results from this convolution and ReLU activation, serving as an intermediate representation of the input features.
The attention map
is calculated using the sigmoid activation function σ applied to the result of a convolution operation. This involves the weight matrix
of the second convolutional layer, which uses a 1 × 1 kernel to further process the intermediate feature map
produced by the previous convolution and ReLU activation, along with the bias vector
, adding bias terms to each channel output. The sigmoid function scales these values to the range [0, 1], creating the attention weights in the attention map, indicating each channel’s importance. The attention map
generated in the previous step indicates the relative importance of each channel. The original input feature map
is then element-wise multiplied
with this attention map. The result of this operation is the attended feature map
, which enhances significant channels and suppresses less relevant ones [
21,
22,
23,
24].
4. Results
Figure 3 and
Figure 4 depict the training and prediction performance of our model in a visual format. The expression denotes training accuracy:
Validation accuracy is denoted as follows:
Figure 3.
Model accuracy against the count of epochs during training and validation.
Figure 3.
Model accuracy against the count of epochs during training and validation.
Figure 4.
Confusion matrix for the pre-trained model.
Figure 4.
Confusion matrix for the pre-trained model.
The confusion matrix in
Figure 4 displays the network’s predictive performance on the three classes: NC (normal cognition), MCI (mild cognitive impairment), and AD (Alzheimer’s disease).
Figure 3 is a visualization of the deep learning model’s performance where the highest training accuracy of 85% is displayed, calculated through Equation (5). The training process was stopped as the model had 0 out of 23,587,712 trainable parameters.
Figure 4 visually represents the label predictions using the confusion matrix. One example from the test set produced a positive prediction for mild cognitive impairment (MCI), as shown in
Figure 5. This image served as the input for our XAI experiments.
In this research, we utilized local interpretable model-agnostic explanations (LIMEs) as a method for explainable AI. LIMEs were implemented to demonstrate how the classifier behaves in relation to the predicted instance, specifically focusing on local fidelity. The quickshift segmentation technique generates superpixels, with each superpixel represented by a binary vector. The algorithm is represented by the following formula:
where feature vectors
and
represent pixel intensities,
is a bandwidth parameter, and the speedup is a constant to speed up computations. This algorithm detects image regions by continuously updating the density mode of individual pixels according to their neighbors until convergence. The process involves utilizing the previously mentioned distance measure to assess the similarity between pixels.
A value of 1 in the vector indicates the original superpixel, whereas 0 indicates a greyed-out superpixel. The data points disturbed are assigned weights based on their closeness to the original example to train a comprehensible model on the corresponding predictions. A binary matrix is created with perturbations as rows and superpixels as columns. In this matrix, an activated superpixel is represented by 1, while a deactivated (off) superpixel is represented by 0. The cosine distance is calculated between each randomly generated perturbation and the explained image. The distances are transformed into a numerical range of 0 to 1 using a kernel function, and then the coefficients are sorted to identify the superpixels with higher magnitudes. This process masks less significant superpixels and produces an image that highlights the most significant superpixels. The main equation for LIME is given as follows:
where
represents the local surrogate model,
is the black box model being explained,
is the interpretable (surrogate) model,
is the perturbation distribution, for instance
, and
is the interpretable representation. Instances from a neighborhood
around the instance
are represented by
, where
is a perturbed instance sampled from the distribution πx, and
is a specific reference instance from the neighborhood
. The perturbation distribution, for instance
, indicates how instances are sampled from the neighborhood. The difference between predictions of the black box model f for the perturbed instance
and predictions of the interpretable model
for the reference
is represented by the expression
(
Figure 6).
Given that the LIME methodology includes perturbing the instance
, sampling perturbed instances
from the neighborhood, obtaining predictions from the black box model
, and fitting an interpretable model
to locally approximate
s behavior, the interpretable representation
is utilized to emphasize significant features or regions represented by superpixels. The deactivation of these pixels returns a greyed-out representation of excluded features in the model’s prediction (
Figure 6). Conversely, the activation of superpixels returns a mapping of feature regions that were relevant to the model’s prediction (
Figure 7).
Additionally, we employed gradient-weighted class activation mapping (Grad-CAM) to visually identify significant areas that contribute to the model’s prediction. A convolutional neural network was utilized to analyze the image and extract features at various resolutions. The last layer of the network generates scores based on probabilities, which represent the classification of the image. The class score is calculated by
where
represents the score for class
, the weight of the
-th feature map for class
is represented by
and
is the
-th activation map.
The gradient of the projected score
, with respect to the feature map of the last CNN layer, is computed to quantify the impact of each feature on the class score. The significance of each feature map is determined by taking the average of gradients denoted by
where
represents the loss for class
in the network,
represents the activation map for the
-th feature in the final convolution layer,
represents the gradient score, and
denotes the summation over all activation maps in the network. This updates parameters during backpropagation to minimize loss and improve classification performance. The heat map is activated using the ReLU function so that only positive values contribute to the visualization. The intensity of each pixel in the heatmap corresponds to the spatial location for the image. The smoothing operation convolves the original Grad-CAM heat map with the importance weights to provide an interpretable visualization by smoothing out sharp edges, which returns a final heat map. The formula is denoted by
where ReLU, the rectified linear unit activation function, sets negative values to zero,
is the sum of all activation maps,
represents the weights relative to the
-th activation map for class
, and
is the
-th activation map. The ultimate standardized heat map accentuates the areas upon which our deep learning algorithm relies to provide predictions [
8,
9,
25,
26,
27]. We employed a channel-wise self-attention mechanism to highlight significant features within each channel of the input tensor (
Figure 8). This mechanism was accomplished through two convolutional layers with (1, 1) kernels that generate an attention map, which is then element-wise multiplied with the input tensor. The initial convolution layer, followed by a ReLU activation, formed an intermediate representation, while the subsequent convolution layer with a sigmoid activation created the attention weights. These weights adjust the emphasis on different channels, enabling the model to enhance or reduce specific features, thereby improving the learning process and overall performance [
24,
28,
29,
30].
5. Discussion
The challenges associated with diagnosing and predicting the progression of Alzheimer’s disease (AD) have been mitigated by using artificial intelligence (AI), particularly in image data analysis. However, in contrast to various brain disorders, AD remains difficult to comprehend despite being relatively easier to categorize based on established features. This study utilized the glass-box methodology to illustrate that the deep learning model generated positive predictions corresponding to the specified characteristics. What makes this research unique is the employment of the channel-wise attention mechanism not only to the deep learning model but also the XAI model, as shown in
Figure 9, wherein the left shows the base Grad-CAM with a generalized heat map while the right shows the model with channel-wise attention applied, returning a well-defined heat map. The image on the left of
Figure 9 was created using traditional Grad-CAM and, in comparison to our approach that hinges on an attention mechanism, it is evident that the heat map generated with channel-wise attention highlights the region of interest more accurately, which is in line with the medical literature on MCI diagnosis. The hippocampal area shows early signs of degradation for cases of mild cognitive impairment. While conducting this study, a few challenges were met during the conceptualization process. It was not straightforward when deciding which techniques to experiment with in order to avoid redundant output while supplementing the study collectively. Since the quickshift method is highly sensitive, it called for gradual hyperparameter tuning to obtain an acceptable kernel size of 70. The application of this mechanism improves the quality of the explanation on a more granular scale in contrast to traditional tuning, since each channel in the feature map is considered separately, wherein the focus is only on the most relevant channels. This approach verifies the classification while enhancing localization and the overall scope of this methodology is not limited to MRI studies; our method can be employed for generalized imaging tasks such as object detection. The flexibility of this framework also allows for experimentation with multimodal data such as MRI and diffusion tensor imaging (DTI).
Our proposed method employs a hybrid model combining ResNet-50 with attention mechanisms and XAI techniques to enhance the interpretability of Alzheimer’s disease (AD) classifications. The use of channel-wise attention improves feature extraction, and the integration of XAI methods provides visual insights into model predictions, helping to identify regions of interest like the hippocampus, a critical area associated with Alzheimer’s disease. In contrast to previous studies, such as Muthamil Sudar et al. [
17] who used layer-wise relevance propagation (LRP) to explain AD progression, and El-Sappagh et al. [
18], who used SHAP with random forest for clinical data, the current study leverages deeper models and focuses more on the explainability in an image-centric context. Sudar et al. aimed at batch-level accuracy through models like VGG-16 and CNN, while the current study integrates attention mechanisms directly into the deep learning architecture to boost performance and interpretability. Furthermore, Jahan Sobhana et al. [
31] used multimodal data and random forest with SHAP for predicting multiple AD classes. The current research, by focusing solely on MRI data and applying advanced explainability techniques like Grad-CAM, provides a more targeted approach in visualizing critical brain areas contributing to the diagnosis, differentiating it from models that depend heavily on multimodal inputs. This proposed framework delivers a highly explainable AI model capable of identifying specific pathological changes linked to AD, bridging the gap between model accuracy and interpretability for clinical use.