Next Article in Journal
Dual-Conversion Microwave Down Converter for Nanosatellite Electronic Warfare Systems
Next Article in Special Issue
TorchEsegeta: Framework for Interpretability and Explainability of Image-Based Deep Learning Models
Previous Article in Journal
Sustainability in the Circular Economy: Insights and Dynamics of Designing Circular Business Models
Previous Article in Special Issue
Investigating Explainability Methods in Recurrent Neural Network Architectures for Financial Time Series Data
 
 
Font Type:
Arial Georgia Verdana
Font Size:
Aa Aa Aa
Line Spacing:
Column Width:
Background:
Article

Chain Graph Explanation of Neural Network Based on Feature-Level Class Confusion

Department of Electrical and Computer Engineering, Sungkyunkwan University, Suwon 16419, Korea
*
Author to whom correspondence should be addressed.
Appl. Sci. 2022, 12(3), 1523; https://doi.org/10.3390/app12031523
Submission received: 28 December 2021 / Revised: 26 January 2022 / Accepted: 28 January 2022 / Published: 30 January 2022
(This article belongs to the Special Issue Explainable Artificial Intelligence (XAI))

Abstract

:
Despite increasing interest in developing interpretable machine learning methods, most recent studies have provided explanations only for single instances, require additional datasets, and are sensitive to hyperparameters. This paper proposes a confusion graph that reveals model weaknesses by constructing a confusion dictionary. Unlike other methods, which focus on the performance variation caused by single-neuron suppression, it defines the role of each neuron in two different perspectives: ‘correction’ and ‘violation’. Furthermore, our method can identify the class relationships in similar positions at the feature level, which can suggest improvements to the model. Finally, the proposed graph construction is model-agnostic and does not require additional data or tedious hyperparameter tuning. Experimental results show that the information loss from omitting the channels guided by the proposed graph can result in huge performance degradation, from 91% to 33%, while the proposed graph only retains 1% of total neurons.

1. Introduction

Thanks to remarkable advances in the Convolutional Neural Networks (CNN) [1,2,3,4], real-life applications of deep learning have become a popular subject, and the feasibility of applications of artificial intelligence, such as self-driving vehicles [5,6,7], medical imaging [8,9,10], finance [11], and recommendation generation [12], has increased. However, deep neural network-based approaches have been treated as ‘black box’ functions due to their complex structure and consecutive non-linearity. Despite their high accuracy, they cannot justify their reliability, and interaction with users is also difficult. This opaque nature can cause problems when the model returns the wrong decisions or malfunctions. Therefore, methods have been proposed to present insight into the deep learning model’s output. There are two main ways to describe a model. One is to explain the prediction of individual instances as a heatmap, and the other is to explain the behavior of the model.
Class Activation Mapping (CAM) and its variants [13,14,15] are widely used explanation methods for the prediction of single instances. A heatmap is generated that highlights the input pixels that had the most influence on decision-making by using a linear combination of target layer feature maps. Thus, the definition of the coefficients for the linear combination is important. Grad-CAM [14] uses the gradient values of the target layer on class c as the coefficient, and Score-CAM [15] uses the rate of change of the model’s prediction for class c when the input is perturbed. However, a heatmap based on CAM has the same size as the target layer feature map, so the final heatmap shows a blurry output because of resizing. Thus, approaches with a relevance score [16,17,18] have been proposed to produce a heatmap in which each pixel represents the importance to the output of the model. They redistribute the relevance score from the final layer of the model to the input layer so that no resizing is required when creating a heatmap. However, they provide similar heatmaps for any class c so that they provide similar explanations even if the predictions of the model are different, which makes the interpretation ambiguous.
To explain the output of the model, the following should be addressed: ‘Why did it judge that way?’, ‘How did it work?’, and ‘What are the possible failures of the model?’. However, methods based on the single instance explanation can only answer the first question, and they are only valid when the explanation of the correct label is valid.
To explain the model behavior, research into model approximation and learned representation has been proposed. Local Interpretable Model-agnostic Explanations (LIME) [19] provide an explanation based on a local linear approximation of the model, but it is not sufficient to approximate the complex neural network structures, potentially resulting in poor performance on CNN approximations. Other approaches to the interpretation of the learned representation of CNN, Testing with Concept Activation Vectors (TCAV) [20] and Automatic Concept-based Explanations (ACE) [21], attempt to convert internal states of a neural network to human-friendly concepts based on the concept vector. The Prototypical Part Network (ProtoPNet) [22] tries to attain model transparency via training with prototypes that can summarize each layer representation with a set of instances. However, these methods cannot provide answers to the last of the three questions presented above. Furthermore, when using prototypes or concepts to interpret learned representations, it is necessary to define them, and additional partial datasets for them should be built.
Recently, some graph-based model explanation methods [23,24,25] have been used to disentangle each CNN filter response with probabilistic modeling and learn a graphical model to reveal the hierarchy of CNN activations. They have the advantage of decomposing neural networks from a probabilistic point of view, but they require many hyperparameters, and the results are sensitive to how the hyperparameters are set. Moreover, they do not satisfy the third prerequisite of the model explanation.
To solve this problem, we present an algorithm to draw a chain graph of a model that reveals the channel activations in each layer that play a role in decision boundaries and provide a confusion information dictionary that can reveal the class relationship in which the difference in the formed features is relatively small. The proposed confusion graph shows the chain graph of activation maps, which are common and play an important role in decision changes of the model. Our method does not require any additional data or hyperparameter settings and can satisfy the three requirements for the explanation method that are written as questions above. Our contribution is as follows.
  • Through the unit channel suppression with two directions, so-called ‘correction’ and ‘violation’, we propose a confusion graph that is a connection of channels that are vulnerable on the decision boundary of each layer in the CNN.
  • We propose a confusion information dictionary that can reveal the class relationship in which the difference in the formed features is relatively small. Furthermore, it is possible to confirm the effect of the change in the feature level on the class prediction based on the proposed dictionary.
  • The proposed method is model agnostic, requires no additional data, and does not need any hyperparameters.
The rest of the paper is organized as follows. Section 2 presents previous work relevant to our research. Section 3 describes the construction details of the confusion graph on consecutive layers in deep convolutional neural networks. In Section 4, we comprehensively show how the proposed confusion graph affects the model in both quantitative and qualitative ways and depicts the possible class confusion based on the intermediate sources of the graph construction. Finally, in Section 5, we conclude the paper with possible applications of our proposed method.

2. Related Works

The related works in this paper can be grouped into the following two categories, which are related to understanding a CNN’s behavior: understanding the learned representation of CNN and chain graph modeling.

2.1. Understanding the Learned Representation of CNN

TCAV [20] creates a concept vector that distinguishes between random sample activation and concept sample activation and utilizes the sensitivity of the concept vector for the model decision to quantify which concepts affect the model’s decision. However, there was a difficulty in collecting additional data for the concept. Therefore, ACE [21] divided the image into segments and automated this process without additional data. However, there are noisy or difficult concepts provided depending on the image segmentation method. Moreover, in both cases, if the performance of the linear classifier for creating a concept vector is not guaranteed, it cannot be considered that a valid concept vector has been created. Furthermore, it is too sensitive to the hyperparameter settings, and it can only be regarded as indicating the model trend for data and cannot be viewed as providing a direct interpretation of the model.
Other approaches for understanding the model’s behavior were based on its learned representation; e.g., [26] made an empirical search on quantification of the dependency of an individual unit of a CNN to the model’s output. Furthermore, [27] describes the accuracy drop for each channel by providing an ablation study on layer activation and provides how the corresponding aspect changes when input has perturbation or additional layers, such as batch normalization or dropout, are included. The Shapley value for a neural network to determine the critical neurons for classification accuracy was computed in [28]. However, these studies concentrated on the accuracy drop related to a single neuron in a CNN, so they do not provide the role of a single neuron on class confusion. Therefore, this paper proposes to draw a model confusion explanatory graph that utilizes unit channel suppression but focuses on the class confusion relationship and the role of each channel in decision making.

2.2. Modeling the Consecutive Layers in Neural Networks

Research into modeling the consecutive layer representation of a CNN includes probabilistic modeling and interpretable learning for model whitening. The authors of [23] tried to understand the relationship between successive activations by first clustering the activations of each layer through the Gaussian Mixture Model (GMM) and then calculating the transition probability between clusters included in successive layers. They can infer where the final activation, which is ambiguous and accounts for a relatively large portion of the object, originates from the previous step activation. However, if there is an error in GMM’s clustering, the approximation creates a different connection than the real one, and it is difficult to reproduce the algorithm without a complete understanding of probabilistic modeling.
In [22,29], the model has also been trained on prototypes in the training stage, and the basis for final classification is obtained from them by layer. This approach can confirm which prototypes the final prediction is derived from. However, prototypes must be defined on whole data, and they require annotation of the prototype for all the training data, which is expensive to prepare. Therefore, [30] used a semi-supervised approach using partially annotated datasets for the prototype training. They can visualize the concept or prototype that is used as the basis of the model prediction. However, they cannot provide insight into the model itself, such as which parts of the model move weakly at decision boundaries and possible failure cases of the model.
To address these difficulties, by utilizing the unit channel suppression on each layer, this paper proposes a method to extract the channels that operate effectively in the two types of decision boundary of the model, respectively. Our proposed method does not need additional data and can express the weakness channel of the model itself while possible feature-level boundary relations can be collected.

3. Proposed Method

Given the properties of visual representations made by CNNs [31], there is no doubt that each layer of a CNN takes part in image recognition. It also turns out that the channels of each layer can affect the accuracy for a particular class [27]. However, how they affect the relationships between classes and these relationships have not been studied in detail. In this section, we measure the effect of each channel on the boundary in image classification by unit channel suppression and create a confusion graph that connects the channels that commonly cause confusion by associating the selected key channels according to the layer order.

3.1. Extract Key Channels for Class Confusion Based on Unit Channel Suppression

Our assumption for channel suppression is to make the target channel zero-valued. This is because we want to observe the information loss that comes from the absence of the targeted channel. Figure 1 shows how unit channel suppression works on the i-th layer of a CNN model.
To measure the effect of an individual channel on classification confusion, it is necessary to define the types of decision conversion that can be observed by removing channels. There are three cases in which the class changes. ‘Violation Confusion’ occurs when the model makes correct predictions when the full feature is used but incorrect decisions when the unit channel is deleted. ‘Correction Confusion’ refers to the opposite case, in which the model makes the wrong prediction when the full feature is used, but the decision becomes correct when the unit channel is deleted. In the final confusion case, the model makes the wrong prediction when the full feature is used and also when channel deletion is performed, but the two predictions are different. Based on these observations, we collect the confusion information for the ‘violation’ and ‘correction’ cases to single out the channels that the model was sensitive to in terms of correct and incorrect predictions.
Suppose that we have a fixed target model f . and input X , label y . We can select k layers for which we want to observe the key channel in f . , and the set of k activations are defined as A   : = A 1 ,   ,   A k = a 0 ,   ,   a C 1 , where C indicates the total number of channels in k activations, k c k . For example, if we choose 3 layers with 64, 128, and 256 channels, then the set A consists of a 0 ,   ,   a 447 that ( a 0 ,   ,   a 63 ) belongs to the layer with 64 channels and the channels ( a 64 ,   ,   a 191 ) compose the layer with 128 channels and the remainders are for the final layer. The Algorithm 1 shows the pseudo code of confusion graph construction with confusion dictionary collection.
First, we collect the violation confusion information from each layer. The confusion information should contain the three items below:
  • Channel information with the layers, which causes confusion.
  • The class prediction of the model before and after the confusion.
  • Images that show confusion while the unit channel is suppressed.
Algorithm 1: Pseudo Code for Proposed Confusion Graph Explanation Construction (Violation)
Input: Fixed model f . , Input X : = X 1 ,   ,   X l , Label y , Model prediction   y ¯ , The channels of the model A   =   a 0 ,   ,   a C 1 , Neighbor matrix N   : =   N 1 ,   ,   N k 1 ;
Output:
D i c t v i o l a t i o n ; Violation Confusion Information Dictionary
    G v i o l a t i o n ; Violation Confusion Graph
Step (1) Confusion Relation Collection
   D i c t v i o l a t i o n i : .     f o r   i   i n   range C ;
  For X l   in X :
   For ( i ,   a i ) in enumerate ( A ):
     y ˜   argmax   f X l | a i = 0
     If y ¯ y ˜ & y ˜ = y : # collection criteria
      key   str y ¯ + str y ˜
    If key not in D i c t c o n f [i]:
       D i c t v i o l a t i o n [i][key]   X l
    else:
       D i c t v i o l a t i o n [i] [key] . append X l

Step (2) Filter common confusion relations
     D i c t v i o l a t i o n = { D 1 ,   D 2 ,   ,   D k };
    assert l e n D k   =   c k
   Common_Confusion   i = 1 k ( a l l   c o n f u s i o n   k e y s   i n   D i )

Step (3) Neighboring between A i and A i + 1
    assert N i R   c i   ×   c i + 1
   For key in Common_Confusion:
     a list   of   channels   that   contains key in D i
     b   list   of   channels   that   contains key in D i + 1
    For ( a m , b n ) in a × b :
       N i a m ,   b n + = 1  

Step (4) Neighboring between A i and A i + 1
     G v i o l a t i o n A i : .   f o r   A i   i n   A
    For N i in N :
      G v i o l a t i o n A i . a p p e n d ( [ row coordinates of N i where the value is max( N i )])
     G v i o l a t i o n A i + 1 . a p p e n d ( [ column coordinates of N i where the value is max( N i )])

Return D i c t v i o l a t i o n ,   G v i o l a t i o n
Algorithm 2: Pseudo Code for Total Confusion Graph Construction
Input:
   G v i o l a t i o n ; Violation Confusion Graph from Algorithm 1
   G c o r r e c t i o n ; Correction Confusion Graph from Algorithm 1
Output:
   G t o t a l ; Total Confusion Graph
Init:
   G t o t a l { }

Process:
  For A i in A :
    G t o t a l       G v i o l a t i o n A i     G c o r r e c t i o n A i

Return G t o t a l
After repeating information collection k times, we extract the violation confusion relation that appears common to the k layers to observe the most common feature-level confusions in target model f . and extract the channels from each layer that induces the confusions. The selected channels are key channels since they cause the common confusion for each layer. The Algorithm 2 shows the pseudo code of total confusion graph construction based on violation and correction confusion graph.

3.2. Graph Construction Based on Key Channel Neighboring

From the acquired key channels among k layers, neighboring is performed by constructing a neighbor matrix between the upper and lower layers to mine channels with the same role. Thus, we need k 1 neighbor matrixes to make consecutive connections between k layers.
Suppose that the zero-initialized neighbor matrix of A i with m channels and A i + 1 with n channels is expressed as N R m × n , which means no relations between the two layers. If a channels in A i and b channels in A i + 1 had the same confusion relation, the cartesian product of two lists of channels is performed, and the element corresponding to the cartesian coordinate in the neighbor matrix N increases in value by 1 , indicating that the channels constituting the coordinate play a common role. Thus, if we find c common confusions, the maximum value that each neighbor matrix element can have should be c . This means that the two channels generate all the common confusion, indicating that they are playing similar roles in class confusion. Similarly, the larger weight on the neighbor matrix indicates that the two connected nodes take similar roles in class confusion. Therefore, when we extract the violation confusion graph, we extract the edges that have the highest weight in the neighbor matrix.
In the case of the correction confusion graph construction, all procedures are identical to the steps for generating a violation confusion graph. The only difference is that the confusion information to be collected is no longer a violation but a correction. The obtained violation graph and correction graph each show the connection relationships of the most confusion-related channels in non-overlapped confusion directions. Therefore, by merging these two graphs, we can obtain a chain graph of the network based on feature-level class confusions. Algorithm 1 is the pseudo code of our violation/correction confusion graph construction and Algorithm 2 is the pseudo code for the total confusion graph.

4. Experiments

For the following experiments, we use VGG16 [21] initialized with the ImageNet [22] pre-trained model and fine-tuned on the Animals with Attributes 2 (AwA2) dataset [23]. AwA2 has 37,322 images of 50 animal classes with pre-extracted feature representations for each image. To avoid confusion based on imbalanced numbers of classes, we excluded the classes with less than 500 images and constructed the entire dataset by randomly sampling 500 images each for the remaining 25 classes. The training and evaluation data were composed by dividing the entire dataset into an 8:2 ratio, and the graph analysis was performed on the evaluation data through unit channel suppression. The model was optimized with Stochastic Gradient Descent (SGD) with a learning rate of 0.01, which achieves 91.82% accuracy on the validation data. Our experiment was conducted on NVIDIA Titan Xp 12GB. We investigated the output of each convolution block and classification layers in VGG16.
Moreover, we compare our work with [28], which is the state-of-the-art method to approximate the importance of neurons in the model, on ImageNet validation data. For fair comparison, we followed the same settings as [28] by dividing the released ImageNet validation set into half (25,000 images each) and using one part as graph construction and another part as test sets. Since the proposed method does not require hyperparameters except batch size, we set the batch size as 128, which is the same as [28].

4.1. Confusion Graphs Based on Feature-Level Confusion and How It Affects the Model’s Decision

In this section, we describe the effect of our proposed graph on the model’s performance. Figure 2 shows the violation and correction confusion graphs of VGG16. As mentioned above, we select the last convolution of each convolution block and classification layers as our target layers. Compared to other studies that are complicated to visualize, our proposed graph selects a relatively small number of channels, less than 5% of the total number of channels, so the key channels can be drawn as a single figure.
To observe whether the proposed graph, which collects channels that behaved identically for a common confusion relationship, can be a weak point in the model, we observed the effect of each graph on the model performance. We compared the accuracy of the model after deleting the channels that consist of the confusion graph and the random channel deletion. To maintain fairness between the random and proposed graph, the number of channels turned off for each layer in random channel deletion is the same as the number of channels turned off in the same layer in the proposed confusion graph. We repeat this procedure five times to maintain the randomness of the experiment and use the average value.
Table 1 shows the performance drop of the model with each type of graph and the full-featured model. The total number of channels in our selected layer was 9152, and the classification accuracy was 91.82% on the evaluation data.
The first column shows the accuracy of the model according to each state. The full-featured state indicates a state in which no zero-out manipulation is applied to the model, and violation, correction, and total confusion are the results of measuring the accuracy of the model after deleting each type of graph from the full-featured state. The second column shows the model accuracy after zero-out by randomly selecting the same number of channels as the proposed graph for each layer. The last column indicates the total number of channels dropped by each graph state.
The result shows that when the total number of channels in confusion graphs, 54, is less than 1% of the total number of channels, the performance decreased by 58.7%, dropping to less than one-third of the original accuracy. Furthermore, compared to the random channel deletion, the selected channels play a similar role in that they show 47.11% lower performance compared to the random deletion when 54 channels are deleted. This suggests that the proposed graph points to the weak channels of the model, which should be modified to improve the robustness of the model in the future.
Table 2 shows the performance drop of the proposed graph and the Neuron Shapley (NShap) [28]. According to [28], 25,000 ImageNet test sets they used showed 74% accuracy for the pretrained Inception-V3 and dropping 10 filters lowered the model’s performance by 38% and dropping 20 could lower it by as much as 8%. Since we also separate the ImageNet validation set into two parts, one for graph construction and another for test, we should report the performance of pretrained Inception-V3 on our test part. We get 76% accuracy, and it is indicated in parentheses of the last column of Table 2. The observed model performance by deleting the proposed graphs is shown in the 3rd to 5th rows of Table 2. Both the violation graph and the correction graph delete less than 20 neurons, and the violation lowers the accuracy by 0.75% and the correction by 2%. It infers that the proposed method can find more efficient neurons than [28]. Moreover, even though the confusion graph uses a larger number than [28], it can completely destroy the model in terms of reducing the accuracy by 0.1%. The time taken to obtain the graph will be covered in Section 4.2.

4.2. Efficiency of Proposed Confusion Graph

Additionally, we checked the effect of each graph by performing a random selection on each type of confusion graph. We randomly turn off as many as x% of channels in each graph and leave (100 − x)% of channels in each graph alone. For example, 0% selection on violation graph means that none of the channels in the graph got zero-values, and 90% attack means that we choose 14 × 0.9 = 12.6 ≈ 12 channels in the graph to be zeroed out and leave the remaining two channels with their own values. Thus, the larger x is, the more channels are deleted. We also compare the results with random channel deletion for the same amount of deletion.
Figure 3 shows the results of random channel selection on the violation, correction, and total confusion graphs, respectively. The blue line in each figure indicates performance with x% selection of each graph, and the orange line is that of random channel deletion with the same number of channels. From each figure it can be observed that even with fewer channels in the path, there is more performance drop than for random channel selection. Further, even if only about 50% of channels are selected in each path, it can cause a greater performance degradation than the random off accuracy in Table 1. This confirms that the channels found by the proposed method operate at the boundaries of the class feature, which is the weak point of the model.
Efficiency not only indicates how each graph corruption affects model accuracy but also includes the time until graph creation. Therefore, we also analyzed the computational details of the proposed algorithm. To measure the importance of each neuron in the model’s performance, NShap formulates the problem into Multi-Armed Bandit (MAB) problem and applies early truncation to speed up the single iteration of its algorithm. However, it iterates repeatedly until the algorithm converges. For example, it took 21 h for their algorithm to converge, requiring a total of 3000 iterations to get the Shapley value of Inception-V3 with 100 parallelized computing machines. This indicates that it needs a lot of computing power to converge their algorithm. Meanwhile, the proposed method does not have the convergence problem, so it can be done with a single computing machine with a single iteration. It took 23 h 48 m with a single computing machine to figure out the neurons for our graph with a single computing machine. Although it took about three more hours to obtain the graph than [28], we used only one computing machine, much less than 100, which implies it is better for real-world applications.
It took 8 h 51 m to collect the confusion dictionary for VGG16 in a single machine, and it can be further reduced by half if we used multiple machines. Once the confusion dictionary is built, then it only takes 31 s to put it together to form the proposed confusion graph without GPU support.
The total number of forward-pass is another option for comparing computation complexity. According to [28], their algorithm can approximate the top 100 import neurons without observing 1500 neurons in one iteration among 17,216 neurons. Thus, they search about 15,000 neurons per batch and repeat 3000 times, and it requires about 4.5 × 107 forward-passes. However, in such a case, the proposed method needs only 17,216 forward-pass for each batch, which is far less than the state-of-the-art method.

4.3. Observation of Confusion Relations between Two Classes

In addition to the graph, since the proposed method is concentrated on the role of each neuron, the proposed confusion information dictionary can provide clues to understand feature-level class confusions.
When building the confusion matrix with image classification results, we cannot consider where each image is at the feature level but only reflect the output for that image. However, based on the confusion information dictionary, which contains what kind of images caused class confusion when modification occurred at the feature level with unit channel suppression, it is possible to distinguish the images that can be changed to another class at the feature level from those that cannot. This can reveal a closer confusion relationship between the two specific classes from confusion relation based on channel deletion.
We experimented with ox and cow as an example. Figure 4 shows the ox–cow confusion relationship that can be seen through a simple confusion matrix, and Figure 5 shows the ox–cow confusion relationship drawn through the confusion information dictionary. In both figures, the blue frames represent samples from the cow class, and the red frames indicate samples from the ox class. Model decisions of ‘ox’ are placed on the left side of the figures, and decisions of ‘cow’ are placed on the right side of the figures. Then, in Figure 4, we can observe the samples where the label was ‘ox’, but the model prediction was ‘cow’ and vice versa.
However, in Figure 5, we can observe the classification of the samples in greater detail. In this case, the red arrows indicate cases where the decision changed from cow to ox during unit channel suppression, and the blue arrows indicate the opposite. At this time, cases with dark brown or black objects with cow predictions show red arrows regardless of the presence or absence of horns. Moreover, if a cow is incorrectly predicted as an ox, the decision does not change at the feature level if the color of the object is dark brown or black. Similarly, even in the case of the blue arrow, if the object is a relatively pale color between white and brown, regardless of the presence or absence of horns, confusion occurs at the feature level between ox and cow. Furthermore, in the case of a completely white ox, even if it is incorrectly predicted as cow, the decision does not change at all during the channel suppression.
In addition to explaining the cause of confusion for the two classes in Figure 5, we also checked the concept in ACE. It should be noted here that ACE is mainly used to reveal which concept was used to determine a specific class, but it is not possible to identify the confusion between two specific classes. Therefore, after identifying the confusion relationship with the proposed method, it can be additionally used to find the confused patches for confusion classes.
Figure 6 presents concept patches from the ACE algorithm with the purpose of distinguishing between ox and cow in our evaluation set. The collected patches are the subset of concepts that has more than 90% accuracy on cow and ox classification and p-value < 0.05. Except for the patches that are related to the background, which has a green or sky-blue color, the color range of ox patches is darker than that of cows. However, the pattern of concept is too noisy, so concepts from ACE are also difficult to translate into simple human language. Thus, we investigate the difference between ox and cow that was observed in the pre-defined predicate provided by AwA2. The color-related parts (e.g., brown, black, white) are marked with the common points of ox and cow in the predicates. Therefore, it becomes a factor that can confuse the two classes. This indicates that class-discriminative characteristics can be inferred by observing the sample distribution, as shown in Figure 6, through the confusion dictionary of the proposed method. This can be helpful when making a future training strategy, such as planning for additional training with light-colored oxen and dark-colored cows, enabling the two classes to be distinguished between, irrespective of color.

4.4. Possible Failures Based on Feature Relation from Confusion Information Dictionary

The most common confusions in the collected information were used to make confusion graphs, but the confusion information dictionary contains how the predictions change when a part of each feature is lost in a whole range.
For the i-th layer activation A i , A i c is the c-th channel of the activation and A i \ A i c should be the feature map with the c-th channel zeroed out. To investigate the importance of the suppressed channel in the layer when the decision of the model is changed by unit channel suppression, we measured two things.
First, we measured the cosine similarity between A i \ A i c and A i if the predicted output from the two features causes a confusion to see how much feature changes are based on the angular similarity. This is closer to 1 when the two features point in the same direction, 0 when they are orthogonal, and close to −1 when they are opposite. Thus, if the value is closer to 1, it means that A i c has small importance on angular perspective. Further, we calculate the sum of the absolute values of A i c and divide it by the sum of the absolute values of A i to observe the value proportion of A i c to A i . It is also measured only when the suppressing A i c induces the confusion. As this ratio increases, the value occupied by A i c is large when the confusion occurs so that the importance of the corresponding channel to the size of the value is large. Table 3 shows this.
According to Table 3, even if it is a channel that changed the decision-making of the model, the importance of the value occupied by the corresponding channel in the layer is small, and the difference between the zeroed-out and non-zeroed-out channels is small. This can be seen from the value ratio of all layers being less than 1% and the cosine similarity before and after zero-out being more than 0.98. Therefore, the importance of the suppressed channels in each layer is weak. This result indicates that the difference between the unit channel subtracted feature and the full feature is small enough, and the class confusion caused by this small difference means that the distinction between the two features is small. Therefore, the two classes in such a relationship can be said to be possible classification failure cases for each other. Reflecting this relationship, in Table 4 and Table 5, we describe the class relationship that is closest on the feature level.
Table 4 is the result of confirming whether the confusion relations that appear frequently in the classification result are often confused, even at the actual feature level. The first column of Table 4 represents the top 10 ‘label-prediction’ relations with the highest number of confusions at the image level. The remaining columns indicate the correction and violation relations, those with the top 10 frequencies on the entire layer and common with the ‘label-prediction’ confusions. Among them, the bold font indicates relationships that appear in the first column, which indicates that relationships that were confused in image level confusion also appear in feature level confusion. That is, when observing the features of two classes, the intra-class bias and inter-class variance are smaller than the relationship with other classes. This accounted for 65.04% of the number of times of confusion in features, and it is considered to be a priority in improving model performance in the future.
Table 5 shows unique relationships that did not exist at the image level but only at the feature level. This information cannot be known when only the input–output relationship is observed, but it can be understood as a feature-forming relationship of classes that should not be ignored because it accounts for 34.96% of the number of instances of confusion in features. Therefore, like the relationships shown in Table 3, those relations also have smaller intra-class bias and inter-class variance than unseen relations. Thus, it can be used to improve the robustness of the model.

5. Discussion

The proposed method can identify the model’s vulnerability in a reasonable time. Since this is a vulnerability based on class confusion in the model, it can provide feedback on future model training. It gives hints about the propensity of input images to be confused, for example, as described in Section 4.3, or which classes should have more separable relationships within the feature space described in Section 4.4. If feedback is given to the model through these hints, the model’s confusion will be relieved, naturally reducing the proposed confusion graph.
As another kind of feedback, from the obtained confusion dictionary, we can extract channels related to a particular confusion relationship. By zero-outing them, the decision tendency of the model can be manipulated in the desired direction. For example, the VGG16 model we used in our experiment predicts 104 ox and 93 cow when full-featured. If all channels confusing from ox to cow are deleted based on the proposed confusion dictionary, the model predicts 0 ox and 296 cows. Conversely, if we delete channels that confuse cows with ox, the model provides three predictions for cow and 94 for ox. This is natural as the proposed graph is constructed by role-based neurons. However, since we investigated each neuron independently, if we lose the graph that connects them all, the part about continuous information loss is less clear, so decision adjustment at the current stage tends not to be stable.
Based on these examples, this model can be applied to various fields, such as medical imaging. The explainable AI methods proposed in the general computer vision domain might be directly applicable to the medical area [32,33]. There are several papers using model activation to indicate the location of lesions for model decision making [34,35] or provide neuron’s importance on the output of CNN [36]. The proposed method also provides analysis on the trained model regardless of the input domain, so it is sufficiently applicable to models using medical images. In this case, among the feedback suggested by the model, it can be applied to the former part, and the model designer can ask the doctor for additional data through the tendency of the model obtained from the former. We can discuss whether this is happening in the clinical setting as well.

6. Conclusions and Future Work

In this paper, we proposed a model-agnostic method to draw a graph of key channels affecting decisions in a CNN. Unlike previous research that concentrated on the effect of neurons in the accuracy drop aspect, we observe the neuron’s role of class confusion. Even though the proposed graph utilizes a small amount of the model, it can derive a catastrophic performance degradation. Thus, the channels that compose the proposed graph can be seen as the weak points of the model. In addition, it was confirmed that the channels that are suppressed and change the decision of the model do not occupy a large portion in the actual feature layer. Based on this observation, unlike conventional studies that are concentrated on the impact of neurons on the model’s performance, it is possible to analyze the cause of confusion for two specific classes through the confusion information dictionary and to observe the relationship that appears in image level confusion and the relationship that appears only in the feature level. Through this method, the proposed analysis method can be used to formulate a strategy to improve the model in the future.
Furthermore, the proposed method has the advantage of time and complexity compared to the state-of-the-art method. The proposed method can be applied to a single computing machine with a reasonable execution time, and it does not require any tedious hyperparameter settings so that it is easier to use. However, because the proposed method treats each layer separately in information collection and mines consecutive layer relations in a post-hoc way, it cannot reflect the consecutive information loss in neighbor layers and the effect of the feedback loop on the constructed graph. Therefore, in future research, we will study how to build a graph more quickly by reflecting the continuous information relation of each layer and applying more detailed training strategies to the weakness channel graph for robustness enhancement of the model in the bias–variance perspective of each class feature cluster.

Author Contributions

Conceptualization, H.H. and J.S.; methodology, H.H. and J.S.; software, H.H.; validation, H.H. and J.S.; formal analysis, H.H.; investigation, H.H.; resources, E.P. and J.S.; data curation, H.H.; writing—original draft preparation, H.H.; writing—review and editing, E.P. and J.S.; visualization, H.H.; supervision, J.S. and E.P. All authors have read and agreed to the published version of the manuscript.

Funding

This work was supported in part by the National Research Foundation of Korea (NRF) Grant funded by the Korean Government Ministry of Science and ICT (MSIT) (NRF-2020R1F1A1065626), in part by the MSIT under the Information Technology Research Center (ITRC) support program (IITP-2021-2018-0-01798) supervised by the Institute for Information & Communications Technology Planning & Evaluation (IITP), and partly supported by the BK21 FOUR Project.

Institutional Review Board Statement

Not applicable.

Informed Consent Statement

Not applicable.

Data Availability Statement

The Animals with Attributes 2 dataset can be downloaded at the official website https://cvml.ist.ac.at/AwA2/ (accessed on 3 December 2021) and the code used in this paper can be found at the following GitHub repository: https://github.com/hailey94/ConfusionGraph (accessed on 21 December 2021).

Conflicts of Interest

The authors declare no conflict of interest.

References

  1. He, K.; Zhang, X.; Ren, S.; Sun, J. Identity Mappings in Deep Residual Networks. In Proceedings of the 14th European Conference on Computer Vision, Amsterdam, The Netherlands, 11–14 October 2016; pp. 630–645. [Google Scholar]
  2. Tan, M.; Le, Q. Efficientnet: Rethinking model scaling for convolutional neural networks. In Proceedings of the 36th International Conference on Machine Learning, Long Beach, CA, USA, 9–15 June 2019; pp. 6105–6114. [Google Scholar]
  3. Howard, A.G.; Zhu, M.; Chen, B.; Kalenichenko, D.; Wang, W.; Weyand, T.; Andreetto, M.; Adam, H. Mobilenets: Efficient convolutional neural networks for mobile vision applications. arXiv 2017, arXiv:1704.04861. [Google Scholar]
  4. Zhang, C.; Benz, P.; Argaw, M.D.; Lee, S. Resnet or densenet? Introducing dense shortcuts to resnet. In Proceedings of the IEEE/CVF Winter Conference on Applications of Computer Vision, Virtual, 5–9 January 2021; pp. 3550–3559. [Google Scholar]
  5. Ouyang, Z.; Niu, J.; Liu, Y.; Guizani, M. Deep CNN-based real-time traffic light detector for self-driving vehicles. IEEE Trans. Mob. Comput. 2020, 19, 300–313. [Google Scholar] [CrossRef]
  6. Li, X.; Li, J.; Hy, X.; Yang, J. Line-cnn: End-to-end traffic line detection with line proposal unit. IEEE Trans. Intell. Transp. Syst. 2019, 21, 248–258. [Google Scholar] [CrossRef]
  7. Yang, M.; Wang, S.; Bakita, J.; Vu, T.; Smith, F.D.; Anderson, J.H.; Frahm, J. Re-thinking CNN frameworks for time-sensitive autonomous-driving applications: Addressing an industrial challenge. In Proceedings of the IEEE Real-Time and Embedded Technology and Applications Symposium, Montreal, QC, Canada, 16–18 April 2019; pp. 305–317. [Google Scholar]
  8. Yadav, S.S.; Jadhav, S.M. Deep convolutional neural network based medical image classification for disease diagnosis. J. Big Data 2019, 6, 113. [Google Scholar] [CrossRef] [Green Version]
  9. Meng, Y.; Wei, M.; Gao, D.; Zhao, Y.; Yang, X.; Huang, X.; Zheng, Y. CNN-GCN aggregation enabled boundary regression for biomedical image segmentation. In Proceedings of the International Conference on Medical Image Computing and Computer-Assisted Intervention, Lima, Peru, 4–8 October 2020; pp. 352–362. [Google Scholar]
  10. Tang, W.; Zou, D.; Yang, S.; Shi, J.; Dan, J.; Song, G. A two-stage approach for automatic liver segmentation with Faster R-CNN and DeepLab. Neural Comput. Appl. 2020, 32, 6769–6778. [Google Scholar] [CrossRef]
  11. Lu, W.; Li, J.; Wang, J.; Qin, L. A CNN-BiLSTM-AM method for stock price prediction. Neural Comput. Appl. 2021, 33, 4741–4753. [Google Scholar] [CrossRef]
  12. An, H.; Moon, N. Design of recommendation system for tourist spot using sentiment analysis based on CNN-LSTM. J. Ambient Intell. Humaniz. Comput. 2019, 1–11. [Google Scholar] [CrossRef]
  13. Zhou, B.; Khosla, A.; Lapedriza, A.; Oliva, A.; Torralba, A. Learning deep features for discriminative localization. In Proceedings of the IEEE conference on Computer Vision and Pattern Recognition, Las Vegas, NV, USA, 27–30 June 2016; pp. 2921–2929. [Google Scholar]
  14. Selvaraju, R.R.; Cogswell, M.; Das, A.; Vedantam, R.; Parikh, D.; Batra, D. Grad-cam: Visual explanations from deep networks via gradient-based localization. In Proceedings of the IEEE International Conference on Computer Vision, Venice, Italy, 22–29 October 2017; pp. 618–626. [Google Scholar]
  15. Wang, H.; Wang, Z.; Cu, M.; Yang, F.; Zhang, Z.; Ding, S.; Mardziel, P.; Hu, X. Score-CAM: Score-weighted visual explanations for convolutional neural networks. In Proceedings of the IEEE/CVF conference on Computer Vision and Pattern Recognition workshops, Virtual, 14–19 June 2020; pp. 24–25. [Google Scholar]
  16. Montavon, G.; Lapuschkin, S.; Binder, A.; Samek, W.; Müller, K. Explaining nonlinear classification decisions with deep taylor decomposition. Pattern Recognit. 2017, 65, 211–222. [Google Scholar] [CrossRef]
  17. Binder, A.; Bach, S.; Montavon, G.; Müller, K.; Samek, W. Layer-wise relevance propagation for deep neural network architectures. In Proceedings of the Information Science And Applications, Ho Chi Minh City, Vietnam, 15–18 February 2016; pp. 913–922. [Google Scholar]
  18. Iwana, B.K.; Kuroki, R.; Uchida, S. Explaining convolutional neural networks using softmax gradient layer-wise relevance propagation. In Proceedings of the IEEE/CVF International Conference on Computer Vision Workshop, Seoul, Korea, 27 October–2 November 2019; pp. 4176–4185. [Google Scholar]
  19. Ribeiro, M.T.; Singh, S.; Guestrin, C. “Why should i trust you?” Explaining the predictions of any classifier. In Proceedings of the ACM SIGKDD International Conference on Knowledge Discovery and Data Mining, San Francisco, CA, USA, 13–17 August 2016; pp. 1135–1144. [Google Scholar]
  20. Kim, B.; Wattenberg, M.; Gilmer, J.; Cai, C.; Wexler, J.; Viegas, F.; Sayres, R. Interpretability beyond feature attribution: Quantitative testing with concept activation vectors. In Proceedings of the International Conference on Machine Learning, Stockholm, Sweden, 10–15 July 2018; pp. 2668–2677. [Google Scholar]
  21. Ghorbani, A.; Wexler, J.; Zou, J.; Kim, B. Towards Automatic Concept-based Explanations. In Proceedings of the Neural Information Processing Systems, Vancouver, BC, Canada, 8–14 December 2019; pp. 9277–9286. [Google Scholar]
  22. Chen, C.; Li, O.; Tao, C.; Barnett, A.J.; Su, J.; Rudin, C. This Looks Like That: Deep Learning for Interpretable Image Recognition. In Proceedings of the Neural Information Processing Systems, Vancouver, BC, Canada, 8–14 December 2019; pp. 8930–8941. [Google Scholar]
  23. Zhang, Q.; Wang, X.; Cao, R.; Wu, Y.N.; Shi, F.; Zhu, S. Extraction of an Explanatory Graph to Interpret a CNN. IEEE Trans. Pattern Anal. Mach. Intell. 2021, 43, 3863–3877. [Google Scholar] [CrossRef] [PubMed]
  24. Shen, Y.; Cremers, D. A Chain Graph Interpretation of Real-World Neural Networks. arXiv 2020, arXiv:2006.16856. [Google Scholar]
  25. Konforti, Y.; Shpigler, A.; Lerner, B.; Bar-Hillel, A. Inference Graphs for CNN Interpretation. In Proceedings of the European Conference on Computer Vision, Virtual, 23–28 August 2020; pp. 69–84. [Google Scholar]
  26. Morcos, A.S.; Barrett, D.G.; Rabinowitz, N.C.; Botvinick, M. On the importance of single directions for generalization. arXiv 2018, arXiv:1803.06959. [Google Scholar]
  27. Zhou, B.; Sun, Y.; Bau, D.; Torralba, A. Revisiting the importance of individual units in cnns via ablation. arXiv 2018, arXiv:1806.02891. [Google Scholar]
  28. Ghorbani, A.; Zou, J.Y. Neuron shapley: Discovering the responsible neurons. In Proceedings of the Neural Information Processing Systems, Virtual, 6–12 December 2020. [Google Scholar]
  29. Koh, P.W.; Nguyen, T.; Tang, Y.S.; Mussmann, S.; Pierson, E.; Kim, B.; Liang, P. Concept bottleneck models. In Proceedings of the International Conference on Machine Learning, Virtual, 12–18 July 2020; pp. 5338–5348. [Google Scholar]
  30. Kazhdan, D.; Dimanov, B.; Jamnik, M.; Liò, P.; Weller, A. Now You See Me (CME): Concept-based Model Extraction. In Proceedings of the ACM International Conference on Information and Knowledge Management Workshop, Galway, Ireland, 19–23 October 2020. [Google Scholar]
  31. Zeiler, M.D.; Fergus, R. Visualizing and understanding convolutional networks. In Proceedings of the European Conference on Computer Vision, Zurich, Switzerland, 6–12 September 2014; pp. 818–833. [Google Scholar]
  32. Kamran, S.A.; Hossain, K.F.; Tavakkoli, A.; Zuckerbrod, S.L.; Baker, S.A. VTGAN: Semi-supervised Retinal Image Synthesis and Disease Prediction using Vision Transformers. In Proceedings of the IEEE/CVF International Conference on Computer Vision Workshop, Virtual, 17 October 2021. [Google Scholar]
  33. Asano, S.; Asaoka, R.; Murata, H.; Hashimoto, T.; Miki, A.; Mori, K.; Ikeda, Y.; Kanamoto, T.; Yamagami, J.; Inoue, K. Predicting the central 10 degrees visual field in glaucoma by applying a deep learning algorithm to optical coherence tomography images. Sci. Rep. 2021, 11, 2214. [Google Scholar] [CrossRef] [PubMed]
  34. Pölsterl, S.; Aigner, C.; Wachinger, C. Scalable, Axiomatic Explanations of Deep Alzheimer’s Diagnosis from Heterogeneous Data. In Proceedings of the International Conference on Medical Image Computing and Computer-Assisted Intervention, Strasbourg, France, 27 September–1 October 2021; pp. 434–444. [Google Scholar]
  35. Tang, Y.; Tang, Y.; Zhu, Y.; Xiao, J.; Summers, R.M. A disentangled generative model for disease decomposition in chest X-rays via normal image synthesis. Med. Image Anal. 2021, 67, 101839. [Google Scholar] [CrossRef] [PubMed]
  36. Quellec, G.; Al Hajj, H.; Lamard, M.; Conze, P.H.; Massin, P.; Cochener, B. ExplAIn: Explanatory artificial intelligence for diabetic retinopathy diagnosis. Med. Image Anal. 2021, 72, 102118. [Google Scholar] [CrossRef] [PubMed]
Figure 1. Unit channel suppression on the i-th layer, c-th channel. Activation map of c-th channel had been turned down to zero and passed to consecutive weights to provide (i + 1)-th activation.
Figure 1. Unit channel suppression on the i-th layer, c-th channel. Activation map of c-th channel had been turned down to zero and passed to consecutive weights to provide (i + 1)-th activation.
Applsci 12 01523 g001
Figure 2. (a) Violation confusion graph of VGG16 and (b) correction confusion graph of VGG16 that are built from the proposed algorithm. The final confusion graph would be the union of the two graphs.
Figure 2. (a) Violation confusion graph of VGG16 and (b) correction confusion graph of VGG16 that are built from the proposed algorithm. The final confusion graph would be the union of the two graphs.
Applsci 12 01523 g002
Figure 3. Model accuracy according to x% random selection on each type of graph. (a) violation confusion, (b) correction confusion and (c) total confusion graph.
Figure 3. Model accuracy according to x% random selection on each type of graph. (a) violation confusion, (b) correction confusion and (c) total confusion graph.
Applsci 12 01523 g003
Figure 4. The ox–cow confusion relationship in image-level classification confusion.
Figure 4. The ox–cow confusion relationship in image-level classification confusion.
Applsci 12 01523 g004
Figure 5. The ox–cow confusion relationship in feature-level confusion information dictionary.
Figure 5. The ox–cow confusion relationship in feature-level confusion information dictionary.
Applsci 12 01523 g005
Figure 6. Concept patches from the ACE algorithm. (a) Concept patches for ox that are distinguishable from cow. (b) Patches for cow that are distinguishable from ox.
Figure 6. Concept patches from the ACE algorithm. (a) Concept patches for ox that are distinguishable from cow. (b) Patches for cow that are distinguishable from ox.
Applsci 12 01523 g006
Table 1. Accuracy (%) of VGG16 without any information loss (full feature) and with information loss after deleting each type of graph in the AwA2 dataset and the number of removed channels in each case. Random-off is the accuracy when the same number of channels is randomly turned off.
Table 1. Accuracy (%) of VGG16 without any information loss (full feature) and with information loss after deleting each type of graph in the AwA2 dataset and the number of removed channels in each case. Random-off is the accuracy when the same number of channels is randomly turned off.
ModelAccuracyRandom-Off Number of Removed Channels
Full Feature91.82-0
Violation78.0389.6714
Correction51.0480.3644
Total Confusion33.1280.2354
Table 2. The number of removed channels for Inception-V3 in [28] and the proposed graphs. Accuracy corresponds to ImageNet test part with the absence of channels and full-featured test set accuracy within brackets.
Table 2. The number of removed channels for Inception-V3 in [28] and the proposed graphs. Accuracy corresponds to ImageNet test part with the absence of channels and full-featured test set accuracy within brackets.
AlgorithmNumber of Removed ChannelsAccuracy (Test Set Acc.)
NShap [28]1038% (74%)
208% (74%)
Violation Graph160.756% (76%)
Correction Graph182.0% (76%)
Confusion Graph270.164% (76%)
Table 3. An importance of the deleted channel in each layer when the decision modification occurs with unit channel suppression in the final convolution layer of each convolution block and dense layer of VGG16.
Table 3. An importance of the deleted channel in each layer when the decision modification occurs with unit channel suppression in the final convolution layer of each convolution block and dense layer of VGG16.
VGG16Block1Block2Block3Block4Block5GAPDenseDense1Dense2
Cosine
Similarity
0.9890.9950.9970.9970.9910.9890.9920.9970.998
Value ratio0.0180.0090.0050.0030.010.0090.040.020.002
Table 4. Top 10 most frequent input-level and feature-level confusions. Feature-level confusion in this table only shows the common confusions of image and feature level.
Table 4. Top 10 most frequent input-level and feature-level confusions. Feature-level confusion in this table only shows the common confusions of image and feature level.
Top 10 Image Level ConfusionsTop 10 Correction ConfusionsTop 10 Violation Confusions
deer–antelopecowoxgorillachimpanzee
cow–oxdolphinwhalewhale–dolphin
dolphin–whalewolf–foxdeer–shepherd
gorilla–chimpanzeecollieshepherddolphin–whale
ox–cowrhinoceros–elephantantelopedeer
seal–otterantelope–deerdeerantelope
antelope–deerottersealrhinoceros–otter
otter–sealshepherd–colliewolf–bobcat
collie–shepherdchimpanzeegorillaoxcow
chimpanzee–gorillasealotterbear–fox
Table 5. Top 10 confusions that appeared only on feature level confusion and did not appear in image level confusion relations.
Table 5. Top 10 confusions that appeared only on feature level confusion and did not appear in image level confusion relations.
Top 10 Unique CorrectionsTop 10 Unique Violations
dolphin–sealbobcat–antelope
dalmatian–whaledeer–bobcat
chimpanzee–cowgiraffe–tiger
otter–rhinoceroscollie–cow
cat–shepherdhamster–bobcat
otter–oxcow–wolf
antelope–foxtiger–wolf
rhinoceros–cowox–deer
chimpanzee–catgiraffe–bobcat
fox–chimpanzeedalmatian–cow
Publisher’s Note: MDPI stays neutral with regard to jurisdictional claims in published maps and institutional affiliations.

Share and Cite

MDPI and ACS Style

Hwang, H.; Park, E.; Shin, J. Chain Graph Explanation of Neural Network Based on Feature-Level Class Confusion. Appl. Sci. 2022, 12, 1523. https://doi.org/10.3390/app12031523

AMA Style

Hwang H, Park E, Shin J. Chain Graph Explanation of Neural Network Based on Feature-Level Class Confusion. Applied Sciences. 2022; 12(3):1523. https://doi.org/10.3390/app12031523

Chicago/Turabian Style

Hwang, Hyekyoung, Eunbyung Park, and Jitae Shin. 2022. "Chain Graph Explanation of Neural Network Based on Feature-Level Class Confusion" Applied Sciences 12, no. 3: 1523. https://doi.org/10.3390/app12031523

Note that from the first issue of 2016, this journal uses article numbers instead of page numbers. See further details here.

Article Metrics

Back to TopTop