1. Introduction
Brain cancer is one of the ten leading causes of death globally among men and women [
1,
2]. The World Health Organization estimates the 5-year survival rate is only 21% for people aged 40 and over [
2]. In most clinical scenarios, LGGs are well-differentiated, slow-growing lesions, while HGGs are usually aggressive with dismal prognosis [
3,
4]. Survival rates differ markedly for different tumor grades. Identifying tumor grade at an early stage is a major unmet need; it contributes to formulating better treatment strategies and enhances the overall quality of life of patients.
Magnetic resonance (MR) imaging is a non-invasive technique that remains the standard of care for brain tumor diagnosis and treatment planning in clinical practice [
5,
6]. It provides a reasonably good delineation of the gliomas and conveys biological information on the tumor location, size, necrosis, edema tissue, the mass effect, and breakdown of the blood–brain barrier (which results in contrast enhancement in post-contrast-enhanced T
1-weighted (ceT
1w) MR images) [
6]. In general, LGGs are less invasive. They usually have well-defined boundaries and homogeneous tumor cores without prominent mitosis, necrosis, and microvascular proliferation [
6,
7,
8,
9]. HGGs always show more mass effect. They usually show microscopic peritumoral white matter tract invasion. The demonstration of this diffuse infiltration is an important discriminating feature for the accurate glioma diagnosis [
6].
Diagnosis of brain tumors from MR images is a time-consuming and challenging task that requires professional knowledge and careful observation. As alternatives, various automated diagnosis approaches have been developed to assist radiologists in the interpretation of the brain MR images and reduce the likelihood of misdiagnosis. Convolutional neural networks (CNNs) provide a powerful technology for medical data analysis [
10]. CNN-based deep learning architectures can extract important low-level and high-level features automatically from the given training dataset of sufficient variety and quality [
11]; they embed the phase of feature extraction and classification into a self-learning procedure, allowing fully automatic classification without human interaction, which can be applied to the problem of tumor diagnosis.
Over the last decade, methods using CNNs have been extensively investigated for brain tumor classification due to their outstanding performance with very high accuracy in a research context [
12,
13]. The differential classification of HGG and LGG is a comparatively simple task that has been tackled in numerous different ways using different CNN methods, and the best-performing models have demonstrated close to 100% performance [
10]. For example, Khazaee et al. [
14] used a pre-trained EfficientNetB0 for HGG and LGG classification. The model achieved a mean classification accuracy of 98.87%. Chikhalikar et al. [
15] proposed a custom CNN model to classify the type of tumor present in MRI images, achieving an accuracy of 99.46%. The authors in [
16] used transfer learning with stacking InceptionResNetV2, DenseNet121, MobileNet, Incep-tionV3, Xception, VGG16, and VGG19 for the same classification task. The average classification accuracy for the test dataset reached 98.06%. Zhuge et al. [
17] utilized a pre-trained ResNet50. The classification accuracy of the proposed model reached 96.3%.
The above CNN-based methods all achieved remarkable performance on automated HGG and LGG classification. However, MR images are unlikely to be artifact-free [
18], and the lesion signal measured by MRI is typically mixed with nuisance sources. The above-mentioned black-box CNNs may learn confounding sources from MR images for decision making, and the health outcomes cannot easily gain the trust of physicians or patients because the evidence is unknown [
6,
19].
The lack of transparency and interpretability concerning the decision-making process still limits their development into clinical practice [
12,
19,
20]. Visualizing the features that are faithful to the underlying lesion is crucial to ensuring the interpretability and trustworthiness of classification outcomes. Interpretability is the ability to provide explanations in terms understandable to a human [
21], based on their domain knowledge related to the task, or common knowledge, according to the task characteristics. The need for interpretability has already been stressed by many papers [
21,
22,
23], emphasizing cases where lack of interpretability may be harmful. Can we explain why algorithms go wrong? When things go well, do we know why and how to exploit them further?
In order to deploy a system in practice, it is necessary to present classification results in such a way that they are acceptable to end users. This is only possible if users trust the decision-making process, which, as a consequence, must be transparent and interpretable. To date, a limited number of saliency-based interpretable methods have suggested different frameworks to improve the interpretability and trustworthiness of CNNs for brain tumor classifications [
24,
25,
26,
27]. We divide the previous interpretable approaches into two categories: object-level methods and pixel-level/part-level methods.
At the coarsest level, there are models that have been proposed to offer object-level explanations for brain tumor classification tasks, such as a class activation mapping method GradCAM [
24,
25] that highlights that entire object as the explanation behind the tumor predictions. The authors in [
25] proposed a pre-trained ResNet-50 CNN architecture to classify three posterior fossa tumors and explained the classification decision by using GradCAM. The heatmap generated by the GradCAM technique can identify the area of emphasis and help visualize where the classification model looks for individual predictions.
At a finer level, there are a few interpretable techniques that have been applied to explain the brain tumor classification results with pixel-level/part-level explanations, such as pixel-level interpretable algorithms SHAP, Guided Backpropagation (GBP) [
24], and a part-level interpretable model called LIME. Authors in [
27] explained the tumor predictions made by the CNN model with SHAP and LIME methods. The SHAP algorithm explains the individual prediction by computing the contribution of each pixel on a predicted image to the prediction using Shapley values to understand what are the main pixels that affect the output of the model [
28]. The LIME algorithm is a counterfactual explanation method that approximates the classification behavior of a complex neural network using a simpler, more understandable model without exploring the model itself [
29]. In the study, the authors segmented the input image into superpixels and made small disturbances around each superpixel to figure out the contribution/importance of each superpixel to the prediction result. Another study conducted by Pereira et al. [
24] utilized GradCAM and GBP maps to provide insights into the regions that support the prediction to perform quality assessment of tumor grade prediction between HGG and LGG. The GBP is a gradient-based visualization method that can visualize which pixels in the input image are more informative for the correct classification.
The above methods identify the most important pixels or objects of an image as the explanation for the prediction outcomes. To some extent, they verify the validity of the classification models. Nevertheless, it is worth stressing that knowing the most important pixels or objects of an image that determined a specific prediction does not always amount to a good-quality explanation.
Ideally, networks should be able to explain the reasoning process behind each individual decision, and this process, ideally, would be similar to that used by a radiologist, who looks at specific features of the MR image relevant to the task. For example, if a doctor classifies a tumor as HGG, this decision always relies mainly on the high-level class-representative features or properties, like the tumor’s irregularity, the necrotic area, or the enhancing ring [
30].
The objectives of this study were to build an interpretable multi-part attention [
31] network (IMPA-Net) for brain tumor classification to unbox the model and the reasoning process of individual predictions with understandable MR imaging features. The proposed IMPA-Net, motivated by [
32], provides both global and local explanations for brain tumor classification on MRI images.
Figure 1 gives a more detailed illustration of the connections and distinctions between the two explanations. The global explanation is represented by a group of feature patterns that the model learns and uses for the classification. The quality of the feature patterns can be used to evaluate the ability and reliability of the model on the classification task. The local explanation interprets the reasoning process of an individual prediction by comparing the prototypical parts of the image with feature patterns. It can be used to evaluate the trustworthiness of individual predictions.
The main contribution of this paper is that it addresses the black-box problems of CNN classification models for glioma diagnosis by developing a model with the following characteristics:
- (i)
The first multi-part interpretable model that can provide both global and local explanations for brain tumor classification, enabling better human–machine collaboration for decision aid.
- (ii)
It presents the reasoning process of individual predictions to show how the model arrives at the decision making in this context, allowing health workers to evaluate the reliability of the prediction outcomes.
- (iii)
It allows the prediction results to be interpreted in a clinical context.
- (iv)
It highlights the most relevant information for predictions based on medical disease-related features that can be understood and interpreted by clinicians and patients.
The remainder of the paper is structured as follows.
Section 3 gives a detailed introduction to the dataset, the proposed interpretable multi-part attention network, and the experimental setup. Results are given in
Section 3.
Section 4 evaluates the performance of the proposed method on both aspects of its classification and explanation.
Section 5 concludes the key findings of this study.
Section 6 concludes the proposed work and discusses the future research directions.
2. Materials and Methods
The overall workflow of the development and evaluation of the proposed methodology is shown in
Figure 2. Input brain MRI images are firstly pre-processed by resizing, normalization, and cropping, and then three augmentation methods, including rotation, shearing, and skewing are performed to produce the training dataset. The proposed methodology classifies the input image by comparing its prototypical patches with pre-learned feature patterns of classes HGG and LGG. In this stage, feature patterns of both classes are optimized and produced. The quality of the feature patterns is evaluated in the next step on aspects of their interpretability, class representability, and correctness, and then poor-quality feature patterns are excluded in the local explanation process. In the next stage, local explanations of individual predictions are given to illustrate how the model arrives at the final decisions, and each case will be evaluated based on whether it satisfies two basic conditions identified for reliability assessment. Finally, the proposed model is evaluated on both aspects of its performance (classification and explanation), including classifier performance, global explanation evaluation, local explanation evaluation (correctness and confidence), and user evaluation.
2.1. Data and Image Processing
We trained and evaluated our network on data from the BraTS 2017 database [
33,
34,
35]. The dataset contains 285 routine-acquired 3T multimodal clinical MRI scans from multiple institutions, comprising 210 patients with pathologically confirmed HGG and 75 patients with LGG. All images from the dataset were pre-processed by co-registration to the same anatomical template, interpolation to the same resolution (1 mm
3), and skull stripping [
33].
Slices that contain gliomas were extracted from each patient’s MRI scan. Considering the enhancing ring in post-contrast-enhanced T
1-weighted (ceT
1w) MR is an important discriminating feature for accurate tumor diagnosis between HGG and LGG [
6], in our experiments, only ceT
1w MR images were considered. The dataset was then partitioned into a training dataset (70%) and a testing dataset (30%). A push dataset of 60 images was randomly selected from the training dataset (30 images for each class).
All images were normalized by Z-score normalization and converted to PNG format, and then the background pixels were cropped to focus feature learning on the brain areas instead of the whole image. Moreover, the images were resized to 224 × 224 to fit the model’s training configurations.
2.2. Data Augmentation
To increase the size and variability of the training dataset, data augmentation methods were performed, including twice rotating in the axial imaging plane by a random amount between 20° left and 20° right, shearing by a random amount between 10° left and right twice in the transverse direction, and skewing by tilting the images left/right by a random amount (magnitude = 0.2) twice. In this way, the training dataset is augmented six-fold, resulting in 6228 images (3546 HGG, 2682 LGG).
2.3. Interpretable Convolutional Neural Network
Figure 3 gives an overview of the proposed IMPA-Net, which consists of a feature extractor, multi-part attention (MPA), and similarity-based classifier. Images are first propagated into convolutional layers for feature extraction, with a structure selected from VGG16. In the proposed classification model, we chose VGG16 as the feature extractor as it combines simplicity, ease of implementation, and fine-tuning capability with adequate feature extraction effectiveness and generalization ability. The pre-trained VGG16 model is suitable for transfer learning or fine-tuning as a feature extractor for brain tumor classification tasks [
12]. A non-linear activation function ReLU is used for all convolutional layers. Then, these convolutional layers are followed by a multi-part attention module for similarity calculation between CNN outputs and the feature patterns pre-learned by the model. In particular, our network tries to find evidence for an image (such as the pre-processed HGG image in
Figure 3) to be of class HGG by comparing its prototypical patches with learned feature patterns of class HGG and LGG, as illustrated in the similarity correlation units. This comparison produces a map of similarity scores of each feature pattern, which is upsampled and superimposed on the input image to see which part of the input image is activated by each feature pattern. The activation maps are then propagated into a max-pooling layer, producing a single similarity score for each comparison. Finally, the model classifies the input image based on the top 10 similarity scores. The output
denotes the weighted sum of top-10 similarity scores generated by the multi-part attention module.
2.3.1. Feature Extractor
The architecture consists of a regular convolutional neural network for feature extraction with a structure selected from VGG16 (kernel size ), followed by two additional 1 × 1 convolutional layers. All these convolutional layers ( use a ReLU with a non-linear activation function.
For a given pre-processed input image
(such as the HGG sample image in
Figure 3), the convolutional layers
extract useful features from
to use for prediction, whose output
have spatial dimension
, where D is the number of the output channels of the last convolutional layer.
2.3.2. Multi-Part Attention
In our experiments, we allocated a pre-determined number of feature patterns , for each class, where () represents the class identity of the feature pattern and is the index of that feature pattern among all feature patterns of class . So that for each class, feature patterns are learned and produced by the model from a push dataset. This dataset consists of a pre-determined number of MRI images that are randomly selected from the training dataset. The shape of each pattern is , where . In our experiments, and are set to 1. The depth of each feature pattern is the same as that of but the height and width are smaller than those of the , each feature pattern will be supposed to represent some representative activation pattern in a patch of the convolutional output , which in turn will correspond to some prototypical image patch in the original training image.
In our network, every feature patch can be considered as a representative pattern of one image from the push dataset, and these feature patterns are supposed to direct attention to enough medical semantic content for recognizing a class [
36]. As a schematic illustration of the multi-part attention for the HGG sample image in
Figure 3, the first feature pattern
corresponds to the necrotic tumor core of an HGG training image, and the fourth feature pattern
enhancing tumor margin of an HGG training image, and the ninth feature pattern
the edematous area of an HGG image.
The similarity correlation units
in a multi-part attention module computes the L2 distance between the CNN outputs and the feature patterns, as shown in Equation (1). The
similarity correlation unit
of class
calculates the squared Euclidean distances between feature patterns
and each patch
generated from the convolutional outputs
and then inverts the distances to similarity scores. Mathematically, the similarity correlation unit
calculates the following:
These similarity scores calculated by Equation (2) define an activation map, which retains the spatial relation of the convolutional output
. The activation map can be unsampled to the size of the input image to visualize the part of the input image that looks most similar to the feature pattern [
36]. In
Figure 3, the similarity score between the first feature patterns
, a an HGG necrotic tumor core, and the most activated patch of the input image of a an HGG is
. The similarity score between the fourth feature pattern
, an HGG enhancing tumor margin, and the most activated patch of the input image is
. The third feature pattern
, an HGG edematous area, activated mostly on the edematous tissue of the HGG sample image, with a similarity score of
. This shows that our model finds that the necrotic tumor core of the HGG sample image has a stronger presence than that of enhancing tumor margin in the input image.
Equation (2) indicates that the similarity is monotonically decreasing with respect to the squared Euclidean distance, that is, the highest similarity score of the similarity correlation unit comes when is the closest patch to . In activation maps, warmer values indicate higher similarity between the learned feature patterns and the parts of the input image activated by the feature pattern, which is enclosed in the yellow rectangles on the superimposed source images. Then, the activation maps produced by similarity scores are max pooled to reduce to a single similarity score for each feature pattern . Hence, if the similarity score of the similarity correlation unit is high, it indicates that there is a patch in the input image that is very similar to the feature pattern of class in the latent space, and that the activated patch contains a similar pattern to that represented in the feature pattern.
2.3.3. Similarity-Based Classifier
Finally, in the classifier block, the top 10 ranking similarity scores are multiplied by the class-connection weight matrix
to produce the output logit to class
. The matrix
represents the relationship between feature patterns and the logit of the class. Higher class-connection values refer to higher representability of the feature pattern to its class.
2.4. Model Training
The training of the proposed model is divided into three stages: stochastic gradient descent (SGD) of layers before the classifier layer, projection and optimization of feature patterns, and optimization of class-connection weights.
2.4.1. Stochastic Gradient Descent (SGD) of Layers before the Classifier Layer
The architecture aims to learn meaningful and teak-relevant features that can be used to distinguish between HGG and LGG, where the most important patches for the classification task are clustered (in Euclidean distance) around similar feature patterns of the ‘correct’ class and separated from feature patterns from a different class [
36]. To learn these features, an iterative algorithm SGD is used to simultaneously optimize the parameters of the convolutional layers
(
) in the feature extractor and the feature pattern
in the multi-part attention module via back propagation. In this step, the weight matrix (class connection values)
of the last layer in the classifier block is frozen.
Formally, let
be a set of training images,
be the set of the corresponding labels. The optimization problem to be solved here is to minimize the defined loss function that incorporates the cross-entropy loss (CELoss), cluster loss (ClstLoss), and separation loss (SepLoss):
where ClstLoss and SepLoss are
The CELoss penalizes misclassification during the training process, and the aim is to minimize CELoss to give better classifications. The ClstLoss is minimized to encourage the prototypical parts to cluster around the correct class, see Equation (6), whereas the SepLoss is minimized to separate the prototypical parts from the incorrect class; see Equation (7).
2.4.2. Projection of Feature Patterns
To visualize which parts of the training images from the push dataset are used as feature patterns, the network projects every feature pattern
onto the closest patch of the output
that has the smallest distance from
, and the closest patch has the same class
as that of
[
32]. The reason is that the patch of training image
that corresponds to
should be the one that
activates most strongly on. We can visualize the part of
on which
has the strongest activation by forwarding
through a trained network. Mathematically, for feature pattern
of class
(
), the network performs the following update:
2.4.3. Optimization of Class-Connection Weights
In this stage, all the parameters from the convolutional layers and multi-part attention blocks are frozen, and a convex optimization on the class-connection weight matrix
of the last layer is performed. To rely only on positive connections between feature patterns and logits, the negative connection
is set to 0 for all to reduce the reliance of the model on a negative reasoning process of the form “this image is of class HGG because it is not of class LGG.”. Mathematically, we perform this step to optimize
2.5. Experimental Setup
All the experiments were conducted on a PC with an Intel Core i7-6700K 4.00 GHz processor running Ubuntu 18.04.6 with one NVIDIA GeForce RTX 2060, using Python 3.9.7 and PyTorch 1.10.1.
The parameters of the convolutional layers from the VGG16 model were pre-trained on ImageNet [
37], and the parameters of the additional convolutional layers were initialized with Kaiming uniform methods [
38]. The parameters of the two additional convolutional layers are trained and optimized with the learning rate
for 5 epochs, while the pre-trained parameters and biases are fixed. In the following joint training stage, the parameters of all convolutional layers are optimized from epoch 6, and the model performs feature pattern projection every 20 epochs, that is, epochs 20, 40, 60, 80, and 100, and the convex optimization of the last layer is performed after each feature pattern projection process for 20 iterations with learning rate
.
The other hyperparameters are learning rate for layers pre-trained on ImageNet: and learning rate for feature pattern optimization: . For VGG16, we set as the number of channels in a similarity correlation unit.
5. Discussion
This work proposed an interpretable multi-part attention network for brain tumor classification. In detail, the widely used VGG16 was built with a specific interpretable architecture to ensure good enough classification performance for the BRATS 2017 dataset. The model was evaluated in terms of both classification and explainability perspectives. Results demonstrated the model produced accurate tumor classification, and the classification accuracy is on par with some of the best-performing CNN models. Furthermore, the proposed framework is able to provide higher quality explanations for HGG and LGG classification, including global explanation and local explanation.
In detail, global explanation is interpreted as a set of feature patterns the model learns from to classify HGG and LGG. The quality of the feature patterns in terms of their validity and representativity was evaluated by radiologists to see if they were valid evidence for decision aids. Results demonstrated the model learns from the class-representative features of both classes for the classification task, and the HGG feature patterns have higher responses in the contrast-enhancing tumor, necrotic tumor core, and the edematous areas as classification evidence; this agrees with the actual imaging characteristics of HGG. The LGG feature patterns present higher responses on the homogeneous tumor cores and the non-enhancing tumor margins.
Another important advantage of the proposed model is the local explanation it presents for individual predictions. Background areas, such as the ventricles, were found to be activated by the ‘tumor core’ feature patterns of the LGG class. These background patches are not faithful features to the underlying lesion. Therefore, unboxing the reasoning process is necessary; it allows the clinicians and patients to screen out ‘unreliable’ correct predictions.
The local explanation of individual explanations was also evaluated by radiologists to see if it is reliable and acceptable for decision-making support. This form of reliability evaluation and model tuning is not available in the development of “black box” networks or the interpretable models mentioned above. According to the findings, the developed solution provided positive outcomes regarding the brain tumor classification and explanation targeted in this study.
Considering the limitations of the present study, these can be divided into methodological limitations in the construction of the network and limitations in the contextualization of the results.
It is reasonable to suppose that network construction limitations contribute to the lower classification accuracy of the proposed interpretable model compared with the baseline model. This discrepancy could be attributed to the model’s classification inference process, which is greatly influenced by the feature patterns obtained from the randomly generated push dataset. In future work, optimizing the selection of the push dataset may help to improve the classification accuracy of the model. It is also possible that the training data augmentation process could be optimized, as some recent evidence suggests that, even though we used very widely used augmentation methods, the inclusion of image orientations not found in the testing set does not improve the generalizing ability of the model [
39].
Regarding interpretation of the results, we did not find other interpretable deep learning methods applied to brain tumor classification based on the same dataset, and we cannot confirm the degree to which the 86% reliability obtained by the model would be considered acceptable by the health workers. Further collaboration with medical practitioners is important for the practical assessment of our model. Considering possible future developments or our work, several possible extensions are clear. The data modalities could be extended to incorporate a greater variety of structural images, such as T1w, T2w, and FLAIR, as well as more targeted sequences, including amide proton transfer [
40] and MR spectroscopy [
41]. It is also important to consider whether findings in the BraTS2017 dataset carry over into other datasets. For example, many clinical scanners continue to use lower field strengths. Publicly available data sets such as MNIBITE [
42] and the recent ReMIND [
43] could be leveraged to test IMPA-Net with 1.5-T data.