**Advancement of Mathematical Methods in Feature Representation Learning for Artificial Intelligence, Data Mining and Robotics**

Editors

**Jianping Gou Weihua Ou Shaoning Zeng Lan Du**

MDPI • Basel • Beijing • Wuhan • Barcelona • Belgrade • Manchester • Tokyo • Cluj • Tianjin

*Editors* Jianping Gou Southwest University China

Weihua Ou Guizhou Normal University China

Shaoning Zeng University of Electronic Science and Technology of China China

Lan Du Monash University Australia

*Editorial Office* MDPI St. Alban-Anlage 66 4052 Basel, Switzerland

This is a reprint of articles from the Special Issue published online in the open access journal *Mathematics* (ISSN 2227-7390) (available at: https://www.mdpi.com/si/mathematics/Advancemen t Mathematical methods Feature Representation Learning Artificial Intelligence Data Mining Rob otics).

For citation purposes, cite each article independently as indicated on the article page online and as indicated below:

LastName, A.A.; LastName, B.B.; LastName, C.C. Article Title. *Journal Name* **Year**, *Volume Number*, Page Range.

**ISBN 978-3-0365-7262-8 (Hbk) ISBN 978-3-0365-7263-5 (PDF)**

© 2023 by the authors. Articles in this book are Open Access and distributed under the Creative Commons Attribution (CC BY) license, which allows users to download, copy and build upon published articles, as long as the author and publisher are properly credited, which ensures maximum dissemination and a wider impact of our publications.

The book as a whole is distributed by MDPI under the terms and conditions of the Creative Commons license CC BY-NC-ND.

## **Contents**



### **Haibo Yu, Guojun Lu, Qianhua Cai and Xue Yun**

A KGE Based Knowledge Enhancing Method for Aspect-Level Sentiment Classification Reprinted from: *Mathematics* **2022**, *10*, 3908, doi:10.3390/math10203908 ............... **393**


## **About the Editors**

#### **Jianping Gou**

Jianping Gou (Senior Member, IEEE) received a Ph.D. degree in computer science from the University of Electronic Science and Technology of China, Chengdu, China, in 2012. He was previously a Post-Doctoral Research Fellow with the University of Sydney. He is currently a Professor in the College of Computer and Information Science, College of Software, Southwest University, Chongqing, China. His current research interests include pattern classification and machine learning. So far, he has published over 100 papers in international journals or conferences, such as in IJCV, TNNLS, TII, TITS, T-CYB and TKDD. He is an academic editor of Scientific Programming, an editorial board member of Mathematics, a senior member of CCF, and a senior member of CSIG.

#### **Weihua Ou**

Weihua Ou received a Ph.D. degree in Information and Communication Engineering from Huazhong University of Science and Technology (HUST), China. Currently, he is a full Professor at the School of Big Data and Computer Science in Guizhou Normal University, Guiyang, China. His current research interests include cross-modal retrieval, deep learning, and image processing and computer vision. His research results mean he has published more than 70 papers in prominent journals and conferences, such as IEEE T-NNLS, IEEE T-MM, IEEE T-CSVT, PR, ICPR, and ICME. His publications have been cited in Google Scholar more than 1800 times; his H-Index is 23.

#### **Shaoning Zeng**

Shaoning Zeng received a B.S. degree and M.S. degree from Beihang University (BUAA), Beijing, China, in 2004 and 2007, respectively, and his Ph.D. degree in computer science in the Department of Computer and Information Science, Faculty of Science and Technology from the University of Macau in 2020. He is an Associate Professor at the Yangtze Delta Region Institute (Huzhou), University of Electronic Science and Technology of China. His research interests include computer vision, pattern recognition, machine learning, and deep learning for multimedia and image processing applications.

#### **Lan Du**

Dr Lan Du is a senior lecturer in Data Science and AI in the Faculty of IT, Monash University. His research interest lies in the joint area of machine/deep learning and natural language processing and their applications in different domains, such as public health, where he and his research team are developing cutting-edge NLP technologies for AI-enabled medical NLP. He is best known for his research work on learning and understanding the semantics of the free language texts as a leading Australian researcher in topic modeling.

## *Editorial* **Preface to the Special Issue "Advancement of Mathematical Methods in Feature Representation Learning for Artificial Intelligence, Data Mining and Robotics"—Special Issue Book**

**Weihua Ou 1, Jianping Gou 2,\*, Shaoning Zeng <sup>3</sup> and Lan Du <sup>4</sup>**


The feature representation learning is the basic task that plays an important role in artificial intelligence, data mining and robotics. With the recent rapid development of deep learning, many advanced methods have been proposed and have gained remarkable successes both in academia and in industry, such as auto-encoders, convolutional neural networks, generative adversarial networks, and so on. However, many questions remain unsolved. What makes one representation better than another? What are appropriate objectives for learning representations well? How can security and the algorithm be explained?

This special issue aims to highlight the latest results on the mathematical methods in feature representation learning for artificial intelligence, data mining and robotics, covering several recently reported methods.

The representation learning is the basic problem for computer vision. For example, the authors of [1] comprehensively reviewed the development of vehicle re-identification and revealed that representation learning plays a vital role in the vehicle re-identification. Furthermore, they classified the vehicle re-identification feature representation approaches into two parts: hand-crafted and deep learning based feature representations. In [2], semantic intelligent detection of vehicle color was studied under rainy conditions for jointly detaining and recognizing vehicle color. Specifically, the feature maps of the recovered clean image and the extracted feature maps of the input image are cascaded into the feature pyramid net (FPN) module to achieve joint semantic representation learning. Based on the YOLOX algorithm, works [3,4] proposed to learn representation features for highperformance head counting and garbage quantity identification, respectively.

Based on the fact that the low-level features contain small object information, while the high-level features contain accurate, large object information, the authors of [5] proposed an effective approach by integrating the characteristics of different stages on pedestrian detection. To explore the high-order information representation in vision tasks, the authors of [6] developed second-order spatial-temporal correlation filters for visual tracking, and the authors of [7] studied facial recognition via compact second-order image gradient orientations. In work [8], the authors proposed deep-learning based cyber & physical feature fusion for anomaly detection in industrial control systems.

In [9], discriminative multidimensional scaling based on pairwise constraints for a feature learning model was proposed considering both the topology of samples in the original space and the cluster structure in the new space. The authors of [10] proposed deep large-margin rank loss for feature learning for multi-label image classification. Ref. [11] proposed reinforcement learning-based representation approach for resource allocation in

**Citation:** Ou, W.; Gou, J.; Zeng, S.; Du, L. Preface to the Special Issue "Advancement of Mathematical Methods in Feature Representation Learning for Artificial Intelligence, Data Mining and Robotics"—Special Issue Book. *Mathematics* **2023**, *11*, 940. https://doi.org/10.3390/ math11040940

Received: 4 February 2023 Accepted: 7 February 2023 Published: 13 February 2023

**Copyright:** © 2023 by the authors. Licensee MDPI, Basel, Switzerland. This article is an open access article distributed under the terms and conditions of the Creative Commons Attribution (CC BY) license (https:// creativecommons.org/licenses/by/ 4.0/).

the elastic optical networks, and Ref. [12] presented a deep reinforcement learning based framework for gait adjustment for the patients suffer from physical disabilities.

Representation learning also plays an important role for the low-level image processing task. The authors of [13] studied blind image deblurring and proposed and learned an innovative sparse channel prior. The authors of [14] proposed a joint deep recovery model to efficiently address motion blur and resolution reduction simultaneously. The proposed multi-order attention mechanism comprehensively and hierarchically extracts multiple attention features and fuses them properly by drop-out gating. In [15], the authors reported an image aesthetic quality assessment and proposed a method that includes a representation learning step and a label propagation step. The authors of [16] developed a plug-and-playbased algorithm for mixed noise removal with the logarithm norm approximation model.

Since available source data are collected from related domains, multi-domain adaptation (MDA) has become increasingly popular. Although multiple source domains provide a significant amount of information, the processing of domain shifts becomes more challenging, especially in learning a common domain-invariant representation for all domains. In [17], due to the ambiguity of the category boundary, the authors proposed Dempster–Shafer evidence theory (DST) to reduce category boundary ambiguity and output reasonable decisions by combining adaptation outputs based on uncertainty. Inspired by generative adversarial networks (GANs), the authors of [18] proposed a novel adversarial domain adaptation method with an initial state fusion strategy followed by a domain similarity strategy based on information entropy. In [19], the authors adopt domain adaptation strategy to solve the remaining useful life (RUL) prediction caused by insufficient sample data of equipment under complex operating conditions. The authors of [20] proposed a geometric metric learning method for multi-output learning.

Sentiment classification is an important task in natural language processing. Traditional word-level vector representations provide the same representation for words that express different sentiment polarities in various domains. In [21], the authors proposed a dual-word embedding model considering syntactic information for cross-domain sentiment classification. The authors of [22] reported a graph convolutional network for aspect-based sentiment analysis considering the dependencies between words and the types of these dependencies simultaneously. The authors of [23] proposed a knowledge-enhanced dualchannel GCN for aspect-based sentiment analysis. In [24], the authors developed a triplet contrastive learning network to coordinate syntactic and semantic information for the domain of aspect-level sentiment classification. Works [25,26] show that the effectiveness of the knowledge enhanced sentiment feature learning for aspect-level the sentiment classification and hate speech detection. [27] studied the embedding representation learning for the uncertain temporal knowledge graph while [28] studied Tensor Affinity Learning for Hyperorder Graph Matching.

Some other representative works also show the importance of the feature representation learning. Such as, Ref. [29] studied the 3D reconstruction of self-rotating objects, Ref. [30] presented a fusion verification method cross-site scripting attacks. Ref. [31] proposed a novel feature transformation-based method to improve the robustness of adversarial example by transforming the features of data. Ref. [32] studied the requirement analysis for complex mechanical products scheme design, while Ref. Ref. [33] studied stability of switched systems with time-varying delays.

Briefly, this Special Issue received 65 submissions, 33 of which were published, including 32 research articles and 1 review article. All submissions covered topics from low-level vision feature learning to high-level semantic representation learning, including texts, images and videos from single domains to cross-domains. We believe that these will effectively boost the research on representation learning. We found the selection of papers for this Special Issue very inspiring and we thank the editorial staff and reviewers for their efforts and assistance during the process.

**Author Contributions:** Conceptualization, W.O., J.G., S.Z. and L.D.; methodology, W.O., J.G., S.Z. and L.D.; software, W.O., J.G., S.Z. and L.D.; validation, J.G., S.Z. and L.D.; formal analysis, J.G., S.Z. and L.D.; investigation, J.G., S.Z. and L.D.; resources, J.G., S.Z. and L.D.; data curation, J.G., S.Z. and L.D.; writing—original draft preparation, J.G., S.Z. and L.D.; writing—review and editing, J.G., S.Z. and L.D.; visualization, J.G., S.Z. and L.D.; supervision, J.G., S.Z. and L.D.; project administration, J.G., S.Z. and L.D.; funding acquisition, J.G., S.Z. and L.D. All authors have read and agreed to the published version of the manuscript.

**Funding:** This research received no external funding.

**Conflicts of Interest:** The authors declare no conflict of interest.

#### **References**


**Disclaimer/Publisher's Note:** The statements, opinions and data contained in all publications are solely those of the individual author(s) and contributor(s) and not of MDPI and/or the editor(s). MDPI and/or the editor(s) disclaim responsibility for any injury to people or property resulting from any ideas, methods, instructions or products referred to in the content.

## *Article* **A Soft-YoloV4 for High-Performance Head Detection and Counting**

**Zhen Zhang 1, Shihao Xia 1, Yuxing Cai 1, Cuimei Yang <sup>1</sup> and Shaoning Zeng 2,\***


**Abstract:** Blockage of pedestrians will cause inaccurate people counting, and people's heads are easily blocked by each other in crowded occasions. To reduce missed detections as much as possible and improve the capability of the detection model, this paper proposes a new people counting method, named Soft-YoloV4, by attenuating the score of adjacent detection frames to prevent the occurrence of missed detection. The proposed Soft-YoloV4 improves the accuracy of people counting and reduces the incorrect elimination of the detection frames when heads are blocked by each other. Compared with the state-of-the-art YoloV4, the AP value of the proposed head detection method is increased from 88.52 to 90.54%. The Soft-YoloV4 model has much higher robustness and a lower missed detection rate for head detection, and therefore it dramatically improves the accuracy of people counting.

**Keywords:** head detection; YoloV4; NMS; soft-NMS; people counting

### **1. Introduction**

People counting is a process of counting the number of people in images. It is one of the most important features in a modern intelligent camera. Without this artificial intelligence technique, we have to manually count the number of people in the surveillance video. However, this is unacceptable due the fact that the scale of video data becomes larger and larger. What is worse, it is unlikely to have a precise count when the number of people is too large. For this reason, many automatic people counting methods have been proposed based on the detection of skin color [1], facial features [2], and pedestrians [3]. Nowadays, deep learning, image recognition, and other artificial intelligence (AI) technologies are continuously developing [4]. These intelligent technologies are gradually being applied in our daily life, e.g., face recognition [5] and human action recognition [6]. In typical places, like classrooms and shopping malls, pedestrians are easily blocked by other objects, which prevents a precise counting of people. The good news is that this problem happens relatively infrequently on head counting. A computer can be adapted to detect human heads and, in turn, count the number of people. For example, the people counting system can detect the head of the student's heads in the classroom, so that the teachers can know whether a student is absent or not. In another case, the number of people in a self-study room can be counted and fed back to the mobile phone in real-time by head detecting. In this way, the students can quickly know which self-study room still has available seats, avoiding spending lots of time and energy searching for an unoccupied space. Besides these, a shopping mall owner can analyze the laws of customer flow by detecting heads in each store, which helps them make appropriate marketing strategies. All of these demonstrate that high-performance head detection and counting is one of the most crucial techniques in modern AI systems and applications.

**Citation:** Zhang, Z.; Xia, S.; Cai, Y.; Yang, C.; Zeng, S. A Soft-YoloV4 for High-Performance Head Detection and Counting. *Mathematics* **2021**, *9*, 3096. https://doi.org/10.3390/ math9233096

Academic Editor: Radu Tudor Ionescu

Received: 26 October 2021 Accepted: 28 November 2021 Published: 30 November 2021

**Publisher's Note:** MDPI stays neutral with regard to jurisdictional claims in published maps and institutional affiliations.

**Copyright:** © 2021 by the authors. Licensee MDPI, Basel, Switzerland. This article is an open access article distributed under the terms and conditions of the Creative Commons Attribution (CC BY) license (https:// creativecommons.org/licenses/by/ 4.0/).

#### **2. Related Work**

As a fundamental technique of people counting, head counting belongs to target detection in computer vision. A lot of machine learning methods have been proposed for this task. The traditional machine learning target detection algorithms include AdaBoost based on Harr features [7], SVM based on Hog [8] and LBP [9] features, etc. The principle of these detection algorithms mainly depends on the traditional manually extracted features. The procedure usually includes extracting features from the images, then constructing a classifier for classification, and finally obtaining the targets. However, most of these traditional target detection algorithms cannot produce a high accuracy for real applications, neither have a good enough generalization ability.

Deep neural networks, on the other hand, have a much better performance in target detection. Hinton et al. published a deep neural network using RBM coding [10]. Since then, deep learning methods have dominated the implementation of target detection applications. Currently, deep target detection algorithms are mainly divided into three categories. The first one is the multi-stage algorithms such as R-CNN [11] and SPPNet [12]. Then, two-stage implementations like Fast R-CNN [13], Faster R-CNN [14], Mask R-CNN [15], and HyperNet [16] have shown very promising performance. However, the speed of these methods is not fast enough for real applications. Besides these, there are many one-stage algorithms including YoloV1 [17], YoloV2 [18], SSD [19], Retina-Net [20], AlignDet [21], CenterNet [22], FSAF [23], FCOS [24], and YoloV4 [25]. All of the above onestage algorithms have a fast recognition speed, but the accuracy is far from high enough. There is still a gap to be filled. For this reason, our goal is to improve the YoloV4 model, which represents the current state-of-the-art, to create a high-performance head detection and counting model.

In the conventional YoloV4, non-maximum suppression (NMS) sets the score of adjacent detection frame (adjacent detection frame probably contains object) to 0, then the final output will not contain this detection frame, which caused the occurrence of missed detection. This is harmful in the head-counting application. Soft-NMS algorithm was proposed to attenuate the score of the adjacent detection frame rather than set it to 0 [26]. As long as the score of the adjacent detection frame is greater than the threshold, the final output will contain this detection frame. Inspired by the above inference, this paper proposes a novel head detection method based on YoloV4, which we call Soft-YoloV4 (the NMS in YoloV4 is replaced by Soft-NMS). We make the following novel contributions:


The present paper is organized as follows. Section 2 introduces the algorithm design of Soft-YoloV4 and presents the acquisition of experimental data. The results of Soft-YoloV4 in a real application and the comparison of Soft-YoloV4 between other several methods are presented in Section 3. The conclusion is provided in Section 4.

#### **3. Methods**

#### *3.1. NMS Algorithm*

The YoloV4 model mainly consists of the following parts: CSPDarknet53 (the backbone features extraction network), SPP (the strengthened features extraction network), PANet, and Yolo Head [27]. When the size of the inputted picture is 416 × 416 × 3, the architecture consisting of CSPDarknet53, SPP, PANet, and Yolo Head is shown in Figure 1.

In particular, CSPDarknet53 mainly consists of a series of ResNet [28]. The detailed description can be found in the cspdarknet53 module in Figure 1.

Max-pooling in the SPP architecture mainly uses different pooling kernel sizes of 5 × 5, 9 × 9, 13 × 13. It pools the inputted feature layers and stacks each output. The Max-pooling process reduces the features and parameters of the result and keeps some invariance well, like rotation, translation, expansion, and others. The SPP architecture also increases the receptive field of the output unit nicely.

PANet was proposed by Shu Liu et al. [29]. This architecture makes full use of shallow and deep features. It obtains a more effective feature layer by fusing shallow features and deep features. In YoloV4, PANet is mainly used on three effective feature layers (13, 13, 1024), (26, 26, 512), (52, 52, 256). By fusing the features in PANet, three effective feature layers are available in sizes of 52 × 52 × 128, 26 × 26 × 256, and 13 × 13 × 512, respectively. Yolo Head has two convolution layers: the first layer is a 3 × 3 convolution, the second is a 1 × 1 convolution. For the case of Yolo Head1, the input of Yolo Head1 is 52 × 52 × 128 feature layer, and 52 × 52 × 18 feature layer is obtained after Yolo Head1 processing. Likewise, 26 × 26 × 18 feature layer is obtained after Yolo Head2 processing, 13 × 13 × 18 feature layer is obtained after Yolo Head3 processing. Finally, 52 × 52 × 18, 26 × 26 × 18, and 13 × 13 × 18 feature layers will be the output of YoloV4.

In the original YoloV4 model, NMS is used to sift out the detection frame with the highest scores in the same category. However, the elimination mechanism of NMS is very strict, only considering the detection frame and its *IOU* (Intersection over Union), which easily leads to a missed detection. For example, a missed detection as an instance is shown in Figure 2:

**Figure 2.** A missed detection happened using NMS.

There are three people in Figure 2. However, only two people were detected using NMS, which means a missed detection. Obviously, in a crowded occasion, using NMS algorithm to remove the redundant detection frames when people's heads are blocked by each other is likely to cause a missed detection.

In our improvement, the key step to achieve people counting is detecting people's heads. When there are too many people, their heads are easily blocked by each other. Therefore, we utilize Soft-NMS to replace NMS in the Soft YoloV4 model to fix the problem. Here, we have the following analysis.

#### *3.2. Principle of Soft-NMS Algorithm*

From a mathematical point of view, the mechanism of NMS to remove redundant frames can be expressed as:

$$score\_i = \begin{cases} \quad 0, IOUI(M, b\_i) \ge \text{threshold of } IOUI\\ \quad score\_i, IOUI(M, b\_i) < \text{threshold of } IOUI \end{cases} \tag{1}$$

where *scorei* represents the score of the current detection frame. The best threshold of *IOU* we found is 0.5 after multiple debugging in the data set of this experiment.

In other words, for the detection frame with a higher *IOU* adjacent to one with the highest score, NMS will set the score of this frame to 0 and then remove it. It is very likely to cause a missed detection when in the situation shown in Figure 2. The mechanism of Soft-NMS to remove redundant detection frames can be expressed as:

$$score\_i = score\_i e^{-\frac{l\mathcal{M}I(M,b\_i)^2}{\theta}}\tag{2}$$

It means that Soft-NMS will not directly set the score of the detection frame with a higher *IOU* adjacent to the one with the highest score to 0. Instead, it penalizes the score. The multiplication of the score of the current detection frame and the weight function is to penalize this detection frame. We used the Gaussian function as the weight function: *<sup>e</sup>*<sup>−</sup> *IOU*(*M*,*bi*)<sup>2</sup> *<sup>θ</sup>* (*θ* is the parameter of the weight function. After debugging, the detection effect is the best when *θ* is 0.1). The higher overlap with the highest-score detection frame, the more severe the score of this detection frame decreases. Finally, only the detection frame with a score higher or equal to 0.5 remains. In this way, Soft-NMS can remove the redundant detection frame and reduce the missed detection rate as well. The flow chart of Soft-NMS is shown in Figure 3.

In summary, the main idea of Soft-NMS is as follows. Firstly, it finds out all the detection frames which have a higher confidence level than a certain artificial-set confidence level from an image. The circumstance that the confidence level is lower than this certain confidence level means that there is no target object in the detection frame. Secondly, it processes the detection frames that belong to the same category. Finally, it establishes a set *B* and puts all the detection frames that belong to the same category into this set. The specific algorithm of Soft-NMS is as follows.


After processing Figure 2 by Soft-NMS, the detecting result is as shown in Figure 4.

**Figure 4.** Soft-NMS processing, no missed detection.

#### **4. Experimental Datasets and Evaluation Indexes**

The experiments were conducted on two human heads data sets: Brainwash [30] and SCUT\_HEAD [31]. The Brainwash data set contains 11,438 images, with a total of 81,975 human heads. The scene in this data set is a coffee shop, and the annotation method of the data set is not the Pascal VOC format. It needs to convert to the Pascal VOC annotation format. The SCUT\_HEAD data set contains 4405 images with a total of 11,251 heads. Two data sets include lots of complex scenes, such as classrooms, cafes, daytime, night, and others.

For the case of Brainwash, the size of each image is 640 × 480, 300 images are selected randomly as the testing set, and 11,138 images as the training set. For the case of SCUT\_HEAD, the size of each image is different, 141 images are selected randomly as the testing set, and 4264 images as the training set. The third dataset contains all images of A and B, 441 images are selected randomly as the testing set, and 15,402 images as the training set. For the YoloV4 model, the size of the input image is 416 × 416, so all images will be preprocessed, which means all images will be resized to 416 × 416 before being put into the YoloV4 model.

The indexes of the evaluation model in this experiment include the Precision value, the Recall value, and AP value [32]. The calculation of the Precision value and the Recall value are respectively represented by Formulas (3) and (4):

$$\text{Precision} = \frac{TP}{TP + FP} \tag{3}$$

$$\text{Recall} = \frac{TP}{TP + FN} \tag{4}$$

On the above formulas, *TP* means the prediction result is classified correctly into positive samples, *FP* indicates the wrong classification into positive samples, and *FN* represents the wrong into negative samples. The PR curve is the relationship between the Precision value and the Recall value. We can see the PR curve in Figure 5:

**Figure 5.** The PR curve.

AP is the area enclosed by the PR curve (the blue area). The higher the value of AP, the better the predictive ability of the model.

#### **5. Results**

#### *5.1. Comparison of NMS and Soft-NMS*

To verify the efficiency of the Soft YoloV4 model, the same prediction parameters and data sets (more than 400 complex images) are used for head detection in the YoloV4 model using NMS and Soft-NMS. Judging whether the recognition is accurate is based on whether there is a missed detection.

The AP value of the YoloV4 model before improvement is 88.52%, the Precision is 91.15%, and the Recall is 86.93%. When using Soft-NMS, the prediction result of the Soft YoloV4 model is improved, where the AP value is 90.54%, the Precision is 91.94%, and the Recall is 85.55%.

The comparison results on the third dataset between the YoloV4 model before and after improvement are shown in Table 1.


**Table 1.** The comparison results.

After contradistinction and analysis, we can see that the AP value and the Precision value are improved compared with the original model. However, the Recall has declined. Soft-NMS remove the redundant detection frame by penalizing the score. There is an adjustable parameter θ in Formula (2). A large parameter θ will result in a smaller penalty, then the redundant detection frame may not be removed, which means the model may indicate that there are two objects although there is only one object. The reason why recall has declined is that the parameter θ is large. Recall or Precision cannot be used to evaluate the effect of the algorithm comprehensively, so the AP index is selected. The experiments proved that the AP value using Soft YoloV4 was higher than that using Original YoloV4, even though recall dropped a little. In this way, replacing NMS with Soft-NMS in YoloV4 is effective.

#### *5.2. Comparison with State-of-the-Arts*

The experiments include the following comparison methods: end-to-end people detection (abbreviated as ReInspect [30]), detecting heads using features refined net and cascaded multi-scale architecture (abbreviated as FRN\_CMA [31]), target detection algorithm based on YoloV3 (abbreviated as YoloV3 [33]), and pedestrian head detection algorithm based on

clustering and Faster RCNN (abbreviated as CFR-PHD [34]). All methods use the same evaluation index. The detection results of each method on the Brainwash data set and SCUT\_HEAD data set are shown in Table 2.


**Table 2.** Experimental results obtained on Brainwash and SCUT\_HEAD.

According to the experiment results on the Brainwash data set and the SCUT\_HEAD data set, our Soft YoloV4 algorithm improves detection performance compared to the above algorithms. On the Brainwash data set, the AP value dramatically increases. Compared to the ReInspect, FRN\_CMA, YoloV3, and CFR-PHD algorithms, the improvements are 14.19%, 4.19%, 7.18%, and 2.09%, respectively. On the SCUT\_HEAD data set, the improvements by the AP value are 14.20, 5.40, 7.57, and 4.00%. Therefore, the performance of our proposed improvement can be approved.

Here are three examples, as shown in the following Figures 6–8.

**Figure 6.** One example of people counting results. There are 33 people in the classroom, and it was predicted that there would be 33 people. The result is completely correct.

**Figure 7.** One example of people counting results. There are 77 people in the classroom, and it was predicted that there would be 77 people. The result is completely correct.

**Figure 8.** One example of people counting results. There are 79 people in the classroom, and it was predicted that there would be 81 people. The result is not completely correct.

The result in Figure 8 is not completely correct. With the increase of pedestrian density in a scene, the visibility of heads decreases with the increase of mutual occlusions, resulting in the decrease of head detection, as shown in Figure 8. The possible reason why the model cannot predict objects over heavily overlapped with others is that a detection frame only predicts an object rather than a set of correlated objects.

#### **6. Conclusions**

Compared with other target detection models, the Soft-YoloV4 model in this paper has a higher recognition accuracy and a better people counting effect. Soft-YoloV4 can be built on the server. By recognizing the images sent by the client, the server can return the specific number of people to the client. In this way, the number of people in the classroom can be counted conveniently and quickly, which helps teachers count the number of students, and students do not need to go to each classroom to check whether there is an available seat for them, and then quickly choose a self-study room.

This paper is still unable to accurately recognize the situation that the degree of blockage is too high. In the future, we can consider combining the human body model to determine whether there is a blockage in the detection frame. The network architecture of the target detection model is also too large. Although the accuracy is high, the detecting speed is relatively slow. The next step is to modify the network architecture of the model to speed up the recognition process without significantly decreasing the accuracy. KuralNet is a lightweight deep learning model that strikes a good balance between parameters and effectiveness [35]. In the KuralNet, the inverse residual block with deep convolution and frequency-doubling convolution can be used for signal processing to reduce the computational cost. Perhaps we can learn from this to reduce the complexity of Soft-YoloV4.

This paper proposes a head detection model by improving YoloV4 to count the number of people. By detecting people's heads, we have an improved version YoloV4 using Soft-NMS. In this way, the number of people can be counted more accurately and performance close to the requirement of real applications is obtained. The original YoloV4 model uses the NMS algorithm to remove redundant detection frames. The Soft-YoloV4 model uses the Soft-NMS algorithm. After comparative analysis, Soft-YoloV4 has a higher accuracy in head detection. The AP value of Soft-YoloV4 is 90.54%, 2.02% higher than the original YoloV4 model. Therefore, Soft-YoloV4 is more suitable for head detection on crowded occasions.

**Author Contributions:** Conceptualization, Z.Z. and S.Z.; Data curation, S.X.; Formal analysis, Z.Z., S.X., Y.C. and C.Y.; Funding acquisition, Z.Z.; Investigation, S.X.; Supervision, S.Z. All authors have read and agreed to the published version of the manuscript.

**Funding:** This research was funded by [Young innovative talents project of colleges and universities in Guangdong Province] grant number [2021KQNCX092]; [Doctoral program of Huizhou University] grant number [2020JB028]; [Outstanding youth cultivation project of Huizhou University] grant number [HZU202009].

**Institutional Review Board Statement:** Not applicable.

**Informed Consent Statement:** Not applicable.

**Data Availability Statement:** Data available in a publicly accessible repository that does not issue DOIs. Brainwash dataset and SCUT\_HEAD dataset were analyzed in this study. Brainwash dataset can be found here [https://github.com/aditya-vora/FCHD-Fully-Convolutional-Head-Detector] (accessed on 30 November 2021). SCUT\_HEAD dataset can be found here [https://github.com/ HCIILAB/SCUT-HEAD-Dataset-Release] (accessed on 30 November 2021).

**Conflicts of Interest:** The authors declare no conflict of interest.

#### **References**


## *Review* **Trends in Vehicle Re-Identification Past, Present, and Future: A Comprehensive Review**

**Zakria 1, Jianhua Deng 1,\*, Yang Hao 2,\*, Muhammad Saddam Khokhar 3, Rajesh Kumar 4, Jingye Cai 1, Jay Kumar <sup>4</sup> and Muhammad Umar Aftab <sup>5</sup>**


**Abstract:** Vehicle Re-identification (re-id) over surveillance camera network with non-overlapping field of view is an exciting and challenging task in intelligent transportation systems (ITS). Due to its versatile applicability in metropolitan cities, it gained significant attention. Vehicle re-id matches targeted vehicle over non-overlapping views in multiple camera network. However, it becomes more difficult due to inter-class similarity, intra-class variability, viewpoint changes, and spatiotemporal uncertainty. In order to draw a detailed picture of vehicle re-id research, this paper gives a comprehensive description of the various vehicle re-id technologies, applicability, datasets, and a brief comparison of different methodologies. Our paper specifically focuses on vision-based vehicle re-id approaches, including vehicle appearance, license plate, and spatio-temporal characteristics. In addition, we explore the main challenges as well as a variety of applications in different domains. Lastly, a detailed comparison of current state-of-the-art methods performances over VeRi-776 and VehicleID datasets is summarized with future directions. We aim to facilitate future research by reviewing the work being done on vehicle re-id till to date.

**Keywords:** vehicle re-identification; license plate recognition; video surveillance; feature extraction

### **1. Introduction**

Due to growing global population, commercial activities have been extensively increasing, which leads everyone to access road transportation as a source of mobility. Due to easy accessibility of road transportation system, traffic on roads is massively increasing that not only creates the problem of high traffic congestion but also a drastic increase in carbon dioxide emissions. Along with these issues, road accident risks and the overall transportation complexity increases as well. Therefore, a smooth transportation source and medium is always required for growing commercial activities. Furthermore, traffic management authorities are facing hectic challenges to maintain an undisturbed transportation system. Their task includes tracking the suspicious vehicle, handling traffic jam, and to check whether the vehicle is registered or not. Maintaining undisturbed transportation becomes harder when a large number of vehicles are on the roads.

#### *1.1. Intelligent Transportation System*

Transport is essential for the daily routine functioning of the economy and the society. Over the past few decades there is huge development, deployment, and growth in the

**Citation:** Zakria; Deng, J.; Hao, Y.; Khokhar, M.S.; Kumar, R.; Cai, J.; Kumar, J.; Aftab, M.U. Trends in Vehicle Re-Identification Past, Present, and Future: A Comprehensive Review. *Mathematics* **2021**, *9*, 3162. https://doi.org/10.3390/ math9243162

Academic Editor: Aleksandr Rakhmangulov

Received: 23 October 2021 Accepted: 27 November 2021 Published: 8 December 2021

**Publisher's Note:** MDPI stays neutral with regard to jurisdictional claims in published maps and institutional affiliations.

**Copyright:** © 2021 by the authors. Licensee MDPI, Basel, Switzerland. This article is an open access article distributed under the terms and conditions of the Creative Commons Attribution (CC BY) license (https:// creativecommons.org/licenses/by/ 4.0/).

transport system and have notable effect of development in society and daily life. Therefore, transportation should be redefined as ITS. Currently, not only mechanical and engineering fields are doing research and development for better transportation facility, but computer science related concepts are also playing major role for instance, artificial intelligence (AI), communication, machine learning (ML), internet and so many other emerging technologies.

Due to traffic problems in China, the average speed of vehicle has been decreased to 20 km/h, even in some areas between 7 and 8 km/h [1,2]. Such low speed of vehicles for a long time on roads is a threat for the natural environment of the world like exhaust emissions that deteriorate air quality. In order to deal traffic problems and alleviate the pressure of vehicles on roads, the governments are investing too much on research and ITS development. ITS based infrastructure strengthens the relationship between people, vehicles, and road networks.

ITS have the capability to enhance the performance of current transportation system and make it efficient, safe, comfortable as well as reduces harmful environmental consequences. ITS based real-time applications include electronic payment systems, traffic management systems, emergency vehicle pre-emption management system, advanced vehicle control systems, weather precautionary measures management system, and commercial vehicle operations. Applications of ITS now regularly deployed, such as closed-circuit television surveillance, automatic car parking, electronic toll collection, border control, and in-car navigation equipment. Therefore, an ITS is needed to analyze the recorded video, control, maintain and communicate to ground transport and improve mobility and manage problems efficiently. Furthermore, Figure 1 demonstrates the ITS based environment.

**Figure 1.** Depicts smart city and intelligent transportation system.

#### *1.2. Video Surveillance*

In metropolitan cities, cameras are widely adopted in numerous areas to monitor activities [3]; but most of the current video surveillance systems provides the facilities like capture, storage and distribute video, while leaves unwanted event detection task totally on human operators. Human operator-based monitoring of the surveillance system is not as efficient and a very labour-intensive task, as shown in Figure 2. It requires full visual attention by watching the video in control room and it is very difficult for single person as everyday tasks. Specifically, the ability to focus and react to occasionally occurring activities that require full attention. Furthermore, millions of hours of video data generated by multiple cameras over surveillance network require large number of operators for the task. It's almost infeasible, inefficient and costly to obtain real-time prevention.

**Figure 2.** Shows view of manually traffic monitoring at control room.

Due to digital cameras and the advent of powerful computing resources, automatic video analysis become possible and more and more common in video surveillance applications [4], thus reduces the labor cost. Practically, the objective of automatic video analysis for safety, security, and surveillance is to detect automatically unwanted events or situations that need security attention. Automated video analysis not only process the data faster but also significantly improve the ability to preempt incidents on time. Augmenting security staff with automatic processing increases their efficiency and effectiveness. For the posterior mode, searching a specific vehicle in hundreds of hours of camera recorded video footage needs large number of officers to do this task and takes a lot of time. Automated content-based video retrieval reproducing and assisting human analysis on recorded videos largely enhances forensic capabilities. Furthermore, the surveillance systems application's main goal is to develop intelligent systems that automate the human decision-making mechanism.

An important task to maintain a smooth transport system is to re-identify the specific vehicle that appeared in different cameras over the surveillance network. The vehicle re-id module in ITS should recognize same vehicle that appears in surveillance cameras installed in different geographical locations. Specifically, vehicle re-id can be treated as a fine-grained recognition problem [5,6] that identifies the subordinate type of input class. However, the vehicle re-id problem's granularity is much finer since the system should search specific targeted vehicle instead of the same vehicle model and type. Moreover, recently vehicle reid gained more attention in research community because of various significant real-world applications. It is a difficult task to analyze the surveillance environment for effective vehicle identification. An example of practical environment can be seen in Figure 3, where surveillance cameras can be observed over roads and public places.

**Figure 3.** Illustrates the practical scenario of surveillance camera network.

#### *1.3. Re-Identification*

In a surveillance camera without overlapping vision, re-id is defined as a task to identify objects' captured images taken from different camera networks. It is used to know whether the object image captured by multiple surveillance cameras matches the same object or a different image of the object. Object re-id technology has a significant role in multi-object tracking, intelligent monitoring, and other fields. Recently, re-id gained extensive attention in the computer vision research community. The main application fields of an object re-id are vehicle re-id and person re-id.

Formally, re-id can be defined as a matching task. A targeted image (Query) is matched against a gallery set image (representing the previously captured images in the surveillance camera network). Thus, the query of re-identifying targeted image can be defined by its descriptor P, and it is formulated as:

$$T = \arg\!\!\!\!\!\!T\_i \\$\!\min D \;\!\!\/(T\_i, Q), T\_i \in \mathcal{T} \tag{1}$$

where T = {*T1*, ... ,*TN*} is a gallery set of N image descriptors, and *D*(,) represents the distance metric. Therefore, to solve above the re-id problem, it is important first to answer how we can represent targeted object using a descriptor to robust performance. Furthermore, rest of the paper investigates this topic.

*Vehicle Re-identification*: Similar to person re-id, vehicle re-id is also a demanding task in camera surveillance. Aim of vehicle re-id is to match vehicle images with already captured vehicle images over the camera network [7–9]. However, due to surveillance cameras on the roads for smart cities and traffic management, the demand to perform vehicle search from the gallery set is increased. Vehicle re-id is similar to several other applications, such as person re-id [10], behavior analysis [11], cross-camera tracking [12], vehicle classification [13], object retrieval [14], object recognition [15,16], and so on.

To understand designing the vehicle re-id system, we analyze how a person re-identifies the vehicle. A person re-identifies vehicle by keeping in mind some characteristics like unique feature, color, size etc., our brain and eyes are learned to detect and identify different objects, as shown in Figure 4 and how system identify vehicle is shown in Figure 5.

**Figure 4.** Shows how human re-identify vehicle?

**Figure 5.** Illustrates how machine re-identify vehicle?

#### *1.4. Vehicle Re-Identification Practical Application*

There are many significant real-world applications where vehicle re-id system can be utilized and satisfies the great needs of our practical life. However, some major applications are briefly discussed as follows:

• Suspicious vehicle search: Most of the time terrorists use vehicle for their criminal activities and soon leave that spot on vehicles. It is very difficult to fast search suspicious vehicle manually from surveillance camera.


• Vehicle retrieval: In this case, re-id is associated with a recognition task. The specific query with a target vehicle is provided, and all the related vehicles are searched in the database. The re-id task is thus employed for image retrieval and usually provides ranked lists, similarly related items, and so on.

However, due to the vast range of practical applications that employ vehicle re-id system and to limit the scope of the paper, this review article mainly focuses only on vision-based methods. Moreover, it is very hard to cover all technologies for vehicle re-id in one survey paper but despite of that we have summarized the strengths and weaknesses of all technologies in Table 1. Therefore, this review article focuses on the use of visionbased approaches including, Appearance, license plate, contextual information etc. In last few years, there has been lack of comprehensive study of the overall problem and different solutions. This paper fills the gap by providing a detailed review covering main challenges, different approaches, and applicability. In addition, it provides the analysis and comparison of existing vehicle re-id methodologies. Aiming to facilitate other researchers, this review also provides the required information about the publicly available datasets and discusses several important research directions with under-investigated open issues to narrow the gap between the closed-world and open-world applications, taking a step towards real-world re-id system design.


**Table 1.** Summary of strengths and weaknesses of different vehicle re-id technologies.

Two ways for writing surveys can be found in the object re-id literature; first way gives a deep insight into methodologies, whereas the second way covers the overall perspective related to the problem [17,18] This survey includes both methodologies and overall perspective of vehicle re-id literature. We also review the recent development of vision-based vehicle re-id along with other technologies. In addition, this survey draws a timeline to introduce important milestones for vehicle re-id, which can be seen in Figure 6.

**Figure 6.** Milestones existing re-id approaches in the Vehicle re-id history.

The paper is organized in the following way. Sections 2–5 provide an overview of recent state-of-the-art proposed methodologies in various technologies. Section 6 presents a publicly available benchmark dataset that covers various real-world surveillance scenarios. Section 7 discusses the challenging problems in vehicle re-id. Section 8 sheds light on the evaluation measures for vehicle re-id. Section 9 analyzes and compares the experimental results of various approaches. Meanwhile, the last section concludes and discusses future work.

The main contributions of this review paper is summarized as follows:


#### **2. Methods Used for Vehicle Re-Identification**

Traditionally different traffic sensors are adopted to know the vehicle presence, volume, occupancy, and speed data. Nowadays, new sensor-based technology is adopted to get more information like origin-destination estimation, travel time and other travel information applications. Based on different technologies vehicle re-id approaches can be divided into six categories, as depicted in Figure 7.

**Figure 7.** Shows vehicle re-id methods.

#### *2.1. Magnetic Sensor-Based Vehicle Re-Identification*

An electromagnetic field is used to detect the vehicle, when it crosses and it is used to provide occupancies, counts, and vehicle speed. However, vehicles are made up of metal. It disrupts the magnetic field, so magnetic signature regenerated by one vehicle is different from the other vehicle [19]. This approach helps in re-identifying a specific vehicle. Moreover, for ITS the Berkeley's company sells magnetic sensors with the name "Sensys Network" [20]. A straight-line re-id rat is 50%, and the approach reduces the magnetic signature peak value sequence for calculating the signature distance to prevent vehicle speed dependency [21]. For real-time vehicle re-id processing unit is associated to thousands of magnetic sensor nodes and a large number of magnetic sensors that generate massive data streams, and to deal with real-time data stream mining, high-performance FPGAs and low-performance microcontroller are used [22,23]. Sylvie Charbonnier et al. [24] studied various approaches for vehicle re-id by adopting vehicle tridimensional magnetic signature measured with sensor, when car passes sensor and changes in the magnetic field were induced and measured in three different directions like X, Y, Z. Rene O. Sanchez et al. [25] investigated vehicle re-id approaches by using wireless magnetic sensors and compares vehicle magnetic signatures to overcome the limitations of system while vehicle is stopped or moving slow at detection station.

#### *2.2. Inductive Loop-Based Vehicle Re-Identification*

Vehicle can be re-identified using inductive loops embedded in the road surface for the detection of vehicle. From those loops, a fingerprint is captured for every car passing by. The travel time can be determined when those fingerprints or certain aspects of them coming from different locations are compared with each other. Jeng and Chu [26] designed a real-time inductive loop signature-based vehicle re-id method named RTREID-2M. Inductive signature is used for vehicle re-id and much efforts have been done to utilize inductive loop signature technology. Inductive signature-based vehicle re-id algorithms identify specific vehicle at downstream detection station by matching the inductive signature at upstream detection station, considering that vehicle have same signature by crossing different loop detection stations [27]. Vehicle re-id researchers have proposed several algorithms like optimization, piecewise slope rate (PSR) matching [28], lexicographic and blind deconvolution [29], all these proposed approaches are for raw signature processing, signature

feature extraction, and vehicle matching. R.J. Blokpoel [30] proposed an algorithm with different sizes of a single loop. Validation tests depict re-id rates up to 100%, when loops are identical to the similar type and 88% when compare between different types.

#### *2.3. Global Positioning Systems-Based Vehicle Re-Identification*

Global Positioning Systems (GPS) technology is an essential and valuable tool for ITS and traffic surveillance, because it provides positioning data for every single vehicle [31,32]. There are still some limitations in vehicle re-id using GPS like varying accuracy, minimal fleet penetration, and signal loss because of tunnels, trees, tall buildings, etc. GPS is adopted with vehicles to locate and get travel information along with longitude and latitude information and timestamp. GPS is special form of mobile sensing technology that enables the devices like GPS logger, GPS cellular phones, and smartphones moves with vehicles to get speed information and location continuously. However, different types of vehicles have different behaviors such as deceleration rates, acceleration, and speed variation. This encourages the author to adopt GPS technology for vehicle classification and re-id [33].

#### *2.4. Vision-Based Vehicle Re-Identification*

In computer vision, the aim of vehicle re-id is to identify specific vehicle that appeared over in multiple cameras network. The large surveillance camera network is deployed in different areas of public places like hospitals, parks, colleges, roads, and other areas. It is also difficult and tiresome job for security officers to track targeted or specific vehicle over multiple camera network manually. However, computer vision techniques can automatically re-id a vehicle and basic five main working steps are discussed below (shown in Figure 8).

**Figure 8.** The flow of designing a practical vehicle re-id system, including five main steps.


vehicle videos or images of the dataset. It is a key step in vehicle re-id systems and a widely explored area in literature.

• Step 5: Vehicle Retrieval: Vehicle retrieval is a task of matching targeted vehicle (query image) over a gallery set.

#### **3. Vision-Based State-of-the-Art Vehicle Re-Identification Approaches**

Vision-based methods focus on examining robust feature representations to calculate the distance between features of two-vehicle images and vehicles with the same class have a low distance otherwise high. However, vehicle features are difficult to distinguish when a captured vehicle image consists of similar colors and pose. In this section gives an overview of recent works on computer vision-based methods for vehicle re-id problem, furthermore general approach for vision-based method is shown in Figure 9. Several impressive vision-based methods have been proposed to improve vehicle re-id performance either by modifying the existing DL architectures or designing a new deep neural network (DNN). Generally speaking, eight different techniques have been employed in this research area: (A) Feature representation for vehicle re-id, (B) Similarity metric for vehicle re-id, (C) Traditional machine learning-based vehicle re-id, (D) View-aware-based vehicle re-id, (E) Fine-grained visual recognition-based vehicle re-id, (F) Generative adversarial networkbased vehicle re-id, (G) Attention mechanism, (H) License plate-based vehicle re-id.

**Figure 9.** The vehicle re-id problem: given a Query, find the matching candidate in the gallery.

#### *3.1. Feature Representation for Vehicle Re-Identification*

Feature representation play vital role in progress of many different computer vision tasks. In this regard, vehicle re-id features representation approaches can primarily be classified into two parts: hand-crafted and deep learning features representations. Handcrafted feature representations BOWCN [36], and LOMO [37] initially utilized in person re-id and then applied directly on vehicle re-id task. Some well-known deep learningbased feature representations such as GoogLeNet [38], VGGNet [39], AlexNet [40], and, ResNet [41] are used for vehicle re-id. The researcher also adopts these baseline models in their approaches for vehicle re-id. Such as, NuFACT [42] takes GoogLeNet [38], FACT [43] uses AlexNet [40], DRDL [44] utilizes VGGNet [39] to extract features of vehicles. Various type of loss functions are utilized to efficiently learn vehicle image discriminative feature representation to train deep learning-based model vehicle re-id; such as the deep joint discriminative learning (DJDL) [45] approach uses identification, and verification and triplet loss functions improved triplet convolutional neural network [46] uses classification and-oriented and triplet loss function to extract discriminative feature representation.

#### *3.2. Traditional Machine Learning-Based Vehicle Re-Identification*

In traditional machine learning (TML), we adopt feature engineering to artificially clean and refine data. However, previously proposed approaches are grouped into for robust features extraction and learning discriminative classifiers. In TML extracted features are directly computed from image pixels and it is low level feature representation. Moreover, TML-based algorithm design is expensive and difficult. Broadly, it consists of two steps feature extraction and feature classification. There are many algorithms proposed for low level feature extraction for instance speeded up robust features (SURF) [47], scale-invariant feature transform (SIFT) [48], and histogram of oriented gradient (HOG). After feature extraction different classifiers are applied, which are widely used in TML approaches such as linear regression, k-Nearest Neighbor (KNN) [49], logistic regression, support vector machine (SVM) [50], bayes classification [51], and decision tree [52]. The features extracted using SIFT are local features of the image, which maintains the scale scaling, invariance of rotation, and brightness variation. In addition, it also maintains a particular degree of stability to affine transformation, the viewing angle change, and noise.

Moreover, one of the feature descriptor adopted for targeted object detection in image processing is HOG. The large area of image features are formed by calculating the gradient direction histograms of its local regions. However, an overlapping local contrast normalization approach is adopted to improve the performance. Zapletal and Herout [53] utilize the color histogram and the HOG features with linear regression to re-id vehicle. Chen et al. [54] designed a method to re-id vehicles grid-by-grid with HOG features extraction for coarse search and further improves the result by utilizing histograms of matching pairs. In [55], vehicle re-id local variance measures are applied using local binary patterns and joint descriptors.

#### *3.3. Similarity Metric for Vehicle Re-Identification*

Performance of vehicle re-id can be improved by selecting appropriate distance matrices regardless of appearance representation. Distance metric learning approaches [56] are thoroughly studied in image retrieval and recognition tasks, in which matric space is defined in such a way that features that belong to same class are kept closer and different are at distant as shown in Figure 10. In the re-id task, image features are known as appearance descriptor. In this the learned distance matric in appearance space minimizes the distance for descriptor between same vehicles and maximizes distance for descriptor of different vehicles. As in various face recognition algorithms [57,58] uses Euclidean and Cosine distance matric to measure the similarity, and FACT [43] also utilizes Euclidean and cosine distance metrics to measure similarity between the pair of vehicle for re-id. Similarly, NuFact [42] utilizes the Euclidean distance to measure the similarity between the probe and gallery set vehicle images in discriminative null space [59]. Furthermore, deep relative distance learning (DRDL) [44] studied a two-branch convolutional neural network to covert the raw vehicle images into a Euclidean space, so that distance can be used directly to measure the similarity of two individual vehicles.

Pairwise constraints are required for matrix learning and it is done in supervised fashion. During the training features of appearance descriptor are in pair and labelled as positive and negative. It is totally depending on appearance descriptor whether it belongs to the same vehicle or different vehicle. Appearance descriptors are represented as *x*1, *x*2, ... , *xn*, here *n* represents number of training instances and the dimensionality of every instance is represented by m. The aim of metrics learning is to learn distance metric and matrix *D* ∈ Rmxm represents it; thus, the distance between pair of appearance descriptors *xi* and *xj* is as follows:

$$d\left(\mathbf{x}\_{i\prime}\mathbf{x}\_{j}\right) = \left(\mathbf{x}\_{i} - \mathbf{x}\_{j}\right)^{T} D\left(\mathbf{x}\_{i} - \mathbf{x}\_{j}\right) \tag{2}$$

*d*(*xi*, *xj*) is a true metric only possible when matrix *D* is symmetric positive semidefinite. This issue is resolved by adopting convex programming as follows:

$$\min\_{\{\mathbf{x}\_i, \mathbf{x}\_j\} \in \operatorname{Pos}} \|\left(\mathbf{x}\_i - \mathbf{x}\_j\right)\|\_{\mathrm{D}}^2 \\ \text{s.t.} \\ D \ge 0, \text{ and } \sum\_{\{\mathbf{x}\_i, \mathbf{x}\_j\} \in \operatorname{Neg}} \|\left(\mathbf{x}\_i - \mathbf{x}\_j\right)\|\_{D}^2 \ge 1 \\ \tag{3}$$

where *Pos* represents the positive label in training samples, and it is the appearance descriptor of the same vehicle, whereas *Neg* represents the negative label in training samples and it is the appearance descriptor of a different vehicle.

**Figure 10.** Vehicle re-id system based on metric-based methods.

#### *3.4. Fine-Grained Visual Recognition-Based Vehicle Re-Identification*

Vehicle re-id is fine-grained recognition task, and fine-grained vehicle recognition can be divided into two parts, representation learning model and part-based model. Many approaches are proposed [60] that utilize alignment and part localization for feature extraction of main parts and then those parts are compared for vehicle re-id. Xiao et al. [61] studied weakly supervised way in fine-grained domain using reinforcement learning to get discriminative parts of vehicle. In addition, Lin et al. [62] presents a bilinear architecture to get the pair of local features in which output descriptors of two networks are merged in an invariant way. Boonsim et al. [63] presents an approach for fine-grained recognition of vehicles at night. The authors utilize shape and lights of vehicle visible in night and relative position to identify model and make of a vehicle, which are visible from the front and rear side.

In fine-grained recognition, local region features are extracted from different points such as logo, annual inspection stickers, and decorations, to make system more efficient and robust various attributes of vehicles are also incorporated like color, model, and type information. For example, in different vehicles with similar global appearance in Figure 11, all the vehicles are different in each column. The differences between each vehicle are pointed out with red circles. From Figure 11 it can also be seen that the differences between similar global appearance vehicles lie in some local regions.

**Figure 11.** Shows vehicles that are same in global appearance but differentiated by local regions that are marked in red circle.

#### *3.5. View-Aware-Based Vehicle Re-Identification*

Most of the above discussed deep learning features [38,39,45] are general, and these learned features end at multiple fully connected layers. Despite that, all these approaches performance is not bad. But these approaches are not designed for a specific problem related to view point variation. It is a central challenge in vehicle re-id task. Vehicle re-id is closely related to person re-id, however, intra-class variation is a major problem in person re-id in which the same person looks different by changing viewpoint. Zhao et al. [64] designed a novel approach that achieved satisfactory results and the method was based on person body parts guided for re-id. Wu et al. [65] proposed a study with pose prior that made identification efficient and robust to viewpoint. Zheng et al. [66] proposed the pose box structure that generates the pose estimation after affine transformations. It is also challenging and crucial in vehicle re-id, because image viewpoint is the same as a consequence of vehicle rigid motion. Wang et al. [67] studied the orientation invariant feature embedding to solve the issue of viewpoint variation influence on vehicle re-id system. Prokaj et al. [68] proposed a pose estimation-based approach to handle multiple viewpoint problem. Yi Zhou et al. [69] studied uncertainty in the viewpoint of vehicle re-id system and designed end to end deep learning-based architecture on Long Short-Term Memory (LSTM) bi-directional loop and concatenated CNN, in this model author takes full advantage of LSTM and CNN to learn the different viewpoints of vehicle. And also, there are many more approaches are proposed to handle the view point variation issue in vehicle re-id such as adversarial bi-directional long short-term memory (LSTM) network (ABLN) [70], spatially concatenated convolutional network (SCCN) and CNN-LSTM bidirectional loop (CLBL) [69]. However, all these approaches need vehicle datasets. Every vehicle image is densely sampled camera viewpoints. Despite that, it is hard to gain in real-time camera surveillance systems. Therefore, there is still ample room for vehicle re-id by thoroughly considering viewpoint variations.

#### *3.6. Generative Adversarial Network-Based Vehicle Re-Identification*

GAN [71] is one of the hot technique in semi-supervised and unsupervised learning algorithms. It is proposed by Goodfellow by deriving backpropagation signals through a competitive process involving a pair of networks. GAN can be adopted in different applications, like style transfer, image synthesis, image super-resolution, semantic image editing, image super-resolution, classification and person/vehicle re-id. The GAN-based vehicle re-id flow is shown in Figure 12. At present, there have been many papers that adopt GAN to solve the problems of vehicle re-id. The existing datasets have low diversities and

small scales, which leads to poor generalization performance on the trained models. To solve this problem. Generative Adversarial Network (GAN) in Object re-id is among the latest research trends in the deep learning approaches. GANs achieved significant performance in in many fields such as translation [72] and image generation [73]. Furthermore, recently GANs are also utilized for re-id problems (person re-id and vehicle re-id) [74,75]. Zheng et al. [76] proposed a method in which they used the DCGAN [73] with Gaussian noises to generate unlabeled person images before training. Wei et al. [77] studied a PT-GAN to minimize the domain gap by transferring person images between different styles. Zhou et al. [78] proposed GAN based model to solve cross-view vehicle re-id problem by generating vehicle images in different viewpoints. Lou et al. [74] designed a model to generate the same and cross-view vehicle images from original images to facilitate training model. Zhou et al. [78] proposed a conditional generative network to generate cross-view images from desired vehicle pairs.

**Figure 12.** Vehicle re-id system based on GAN diagram.

Aihua et al. [79] proposed a framework that primarily comprises view transform and vehicle re-id model. The view transform model comprises of GAN to generate vehicle images in different views to overcome the viewpoint related issue. The vehicle re-id model consists of one backbone, three subnetworks, and one embedding network. The overall framework is illustrated in Figure 13.

**Figure 13.** Overview of deep feature representations guided by the meaningful attributes.

#### *3.7. Attention Mechanism*

The neural networks at some extent imitate human brain actions in simple way. Attention Mechanism is also an effort to develop a technique that concentrate on selective thing/actions that are relevant to task and neglecting the others in neural networks. Currently, researchers are trying hard to design an efficient attention-based neural network for vision-related applications. Such as image classification [80], fine-grained image recognition [81], action recognition [82], and re-id [83]. The commonly followed strategy in these approaches is integrating a hard part selection subnet work or soft mask branch into the deep networks. Such as Zhao et al. [84] studied the part-localization CNN for predicting salient parts and features of these parts exploit for person re-id. Wang et al. [80] utilizes residual learning technique [41] to develop the Residual Attention unit for soft mask learning and gained significant image classification results. Though, only the soft pixel-level attention has very small participation in the performance of vehicle re-id task. It gives only global information like vehicle logo, annual inspection stickers, and personalized decorations. So, they presented joint learning framework for vehicle re-id in which both soft and hard level attentions are utilized Furthermore, Guo et al. [85] proposed a model with one

trunk and two salient part branches for hard part level attention. Trunk branches extracts the global features of vehicle and salient branches extracts the features from vehicle head parts and windscreen. For soft pixel level attention residual attention modules are inserted into trunk and salient branches. Lastly, global and salient part features of vehicle are put to gather for effective feature representation with the supervision of multi-grain ranking loss for vehicle re-id task and complete framework is shown in Figure 14. Furthermore, comparison of different attention mechanism-based approaches are shown in Table 2.

**Figure 14.** An overview of Two-level Attention network supervised by a Multi-grain Ranking loss (TAMR) structure.



#### *3.8. License Plate-Based Vehicle Re-Identification*

Vehicle re-id using license plate is simply the system's ability to automatically detect, extract, and recognize license plate characters automatically from vehicle image. License plate recognition (LPR) is a conventional method to identify a specific vehicle [90]. An automatic LPR system is mainly divided into two parts, first license plate detection and second, interpreting the vehicle license plate image into numerically readable form. There are many approaches proposed in past for LPR. However, it is still challenging due to some reasons like vehicle image is not captured perfectly, some characters may be occluded, illumination, variation in size of an image, camera distance and zooming. Li and Shen [91]

studied a sequence labelling-based approach to recognize the vehicle license plate without character-level segmentation using recurrent neural networks (RNN). The input feature sequence to RNN is extracted using a nine-layer CNN. Super-resolution is also proposed to restore a license plate image to improve performance. Shi et al. [92] designed convolutional recurrent neural network (CRNN) for scene text recognition that incorporates feature extraction, transcription and sequence modeling into a unified framework. Moreover, Figure 15 shows the basic steps of license plate-based vehicle re-id.

**Figure 15.** The flow of the license plate-based vehicle recognition.

#### **4. Spatio-Temporal Cues-Based Vehicle Re-Identification Approaches**

Introducing contextual information in vehicle re-id system can increase the efficiency and reduces irrelevant vehicle gallery images. As compared to person, for vehicle it is necessary to follow traffic rules for instance, practically vehicle follows speed limits, routes, and traffic lanes, so in this scenario vehicle moving in between different cameras at specific time and location helps a lot in vehicle re-id. Spatio-temporal cues are greatly examined for various objects association in surveillance camera network [93]. As in [94] concluded few key findings. Firstly, one specific captured vehicle in one camera cannot appear at more than one location at the same time. Secondly, along the time vehicle is moving continuously based on these finds, authors use location and time slots to eliminate irrelevant vehicle images from list as demonstrated in Figure 16. Ellis et al. [93] proposed approach that trains the model on temporal and topological transitions of trajectory data and is acquired from surveillance camera network. Loy et al. [95] presented a method for obtaining the spatio-temporal topology of surveillance camera network using multiple camera correlation analysis. Furthermore, time and location information is also exploited for vehicle re-id task. Liu et al. [96] studied a spatio-temporal affinity method for quantifying different pairs of vehicle images. Shen et al. [97] also introduces the spatio-temporal path data for vehicle re-id.

**Figure 16.** Depicts the spatio-temporal information.

#### **5. Hybrid Methods-Based Vehicle Re-Identification**

To further enhance the robustness and efficiency of vehicle re-id system researchers have proposed the approaches in which they combined the two or more different techniques, for instance Liu et al. [42] proposed a framework with name PROVID, in this framework author not only consider the visual appearance of vehicle for re-id system, but also exploits the license plate and spatio-temporal cues of vehicle as shown in Figure 17. Jiang et al. [98] studied vehicle re-id algorithm using appearance and contextual information, author examines the multiple attributes during training like vehicle model, color, and vehicle image features individual respectively and sort vehicles on the bases of spatiotemporal cues. Shen et al. [97] designed a two-step architecture, a pair of query vehicle images with contextual information and visual temporal path are produced using Markov Random Fields (MRF) chain model, and then the similarity score is generated.

**Figure 17.** The architecture of the PROVID framework.

#### **6. Vehicle Re-Identification Benchmark Datasets**

Datasets are the key components to measure the performance of vehicle re-id system and should reflect the practical surveillance camera data. We cannot avoid some factors like occlusion, background clutter, change in illumination etc. to evaluate the approach [99]. However, multiple benchmark datasets are available, some well-known like VeRi-776, VehicleID, etc. that are prepared by the research community to evaluate vehicle re-id techniques. Table 3 and Figure 18 lists the commonly used vehicle re-id dataset with attributes. Furthermore, a brief description of the most popular datasets is as follows:


**Table 3.** Characteristics of publicly available datasets.

**Figure 18.** Depicts the number of total images per vehicle re-id dataset.

*VeRi-776:* [43] VeRi-776 is a publicly available vehicle re-id dataset, and often adopted by the computer vision researcher community. Dataset images are gathered in real scenario using surveillance cameras, and the total images in dataset are 50,000 of 776 different vehicles. Each captured vehicle images have 2 to 18 viewpoints with different resolution, occlusion, and illumination. Furthermore, spatio-temporal relations and license plate are annotated for all vehicles. To make dataset more robust, images are labelled with color, type, and vehicle model. In Figure 19 various types of vehicles from VeRi dataset are shown.

**Figure 19.** Depicts the sample images of VeRi-776 dataset.

*PKU VehicleID:* [44] VehicleID dataset is developed by Peking University with the funding of the Chinese national natural science foundation and national basic research program of China in the national engineering laboratory for video technology (NELVT). The vehicle dataset consists of 221,763 total images of 26,267 vehicles, and all the images are captured during daytime in a small town of China with multiple surveillance cameras with 10,319 vehicles model information i.e "Audi A6L", "MINI-cooper" and "BMW 1 Series" are labeled manually. In Figure 20 different vehicles from PKU vehicleID dataset are shown.

**Figure 20.** Depicts the sample images of PKU VehicleID dataset.

*Vehicle-1M:* [100] Vehicle-1M dataset is developed by the University of Chinese Academy of Sciences in the National laboratory of pattern recognition, Institute of Automation. This benchmark dataset contains 55,527 vehicles with 400 different vehicle models, and the total captured images are 936,051. Surveillance cameras capture all the images in China's town at day and night time and consist of a vehicle's rear and head view. Moreover, each image in this dataset is labeled with a model, make, and vehicle year. Images from Vehicle-1M are shown in Figure 21.

**Figure 21.** Depicts the sample images of vehicle-1M dataset.

*BoxCars21k:* [35] BoxCar116k dataset is developed using 37 surveillance cameras, and this dataset consists of total images 116,286 of 27,496 vehicles. For the preparation of dataset, 45 brands of the vehicle are used. Moreover, captured images of the vehicle in the

dataset are in an arbitrary viewpoint, i.e., side, back, front, and roof. All vehicle images in the dataset are annotated with 3D bounding box, model make, and type. However, some sample images are shown in Figure 22.

**Figure 22.** Depicts the sample images of BoxCar21k dataset.

*VehicleReId:* [53] VehicleReId dataset provides 47,123 vehicle images and all these images are extracted from five different video shots by using two surveillance cameras, out of total images 24,530 vehicle image pairs are human annotated.

*CompCars:* [101] CompCars dataset consists of two types of image nature (1) Webnature images (2) Surveillance-nature images. There are total of 136,726 web-nature images in which there are 163 car makers with 1716 car models. However, in surveillance-nature, the total car images are 50,000 that are captured from the front view. Samples of CompCars dataset are shown in Figure 23.

**Figure 23.** Depicts the sample images of CompCars dataset.

*VRIC:* [102] VRIC contains 5622 vehicles with 60,430 total images with different traffic road surveillance cameras and images captured at day and night. Images with different angles, viewpoints, occlusions and illuminations from VRIC dataset are depicted in Figure 24.

**Figure 24.** Depicts the sample images of VRIC dataset.

*VRID:* [103] This dataset contains total 10,000 images and specially developed for vehicle re-id with 326 surveillance cameras the VRID images were captured from 7 a.m. to 5 p.m. for one week. In the development of the dataset there are 1000 vehicles used with 10 commonly used vehicle models, and at least 10 times each vehicle is captured over a camera network in Guangdong city, China. Surveillance cameras have been fixed in a practical environment with arbitrary directions and angles; therefore, dataset images have various resolutions and poses distributed from 400 × 424 to 990 × 1134 pixels.

*VERI-Wild:* [104] Collects a large-scale vehicle re-id dataset in the unconstrained environment. For dataset development, an existing large CCTV system is utilized. It consists of 174 cameras across, recorded till one month (30 × 24 h). The CCTV cameras are spread over a large city consists of 200 km2. The dataset includes 12 million vehicle raw images, and 11 volunteers cleaned the dataset for one month. After data cleaning and annotation, 416,314 vehicle images of 40,671 identities are collected. VERI-Wild dataset images with viewpoint changes, illumination variations, occlusion, and background variations are presented in Figure 25, and statistics are shown in Figure 26.

**Figure 25.** Depicts the sample images of VERI-Wild dataset.

**Figure 26.** Illustrates the characteristics of VERI-Wild dataset. (**a**) The number of identities across multiple surveillance cameras; (**b**) Total number of IDs captured in different slots of each day; (**c**) Division of vehicle types; (**d**) Division of vehicle colors.

#### **7. Challenges Regarding Vehicle Re-Identification**

The vehicle re-id is among an essential and challenging task, and it is defined as, either any specific vehicle captured in one camera has already appeared over multiple camera network or not. With the increasing need for automated video analysis, the vehicle re-id receives increasing attention these days in the computer vision research community. Therefore, some key factor and their effects on performance are explained following.


**Figure 27.** Demonstration of two main challenges in vehicle re-id. (**a**) Intra-class variance; (**b**) Inter-class similarity.

**Figure 28.** Images of the same vehicle taken from different cameras to illustrate the appearance changes.


#### **8. Evaluation Metrics**

In the re-id task, the target object's images are mostly aligned and cropped. However, the vehicle re-id task is same as the instance retrieval. Given the input image, the candidates with a similar input image in the gallery set are required to be placed in the top positions within a ranking list. To measure the performance of vehicle re-id approaches, the cumulative matching characteristics (CMC), curve HIT@1 and HIT@5 are commonly used by researchers. CMC curve provides the probability that an input image identity appears

in a different-sized gallery set as shown in Figure 29. The cumulative number of correctly matched inputs is demonstrated based on the rank list in which inputs are re-identified. Moreover, HIT@1 is precision at rank-1 and HIT@5 is precision at rank-5. Rank is utilized to measure the matching score of test image to its own class, and higher value of rank indicates the improved performance of the system. Where the number of correctly re-identified input images in rank 1 is *q*(*i*), the CMC value for rank *i* can be defined as:

$$\text{CMC}(i) = \sum\_{r=1}^{i} q(r) \tag{4}$$

where *r* represents the rank index. CMC curve not only computes the rank-1 but also places the correctly matched images top ranks. Therefore, the CMC curve is a suitable option to describe the vehicle re-id performance of different approaches. Besides, CMC curves, if multiple ground truths for each query image in the gallery set are available, mean average precision (mAP) is used to measure the overall performance for vehicle re-id system. For the given query image, the average precision (*AP*) can be defined as:

$$AP = \frac{\sum\_{k=1}^{n} P(k) \times G(k)}{Ngt} \tag{5}$$

where *n* is the number of test tracks and *Ngt* represents the number of ground truths. *P*(*k*) shows the precision at cut-off k in the ranking lists. *G*(*k*) equals 1 if the k-th match is true, otherwise 0. The *mAP* measures the overall performance of vehicle re-id system. Therefore, the *mAP* can be defined as:

$$mAP = \frac{\sum\_{q=1}^{Q} AP(q)}{Q} \tag{6}$$

where *Q* denotes the number of queries.

**Figure 29.** Cumulative matching characteristics (CMC) curve.

Another way by which vehicle re-id techniques performance can be evaluated is the confusion matrix. A confusion matrix consists of various columns and rows; it depends on the number of classes. It's diagonal represents the recognizing accuracy or true classification and off-diagonal express the misclassification.

#### **9. Performance Comparison of Recently Proposed State-of-the-Art Approaches**

The first phase in vehicle re-id is to decide whether the given vehicle image exists in the gallery set or not. In other terms, before considering for a similar match, the vehicle re-id system should have the capacity to decide whether the given vehicle probe image is a part of the gallery set or not. This approach is known novelty detection and it needs that vehicle re-id systems to have the ability to discard the miss-matched vehicle images. Usually in vehicle re-id systems, once the gallery set images are ranked in comparison with the given query image, the query image belongs to the gallery set if the similarity distance is higher than an operating threshold. We give a summary of the vehicle re-id mAP of some state-of-the-art methods including CMGN+Pre+Track [111], DF-CVTC [79], PROVID [42], and RAM [112] etc. on VeRi-776 dataset mentioned above. We have chosen VeRi-776 dataset for comparison because it is consisting of varying illumination, more viewpoints, and resolution. In short, this dataset fulfills most of the aspects of real-world camera surveillance data. The statistics about this dataset have been provided in Table 3.

However, Table 4 provides recently proposed state-of-the-art approaches on VeRi-776 dataset. For comparison, we measure the performances of each method in mAP, HIT@1 and HIT@5. From Table 4, and Figure 30 we can observe that mAP of different models is increasing during the years 2016 to 2020. As on VeRi-776 dataset from the years 2016 to 2020, the performances of state-of-the-art methods have improved from 12.76% to 85.20%, with an increase of 72.44%. Moreover, Figure 31 shows the CMC of different state-of-the-art approaches on VehicleID dataset with different test size.


**Table 4.** Performance analysis of some proposed approaches in state-of-art on VeRi-776.


**Table 4.** *Cont.*

**Figure 30.** Demonstrates the performance comparison of different state of the art approaches.

**Figure 31.** Demonstrates the performance comparison of different state-of-the-art approaches on VehicleID dataset. (**a**) Test size = 800; (**b**) Test size = 1600; (**c**) Test size = 2400.

#### **10. Conclusions & Way Forward**

Vehicle re-id is one of the most critical and challenging area in the ITS. Despite high significance, it is not well explored compared to a similar problem that is person re-id. In this review paper, the authors present recent advancements being done for vehicle re-id. Moreover, to draw a detailed picture of study, the authors discuss different vehicle re-id technologies, especially vision-based, including appearance, license plate, spatio-temporal, etc., along with the quantitative and qualitative comparison of different vision-based methods on VeRi-776 and VehicleID datasets. In addition, this review provides comprehensive synopses of publicly available benchmark datasets utilized for performance evaluation with a brief description of re-id evaluation techniques. This paper also presents applications as well as the main challenges of camera-based vehicle re-id such as complex and unconstrained environment, dirt, snow, occluded image, blurry image, and sunshine, etc., along with varied road topology that affects the performance.

There are many aspects of vehicle re-id that can be improved. In the future, a reader can explore possibilities to enhance the overall performance of vehicle re-id. Moreover, there is significant potential to extend the approach with some of the following concepts:

CNN works on edges, shapes, and original vehicle features, but the relationship between these features is not considered; hence, the model performance is often unsatisfactory when the vehicle image is rotated or captured with a different rotation. However, a recent capsule network [133] is introduced, which showed improved performance in handling different poses, orientations, and occluded objects.

Secondly, attention-based deep neural network models have gained encouraging results on various challenging tasks, including machine translation [134], caption generation [135], and object recognition [136]. However, attention-based neural network models are still not well investigated for vehicle re-id.

Lastly, due to the development of large-scale real-world data sets, the vehicle re-id system's performance is significantly increased. However, existing datasets offer a specific range of vehicle images with correlated data that causes over-fitting due to over-tuned parameters on specific data. Therefore, the system cannot efficiently generalize other data. A reader can develop large scale real-world surveillance vehicle datasets in an unconstrained environment with multiple views to enhance the training of the state-of-theart approaches for performance improvement.

Concisely, Vehicle re-id is a demanding and challenging area with massive opportunities for improvement and research. This review paper attempts to provide an overview of the vehicle re-id problem, its challenges, and applications, and, simultaneously, present a way forward. We hope this paper will be valuable for anyone who wants to work in this area.

**Author Contributions:** Conceptualization, Z. and M.S.K.; methodology, Z., Y.H. and R.K.; software, J.D., and J.K.; writing—review and editing, M.U.A.; supervision, J.D. and J.C.; funding acquisition J.C. All authors have read and agreed to the published version of the manuscript.

**Funding:** This paper has been supported The National Key Research and Development Program of China (2017YFC0821505), Funding of Zhongyanggaoxiao ZYGX2018J075 and also supported by Sichuan Science and Technology Program 2019YFS0487.

**Institutional Review Board Statement:** Not applicable.

**Informed Consent Statement:** Not applicable.

**Data Availability Statement:** Not applicable.

**Acknowledgments:** I am grateful to my worthy supervisor as well as all the lab mates for their endless support.

**Conflicts of Interest:** The authors declare that they have no known competing financial interests or personal relationships that could have appeared to influence the work reported in this paper.

#### **References**


### *Article* **Cascaded Cross-Layer Fusion Network for Pedestrian Detection**

**Zhifeng Ding 1, Zichen Gu 2,\*, Yanpeng Sun <sup>1</sup> and Xinguang Xiang <sup>1</sup>**


**Abstract:** The detection method based on anchor-free not only reduces the training cost of object detection, but also avoids the imbalance problem caused by an excessive number of anchors. However, these methods only pay attention to the impact of the detection head on the detection performance, thus ignoring the impact of feature fusion on the detection performance. In this article, we take pedestrian detection as an example and propose a one-stage network Cascaded Cross-layer Fusion Network (CCFNet) based on anchor-free. It consists of Cascaded Cross-layer Fusion module (CCF) and novel detection head. Among them, CCF fully considers the distribution of high-level information and low-level information of feature maps under different stages in the network. First, the deep network is used to remove a large amount of noise in the shallow features, and finally, the high-level features are reused to obtain a more complete feature representation. Secondly, for the pedestrian detection task, a novel detection head is designed, which uses the global smooth map (GSMap) to provide global information for the center map to obtain a more accurate center map. Finally, we verified the feasibility of CCFNet on the Caltech and CityPersons datasets.

**Keywords:** pedestrian detection; machine learning; end-to-end; anchor-free; feature reuse

#### **1. Introduction**

Pedestrian detection is a crucial but challenging task in computer vision and multimedia, which has been applied in various fields. The goal of pedestrian detection is to find all pedestrians in images and videos. Early detection methods [1–6] show that directly using the features of the backbone output is not conducive to the detection of small objects in the image. Recent detection methods show that obtaining high-resolution and high-quality feature representations is the key to improving detection results. As we all know, the low-level features of the backbone contain accurate small object information, while the high-level features contain accurate large object information. Therefore, how to more effectively integrate the characteristics of different stages has been the focus of research on pedestrian detection in recent years.

According to the feature detection method, we divide the feature fusion methods into FPN-like (Like Feature Pyramid Networks) methods and FCN-like (Like Fully Convolutional Networks) methods. The specific difference is that the FPN-like methods detects features of different scales separately, while the FCN-like methods only detects final feature after the fusion of features of different scales. The basic idea of the FPN-like methods is proposed by Single Shot MultiBox Detector (SSD) [2], and its main process is to detect objects in feature maps at different resolutions. However, SSD ignores the spatial information in the shallow feature map, and thus loses the information of small objects in the shallow feature. To improve the recognition performance of small objects, Feature Pyramid Networks (FPN) [7] combines high-level feature maps with strong semantic information and low-level feature maps with weak semantic information but rich spatial information. Some recent works have proposed some FPN-like methods [8–14]. In order to more effectively integrate features of different scales. However, these methods mainly focus on the features

**Citation:** Ding, Z.; Gu, Z.; Sun, Y.; Xiang, X. Cascaded Cross-Layer Fusion Network for Pedestrian Detection. *Mathematics* **2022**, *10*, 139. https://doi.org/10.3390/ math10010139

Academic Editors: Radu Tudor Ionescu, Jianping Gou, Weihua Ou, Shaoning Zeng and Lan Du

Received: 25 November 2021 Accepted: 22 December 2021 Published: 4 January 2022

**Publisher's Note:** MDPI stays neutral with regard to jurisdictional claims in published maps and institutional affiliations.

**Copyright:** © 2022 by the authors. Licensee MDPI, Basel, Switzerland. This article is an open access article distributed under the terms and conditions of the Creative Commons Attribution (CC BY) license (https:// creativecommons.org/licenses/by/ 4.0/).

of adjacent stages in the feature fusion process, and the deep features containing rich semantic information gradually weaken during the top-down process. Therefore, high-level semantic information is lost when detecting shallow features, so that small objects in the image can not be effectively detected.

To avoid the shortcomings of FPN-like methods, some methods directly fuse features of different scales, and then only need to detect the fused features. The origin of this type of method comes from Fully Convolutional Networks (FCN) [15], which combines the features of different stages to obtain feature maps containing semantic information of different scales. In this paper, structures similar to FCN are collectively referred to as the FCN-like methods [15–24]. Compared with FPN-like methods, FCN-like methods have lower computational complexity and faster computational speed, while avoiding the situation that small objects can not be detected due to loss of high-level semantic information. These methods have the same weights for feature fusion at different scales in the feature integration process. In this case, the noise in the shallow features will directly affect the accuracy of the final feature. Previous work Semantic Structure Aware Inference (SSA) [25] proved that the information of small objects is not only in the shallow features, but there is also a small amount of small object information in the deep features. However, the noise information in the shallow network is huge, so how to reduce the impact of the noise information in the shallow features on the detection accuracy is a problem that has not been solved by the current FCN-like methods.

Toward this end, this work takes pedestrian detection as an example and proposes a novel Cascaded Cross-layer Fusion Network (CCFNet), which consists of backbone network, Cascaded Cross-layer Fusion module (CCF), and novel detection head. The basic process framework is shown in Figure 1. First, the CCF merges the features in different stages in the backbone to obtain the final feature map and then performs detection on the feature map. Different from the previous method, CCF uses deep features to denoise shallow features and then reuses deep features to increase the semantic information in the final feature map. To improve the running speed of the algorithm, CCFNet adopts the anchor-free method, based on the detection of pedestrian center points, does not generate anchor points and anchor boxes, and does not match multiple key points. In the detection head, we introduced the center map and global smooth map (GSMap) of the object respectively to reduce the impact of complex scenes and object crowding on the detection performance. Traditional anchor-free detection head only rely on scale map to solve the problem of *'where'* and *'how size'* the object is. This approach increases the difficulty of training the detector. Therefore, we first introduce the center map to undertake the task of *'where the object is'*, while the scale map only needs to undertake the task of *'how size the object is'*. The center map is obtained by convolution, so the center map is obtained by local feature inference. The finiteness of local features limits the accuracy of the center map, so we introduce global smooth map to provide global information for the center map. The specific process is shown in the detection head in Figure 1. Extensive experimental are conducted on the Caltech and CityPersons datasets. The superior performance of CFFNet for pedestrian detection is demonstrated in comparison with the state-of-the-art methods.

The main contributions of this work are summarized as follows:


**Figure 1.** The overall structure of Cascaded Cross-layer Fusion Network (CCFNet). It includes two parts: CCF module and detection head. CCF cascades and reuses features to generate low-level feature maps with contextual semantic information. This feature map generates center map, scale map, and global smooth map through the detection head. And generate the new center map with global information by integrating center map and global smooth map. Finally, locate and mark the objects.

#### **2. Related Work**

#### *2.1. Anchor-Base and Anchor-Free*

The object detection model can be divided into anchor-based detection network and anchor-free detection network. The anchor-based detection network uses anchor points and anchor boxes to generate high-quality prediction regions, then classifies and regresses the prediction regions, which have high accuracy and can extract richer features. Such as Faster Regions with CNN Features (Faster R-CNN) [1], Cascade Regions with CNN Features (Cascade R-CNN) [26], SSD [2], You Only Look Once Version 2 (YOLOv2) [27], etc. However, anchor-base detection network requires manual intervention due to the number of anchor points and the large aspect ratio of the anchor box, which has disadvantages such as too many parameters and insufficient flexibility.

Therefore, people study methods that do not rely on anchor points and anchor boxes, this method is called the anchor-free detection network. The anchor-free detection network are divided into two types: anchor-free detection network based on key points and anchorfree detection network based on object center. The former generate an object bounding box through a set of predefined or self-learned key points (usually a set of corner points of the bounding box) to locate the object, such as CornerNet-Lite [28] and ExtremeNet [29], etc. The latter locates the object by calculating the distance from the object center to the four sides of the bounding box, such as Center and Scale Prediction (CSP) [23], CenterNet [30], etc. The anchor-free detection network based on object center is similar to the anchor-base detection network, but there is not need to generate a large number of anchor points to predict the bounding box, which improves the detection speed of the algorithm. Recently, Zhang et al. [31] proposed that the definition of positive and negative samples of the dataset is the fundamental difference between their performance. Therefore, CCFNet is also built with an anchor-free structure and has reached or even exceeded the accuracy anchor-base detection network.

#### *2.2. FPN-like Methods*

The main idea of FPN [7] is to build a top-down feature pyramid to fuse feature maps at different stages of the backbone, and to detect objects of different sizes on feature maps of different scales. This idea is used in different models, You Only Look Once Version 3 (YOLOv3) [8] obtains multi-scale information through multiple convolutions and repeated fusion of the features of the last three stages of the backbone. Adaptively Spatial Feature Fusion (ASFF) [9] adds attention structure based on YOLOv3, which realizes the selective use of the feature information of different stages by controlling the contribution degree of the features of other stages to the current feature. Bi-Directional Feature Pyramid Network (BiFPN) [11] realize adaptive control of the size of FPN by overlapping effective blocks in FPN multiple times. Recursive Feature Pyramid Network (Recursive-FPN) [12] uses recursive FPN to re-input the mixed multi-scale feature map to the backbone, extract the features again, and finally achieve extremely competitive performance. Multi-level Feature Pyramid Network (MLFPN) [13] proposes three modules, Feature Fusion Module (FFM1), Thinned U-shape Module (TUM), and Scale-wise Feature Aggregation Module (SFAM), to integrate semantic information and detailed information by overlapping feature maps multiple times. However, FPN-like methods not only need to fuse feature maps multiple times but also need to build detection head on feature maps of different output sizes to deal with objects of different sizes. Therefore, FPN-like has shortcomings such as a complex model and slow calculation speed.

#### *2.3. FCN-like Methods*

With the attention of anchor-free detection networks, the idea of FCN-like gradually shifted from the segmentation task to the object detection task. Different from the FPNlike methods, the FCN-like methods only outputs a feature map that integrates feature information of different scales to the detection head. FCN [15] uses deconvolution layer to upsample the feature map of the last stage of the backbone to restore it to the same size of the input image, thereby preserving the spatial information in the input image to classify each pixel in the feature map. In contrast, the reference [24] adopts a completely symmetrical structure, uses deconvolution to restore the image size, splices and fuses feature information of different scales according to the dimension of the feature map. However, its parameters are few and it is not suitable for large-scale detection or segmentation tasks. CornerNet [21] and CSP [23] use FCN to generate feature maps adapted to the detection head. FCN-like methods have fast calculation speed, but the feature information contained in feature maps of different scales is different. If two feature layers with a large semantic information gap are mixed through dimensionality reduction, a large amount of feature information will be lost, and small objects in the image will be lost.

The difference from the above is that CCF combines the advantages of FPN-like methods and FCN-like methods, and retains more low-level detailed information and high-level semantic information through feature reorganization. In addition, CCFNet also proposes global smooth map that enhances the global perception of the center map to deal with the problem of object occlusion.

#### **3. Methods**

This section will elaborate on the proposed Cascaded Cross-layer Fusion Network (CCFNet) for pedestrian detection by exploring the feature fusion and global dependencies.

#### *3.1. Detection Network*

The object detection network is usually divided into backbone network, neck, and detection head. The backbone network is responsible for extracting features from the image. A high-quality feature will significantly improve the ability of object localization. The neck is the hub connecting the backbone and detection head. It integrates the features obtained by the backbone network and then inputs the integrated features into the detection head. A high-quality neck can more fully integrate the high-level and low-level information of the image to improve the representation ability of the model. The detection head is responsible for classification and regression.

Most backbone networks [32–36] can be divided into five stages. With the deepening of the network stage, the resolution of the feature map is reduced at a rate of 2 times. In other words, the size of the feature map obtained in the last stage is 1/32 of the input image, which is not friendly to the small object. Previous work [37,38] proposed that the size of the feature map generated in the fifth stage of backbone should be kept at 1/16 of the input image, which can improve the detailed information in the deep feature map to increase the ability to detect small objects.

The input image *<sup>I</sup>* ∈ *<sup>R</sup>*3×*H*×*<sup>W</sup>* passes through each stage of the backbone network to obtain a set of feature maps *F* = {*F*1, *F*2, *F*3, *F*4, *F*5}. The low-level feature maps generated in the previous stage have more detailed information, but it has a lot of noise. The high-level feature maps generated in later stages have more semantic information. The neck [13,19,39] will reprocesses the feature map set *F* of the backbone network to obtain feature map *fdet* suitable for the detection head. The detection head [1,40,41] is used to classify and locate the object on the feature map *fdet* output by the neck. In anchor-free detection network, the detection head is defined as *Fdet* = {*cls*(*fdet*),*regr*(*fdet*))}, *cls*(·) represents the classification branch that classifies the object by key points, *regr*(·) represents the regression branch that locates the object by scale.

#### *3.2. Cascaded Cross-Layer Fusion Module*

We combine the advantages of the FPN-like methods and the FCN-like methods, propose Cascaded Cross-layer Fusion module (CCF) to more effectively extract the feature information of the object. CCF uses deconvolution to change the scale of the deep feature map to fuse with the shallow feature map. CCF transfers the deep features to the shallow features in a top-down method, enriching the shallow features while removing noise. However, in this transfer process, the semantic information contained in the deep feature map will continue to be lost. Therefore, CCF supplements missing semantic information by reusing deep feature maps. In this way, the final feature map can not only retain the detailed information in the shallow feature map, but also have the semantic information in the deep feature map. Following [23,37], the final feature map size of CCF is [*H*/4, *W*/4]. It is worth noting that this is the same size as the feature map of the second stage. The specific implementation process is as follows:

As shown in Figure 2, CCF uses *F*<sup>4</sup> and *F*<sup>5</sup> as the source to deliver deep semantic information and denoise the shallow feature maps, because the feature maps generated in the fourth and fifth stages of the backbone network contain rich semantic information. In addition, to reduce the computational complexity of the network, the dimensions of *F*<sup>4</sup> and *F*<sup>5</sup> are reduced by 1 × 1 convolution to generate *Fc*<sup>4</sup> and *Fc*5. Finally, *Fc*<sup>4</sup> and *Fc*<sup>5</sup> are fused to obtain the feature map *Fs*4. *Fs*<sup>4</sup> retains the semantic information of *F*<sup>4</sup> and *F*<sup>5</sup> and continues to be used for subsequent transmission of semantic information. The fusion generation method of feature map *Fs*<sup>4</sup> can be expressed as:

$$\mathcal{F}\_{\mathfrak{sl}} = \operatorname{Sum}(F\_{\mathfrak{c4}\mathfrak{s}}, F\_{\mathfrak{c5}}) \tag{1}$$

where *Sum*(·) indicates that the fusion method of *Fc*<sup>4</sup> and *Fc*<sup>5</sup> is the element-wise addition between the feature maps *Fc*<sup>4</sup> and *Fc*5.

The feature map *Fs*<sup>4</sup> will serve two purposes: (1) Regarding *Fs*<sup>4</sup> as a new source, it will fuse with the new receiver *F*<sup>3</sup> and continue to convey semantic information from the deep features map. Only the output features of the last two stages in the backbone have the same size. Therefore, it is necessary to perform deconvolution before fusing the shallow features to make it the same size as the previous layer. Therefore, the new source *Fs*<sup>4</sup> performs up-sampling through deconvolution to obtain a feature map *Fsd*<sup>4</sup> of the same size as *Fc*3. The process is as follows:

$$\mathcal{F}\_{sd4} = D\mathbb{C}(\mathcal{F}\_{s4}) \tag{2}$$

where *DC*(·) means 4 × 4 deconvolution. *Fsd*<sup>4</sup> will be used as the new source, and *Fc*<sup>3</sup> after dimensionality reduction of feature map *F*<sup>3</sup> will be fused to obtain *Fs*<sup>3</sup> according to Equation (1). *Fs*<sup>3</sup> will be used to transfer the semantic information and detailed information contained in the feature maps *F*3, *F*<sup>4</sup> and *F*5. (2) As mentioned before, in purpose (1), the semantic information of the deep feature map will continue to be lost, so the feature map *Fsd*<sup>4</sup> needs to be transformed into a feature map *Fd*<sup>4</sup> of size [*H*/4, *W*/4] for feature reuse (Equation (2)). *Fd*<sup>4</sup> can retain the feature representation in the deep feature map.

To continue to transmit the semantic information from the deep feature map and retain the detailed information in *F*3, the feature map *Fs*<sup>3</sup> is transformed to the same size as *F*<sup>2</sup> through deconvolution, and the resulting *Fsd*<sup>3</sup> will be used for subsequent operations (Equation (2)).

**Figure 2.** Cascaded Cross-layer Fusion Module (CCF).

The feature map *F*<sup>3</sup> only contains part of the detailed information, which is not enough to support the network to detect small objects, as shown in the ablation study (Section 4.3). Therefore, CCF refers to the feature map *F*<sup>2</sup> generated in the second stage, so that the final feature map input to the detection head has more detailed information. However, *F*<sup>2</sup> contains a lot of noise. CCF uses *Fsd*<sup>3</sup> containing depth semantics to denoise *F*2. In other words, the feature map *Fc*<sup>2</sup> is obtained by reducing the dimension of *F*<sup>2</sup> through 1 × 1 convolution. *Fc*<sup>2</sup> and *Fsd*<sup>3</sup> are calculated by Equation (1) to get the feature map *Fs*2. It is worth noting that the size of *Fs*<sup>2</sup> is [*H*/4, *W*/4]. There is no need to perform additional processing on *Fs*2.

Finally, CCF merge all feature maps through *Concat*(·) to obtain a final feature map *Flc* with rich detailed information and semantic information, *Flc* can be expressed as:

$$\mathcal{F}\_{lc} = \text{Concat}(F\_{d4\prime}F\_{sd3\prime}F\_{s2})\tag{3}$$

Following [7], CCF use 3 × 3 convolution after *Flc* to reduce the aliasing effect produced in the process of deconvolution and feature fusion.

#### *3.3. Detection Head*

Our detection head contains center map, scale map, and global smooth map. Following CSP [23], the center map is equipped with gaussian heat map to locate the object, and scale map is used to determine the size of the object. Although the Gaussian heat map can reduce the weight of negative samples around the object center point, the center map only obtains local perception and lacks global perception. To this end, we add global smooth map, which is fused with the center map, and the generated new center map will have global perception. In addition, considering that the aspect ratio of the pedestrian will change with the change of the pedestrian state, we discarded the scale map that predicts the size of the pedestrian by only predicting the height and fixing the width. The scale map was modified to predict the height and width of pedestrians at the same time.

As shown in Figure 3, the detection head includes center map, global smooth map and scale map. They are all obtained by the feature map *Flc* generated by CCF through different 1 × 1 convolutions. Then we use the global smooth map to modify the center map to obtain a more accurate new center map. Finally, the new center map and scale map are used to generate detection results. Optionally, the offset map can be added to the detection head to correct the position of the object.

**Figure 3.** The overall architecture of the detection head mainly includes three map components, namely the center map, the scale map and the global smooth map (GSMap).

#### *3.4. Loss Function*

#### 3.4.1. Center Loss

Combined with the global smooth map, the center loss is modified as follows:

$$\mathcal{L}\_{\text{center}} = -\frac{1}{K} \sum\_{i=1}^{\mathcal{W}/4} \sum\_{j=1}^{H/4} (s\_{ij} f\_{ij} + (1 - s\_{ij}) b\_{ij}) \log(1 - p\_{ij}) \tag{4}$$

where

$$\begin{cases} f\_{i\bar{j}} = \text{gs}\_{i\bar{j}} (1 - p\_{i\bar{j}})^\gamma \\ b\_{i\bar{j}} = p\_{i\bar{j}}^\gamma (1 - M\_{i\bar{j}})^\beta \end{cases} \tag{5}$$

from Equations (4) and (5), *K* is the total number of objects, *W* and *H* are the width and height of the input image respectively, *sij* represents the true label on the coordinates (*i*, *j*), *pij* represents the probability of the positive on the coordinates (*i*, *j*), *gsij* is global smooth confidence, *Mij* is Gaussian heat map [23], *fij* and *bij* represent the foreground and background scores in the image, respectively.

#### 3.4.2. Scale Loss

Calculate the scale map by SmoothL1 loss [42] to predict the error between the height and width of the object according to the ground truth. The details of scale loss as follows:

$$\mathcal{L}\_{scale} = -\frac{1}{K} (\sum\_{k=1}^{K} SmoothL1(h\_k, \hat{h}\_k) + \sum\_{k=1}^{K} SmoothL1(w\_k, \hat{w}\_k)) \tag{6}$$

where *hk* and ˆ *hk* respectively represent the height of the prediction boxes of the network and the height of the ground truth of each positive, *wk* and *w*ˆ *<sup>k</sup>* respectively represent the width of the prediction boxes of the network and the width of the ground truth of each positive.

#### 3.4.3. Total Loss

Optionally, if the offset map is added to correct the object position, the offset loss is:

$$\mathcal{L}\_{offset} = -\frac{1}{K} (\sum\_{k=1}^{K} SmoothL1(o\_{k\prime}\vartheta\_k))\tag{7}$$

where *ok* represents the predicted offset of each positive and *o*ˆ*<sup>k</sup>* represents the ground truth of each positive.

Therefore, the complete loss function is:

$$\mathcal{L} = \lambda\_{\mathcal{L}} L\_{\text{center}} + \lambda\_{\mathcal{S}} L\_{\text{scale}} + \lambda\_{\mathcal{O}} L\_{\text{offset}} \tag{8}$$

where *λc*, *λs*, and *λ<sup>o</sup>* are the weights of center loss, scale loss and offset loss, which is set to 0.01, 1 and 0.1 in this experiments. Although on the surface, our loss function is similar to the loss of many methods, from the details we can know that this is different.

#### **4. Experimental Results**

To evaluate the proposed CCFNet, we conducted comparative experiments on Caltech [43,44] and CityPersons [45]. In this section, we introduce the datasets and experimental setting, then verify the effectiveness of the model by the ablation study on the CityPersons dataset, and finally show the compare experimental results with state-of-the-art methods and visualize to verify the superiority of the CCFNet.

The details of each section are as follows: The Section 4.1 introduces the datasets and evaluation indicators of pedestrian detection. The Section 4.2 introduces the experimental setting. The ablation studies on the CityPersons dataset will be analyzed in the Section 4.3. In Section 4.4, the superiority and effectiveness of the model is verified by comparison with other methods on the Caltech and CityPersons datasets. In Section 4.5, visualize the detection results to further illustrate the superiority of CCFNet. Finally, in Section 4.6, we discuss all the experimental results.

#### *4.1. Datasets*

The Caltech dataset is about 10 hours of video data, divided into 11 subsets, of which 6 subsets are training sets and 5 subsets are test sets. We divided the video into RGB frames, the training set extracts one image for every 3 frames (total of 42,782 images) and the test set extracts one image for every 30 frames (total of 4024 images). It is observed in Figure 4a,b: the training set contains 5564 pedestrians and 4992 ignored regions, the test set contains 7596 pedestrians and 0 ignored regions.

**Figure 4.** The histogram and pie chart represent the distribution statistics of each category in the Caltech and CityPersons datasets. (**a**) represents the label distribution of the training set in the Caltech dataset. (**b**) represents the label distribution of the test set in the Caltech dataset. (**c**) represents the label distribution of the training set in the CityPersons dataset. (**d**) represents the label distribution of the validation set in the CityPersons dataset.

The CityPersons dataset is a subset of the Cityscapes dataset, it has a training set of 2975 images and a validation set of 500 images. From Figure 4c,d, we can clearly known that objects with 59.51% in the training set are marked as pedestrian labels. Objects with 24.37% are marked as ignore labels, including object height pixels less than 20, unclear object status, billboards, etc. Objects with 6.05% are marked as rider labels, Objects with 3.72% are marked as sitting labels. Objects with 1.50% are marked as other labels, including being held of the people. Objects with 4.85% belong to the group. It is worth noting that during the evaluation process, prediction boxes that match rider, sitting, other, ignored areas, etc. It will not be included in the error sample. The label distribution of the validation set is similar to the training set.

Following [44], we using Log-Average Miss Rate (*MR*<sup>−</sup>2) as an evaluation indicator. It evaluates the False Positive Per Image (FPPI) of each image between [0.01, 1]. The Caltech dataset is evaluated on the Reasonable and Reasonable\_Occ=Heavy subsets. The CityPersons dataset is evaluated on the Reasonable, Bare, Partial and Heavy subsets. The definition rules of subsets are shown in Table 1, where *inf* means infinity.



#### *4.2. Experimental Setting*

Unless otherwise specified, The construction of CCFNet follows mmdetection [46] and pedestron [47]. The experiment in this paper is run on a TITAN RTX. On the Caltech dataset, the batch size is set to 16, the initial learning rate is 2 × <sup>10</sup>−4, and the iteration is 20 epoch. On the CityPersons dataset, the batch size is set to 4, the initial learning rate is <sup>2</sup> × <sup>10</sup><sup>−</sup>4, and the iteration is 150 epoch. Our experimental setup is based on [48,49].

#### *4.3. Ablation Study*

For CCF. To study the effective combination methods of the feature maps, we test the impact of different fusion strategies on model performance. CCF starts with the features of the second stage and keeps the final feature map size as [*H*/4, *W*/4], which is consistent with the feature map size of the second stage. As shown in Table 2, *sn* represents the feature map generated at the *n*-th stage of the backbone. It can be easily observed that the last model combines feature maps {*s*2, *s*3, *s*4, *s*5} obtains the best performance. When *s*<sup>2</sup> is removed, that is, the combination way {*s*3, *s*4, *s*5} gets a poor result, which indicates that the lack of detailed information makes it impossible to accurately locate the object. When *s*<sup>5</sup> is removed, that is, the combination way {*s*2, *s*3, *s*4} also obtains a bad result, which shows that the semantics information contained in the deep features information is crucial. In summary, {*s*2, *s*3, *s*4, *s*5} is the most suitable combination methods.

**Table 2.** Ablation study analysis of different combinations of multi-scale feature on the Citypersons dataset.


To verify the effectiveness of CCF, we use different neck to connect the backbone network ResNet-50 and the detection head [23], such as FPN [7], Augmented FPN (AugFPN) [50], Attention-guided Context Feature Pyramid Network (ACFPN) [51] and CSP [23]. As shown in the table 3, we can observe that compared with necks of other models, CCF has strong competitiveness in Reasonable, Bare and Partial subsets. In the Heavy subset, CCF is also better than part of the necks. Compared with FPN, CCF reuses the semantic information of deep feature maps to obtain more contextual information in the final feature

map. In addition, CCF does not need to output multi-scale feature maps to detect objects. Compared with CSP, CCF removes the noise in the shallow feature map, and retains more detailed information by cascading.


**Table 3.** Ablation study of different neck module on the Citypersons dataset.

For GSMap. Table 4 shows the ablation study on GSMap. The Baseline contains neck and detection head. The neck contains the deconvolution of the fifth stage of ResNet-50 and the detection head contains center map and scale map. Baseline + GSMap means adding GSMap to the detection head. Baseline + GSMap means replacing the neck in the baseline with CCF. Baseline + CCF + GSMap uses CCF to replace the neck in the baseline and adds GSMap to the detection head. As shown in Table 4, we can be observed that adding GSMap separately based on the baseline increases the Reasonable subset by 0.7%, the Bare subset by 0.3%, the Partial subset by 0.8%, and the Heavy subset by 3.7%. If CCF and GSMap work at the same time, compared with baseline + CCF, each subset increases by 0.4%, 0.3%, 0.6% and 5.7%, respectively. This result shows that GSMap enhances the locating ability by making the center map have global feature information. Its performance is enhanced as the effective feature information increases.

**Table 4.** Ablation study of global smooth map on the Citypersons dataset.


For Scale Prediction. Table 5 shows the impact of scale prediction on CCFNet. Following previous work [23], we set the three scale predictions of height, width and height + width. Compared with the predicted height, height + width increases by 0.6% on the reasonable subset and 4.5% on the heavy subset. Compared with the predicted width, height + width increases by 1.2% on the reasonable subset and 7.2% on the heavy subset. Simultaneously predicting the height and width of the object can further improve the performance of CCFNet. This result is attributed to predicting the height and width of the object at the same time, which can adapt to objects with different aspect ratios, rather than being limited to a certain aspect ratio. In addition, retaining more feature information is conducive to the prediction of object width. From the results of the heavy subsets, it can be concluded that predicting the height and width at the same time helps to deal with dense and overlapping objects.

**Table 5.** Ablation study of different definitions for scale prediction on the Citypersons dataset.


#### *4.4. State-of-the-Art Comparisons*

Caltech Dataset: CCFNet compares some excellent methods in reasonable and Reasonable\_Occ=Heavy subset. As shown in the Figure 5, CCFNet has 4.33% MR-FPPI on the Reasonable subset, which is 0.37% more advanced than the best method. On the Reasonable\_Occ=Heavy subset, CCFNet has 43.21% MR-FPPI, which is also competitive. When the model is initialized on the CityPersons dataset, the performance of CCFNet has increased by 6.04%, surpassing other comparison methods. CCFNet uses feature cascading and reorganization to retain more contextual information, and improves the positioning ability of the center map through global smoothing graph.

**Figure 5.** The results of various models on the Caltech dataset. (**a**) Compare with existing methods on Reasonable subset. (**b**) Compare with existing methods on the Reasonable\_Occ=Heavy subset.

As shown in the Table 6, CCFNet also compares advanced algorithms, such as Repulsion Loss (RepLoss) [38] used to solve the occlusion problem and anchor-free detection network CSP, etc. In the reasonable subset, CCFNet achieved 4.3% MR-FPPI, which is 0.7% and 0.2% lower than that of RepLoss and CSP, respectively. In the Reasonable\_Occ=Heavy subset, CCF has reached 43.2% MR-FPPI, which is an increase of 4.7% and 2.6% compared to RepLoss and CSP, respectively. This is an impressive improvement. When the model is initialized on the CityPersons dataset, CCFNet reaches 3.5% on a reasonable subset, and 36.2% on the Reasonable\_Occ=Heavy subset. It is proved that CCFNet reuses high-level features in cascaded manner is effective.


**Table 6.** The results of various models on the Caltech dataset.

CityPersons Dataset: We verify the performance of CCFNet on CityPersons dataset, which contained reasonable, heavy, bare and partial subsets. The comparative experiment results as show in Table 7. *MR*−<sup>2</sup> of CCFNet on the reasonable subset is 10.2%, on the bare subset is 6.8%, on the partial subset is 9.5%, and on the heavy subset is 42.7%. In the reasonable subset, CCFNet is 0.4% and 0.3% lower than Attribute-aware Pedestrian Detection (APD) [55] and Mask-Guided Attention Network (MGAN) [53], respectively. In the heavy subset, CCFNet is increased by 7.1% and 4.5% compared with APD and MGAN, respectively. It can be seen that CCFNet achieved best performance beyond other comparison methods. It reflects the strong competitiveness of CCFNet.


**Table 7.** The results of various models on the CityPersons dataset.

#### *4.5. Visualization*

To further illustrate the superiority of CCFNet, we visualized the detection results on the CityPersons dataset, as shown in Figure 6. The first line (a) represents the original image in the validation set of the CityPersons dataset. The second line (b) represents the ground truth. The third line (c) represents the visualization result of the CSP. And the fourth line (d) represents the visu.alization result of CCFNet. The visualization results of CSP and CCFNet rely on the same confidence.

To show the effectiveness of the CCFNet, we selected three images from different scenes to compared with CSP. The first image belongs to a crowded scene. The second image belongs to a simple scene containing small objects. The third image is a scene with low visibility, low exposure, and small objects. The visualization result as show in Figure 6. It can be seen that in the first image, CSP and CCFNet generate a large number of detection boxes, but CCFNet has fewer false detection boxes. In addition, CCFNet can better solve the problem of multiple detection boxes for one single object. From the second image, CSP and CCFNet have the problem of overlapping detection boxes, but CSP has extremely bad results. In contrast, CCFNet has better visualization. From the third image, CSP can detect small objects in the image, but it also gets a lot of objects that should not be detected. In contrast, CCFNet avoids this problem. Therefore, CCFNet not only has good performance, but its visualization results are also robust.

As shown in Figure 7, the first line (a) represents the original image in the validation set of the CityPersons dataset. The second line (b) represents the heat map of the ACFPN. The third line (c) represents the heat map of the CSP. And the fourth line (d) represents the heat map of CCFNet. We also selected the images of the three scenes for comparison. The three images respectively cover complex environments, crowded scenes, and general scenes. It can be seen that the highlight of ACFPN presents a discrete distribution, the highlight of CSP presents a concentrated distribution, and the highlight of CCFNet is multi-peak. The ACFPN can not distinguish which type of person belongs to, and can not cope with the crowded state of objects, this is related to the fact that ACFPN is a general object detection network. The CSP responds to certain backgrounds, which makes CSP a bad visualization result, even though it has a low error detection rate. The CCFNet will not over-respond to the background and can distinguish the categories of people, it not only has a lower error detection rate, but its visualization results are also more optimistic.

**Figure 6.** Visualization results of CCFNet and CSP do not limit the visibility of pedestrian objects. (**a**) Input the original image for the CityPersons dataset; (**b**) is the ground truth corresponding to (**a**); (**c**) is the visualization result of CSP; (**d**) is the visualization result of CCFNet.

**Figure 7.** Visualization results of ACFPN, CSP, and CCFNet. (**a**) Input the original image for the CityPersons dataset; (**b**) is the visualization result of ACFPN; (**c**) is the visualization result of CSP; (**d**) is the visualization result of CCFNet.

#### *4.6. Discussion*

The proposal of CCFNet is influenced by the anchor-free object detection network. In the anchor-free network, how to make the neck effectively use the feature representation extracted by the backbone network will directly affect the performance of the detection head. Previous work [50,51] has achieved good performance in general object detection, but it can not be generalized to some special tasks, such as pedestrian detection.

Table 2 shows the ablation experiment of multi-scale features in the CCF module. By combining the feature maps of different stages, the optimal feature map combination is discussed. CCF reduces the noise in the shallow feature map by cascading and reusing deep semantic information, while retaining the semantic information lost due to dimensionality reduction operations. The purpose of this is to make the final feature map have more features.

Table 3 shows the comparative experiments between CCF and other necks. The previously proposed FPN-like methods and FCN-like methods achieve the most advanced performance in general object detection, but they are not suitable for pedestrian detection tasks. CCF module shows a very competitive performance.

Table 4 shows the ablation experiment of GSMap. The center map reduces the weight of negative samples through the Gaussian heat map, but does not change the shortcomings of convolution operation that can only obtain partial global information [57–59]. The proposal of GSMap can enable the center map to obtain more global information. In addition, according to the results of the heavy subset. It not only proves that the congestion problem between objects can not be completely solved by enhancing the semantic information in the feature map, but also requires additional modules for assistance, such as GSMap.

Table 5 shows the experiment of object scale prediction. The previous work only determines the size of the object by predicting the height [23,48]. We have proved through experiments that predicting the height and width of objects at the same time is the most suitable for CCFNet. In addition, this can also help cope with dense and overlapping problems.

Figure 5 and Table 6 show the comparative experiments of CCFNet with other advanced algorithms on the Caltech dataset. Table 7 shows the comparative experiments of CCFNet with other advanced algorithms on the Citypersons dataset. Their results prove the effectiveness of CCFNet.

#### **5. Conclusions**

In this paper, we proposed Cascaded Cross-layer Fusion module (CCF), which combines deep semantics and shallow details to obtain features, which will obtain more contextual semantic information. In order to cope with the situation of highly congested and severely occluded objects, we designed global smooth map (GSMap) and improved center loss function, which can effectively solve this problem at a small cost. Cascaded Cross-layer Fusion Network (CCFNet) can achieve better performance without relying on anchor points, multiple key points and complex post-processing. Finally, we conducted a large number of experiments on Caltech and CityPersons datasets to verify the superiority of CCFNet. Although the model introduces dimensionality reduction operations in the design process to reduce the computational complexity of the model, the final model still uses a large number of parameters that cannot meet the requirements of the real-time system. Therefore, designing an effective lightweight module is the focus of our next work.

**Author Contributions:** Formal analysis, Z.G., Y.S. and X.X.; methodology, Z.D. and Y.S.; project administration, Z.D.; validation, Z.G. and X.X.; visualization, Z.G and X.X.; writing–original draft, Z.D.; writing–review and editing, Z.G., Y.S. and X.X. All authors have read and agreed to the published version of the manuscript.

**Funding:** Not applicable.

**Institutional Review Board Statement:** Not applicable.

**Informed Consent Statement:** Not applicable.

**Data Availability Statement:** Not applicable.

**Acknowledgments:** Not applicable.

**Conflicts of Interest:** The authors declare no conflict of interest.

#### **References**


## *Article* **Second-Order Spatial-Temporal Correlation Filters for Visual Tracking**

**Yufeng Yu 1, Long Chen 1, Haoyang He 2, Jianhui Liu 3, Weipeng Zhang <sup>4</sup> and Guoxia Xu 5,\***


**Abstract:** Discriminative correlation filters (DCFs) have been widely used in visual object tracking, but often suffer from two problems: the boundary effect and temporal filtering degradation. To deal with these issues, many DCF-based variants have been proposed and have improved the accuracy of visual object tracking. However, these trackers only adopt first-order data-fitting information and have difficulty maintaining robust tracking in unconstrained scenarios, especially in the case of complex appearance variations. In this paper, by introducing a second-order data-fitting term to the DCF, we propose a second-order spatial–temporal correlation filter (SSCF) learning model. To be specific, the SSCF tracker both incorporates the first-order and second-order data-fitting terms into the DCF framework and makes the learned correlation filter more discriminative. Meanwhile, the spatial–temporal regularization was integrated to develop a robust model in tracking with complex appearance variations. Extensive experiments were conducted on the benchmarking databases CVPR2013, OTB100, DTB70, UAV123, and UAVDT-M. The results demonstrated that our SSCF can achieve competitive performance compared to the state-of-the-art trackers. When penalty parameter *λ* was set to 10<sup>−</sup>5, our SSCF gained DP scores of 0.882, 0.868, 0.706, 0.676, and 0.928 on the CVPR2013, OTB100, DTB70, UAV123, and UAVDT-M databases, respectively.

**Keywords:** correlation filters; second-order fitting; visual tracking

**MSC:** 68T45

#### **1. Introduction**

Visual object tracking is a fundamental problem in the field of computer vision, which has a wide range of applications in human–computer interaction, video surveillance, unmanned driving, and so on. The task of visual object tracking always suffers from the challenges of appearance variations, such as illumination variation, fast motion, out-ofplane rotation, and in-plane rotation. To deal with these challenges, various innovative trackers have been proposed and achieved significant progress in tracking performance and robustness. Among these tracking methods, discriminative-filter-based trackers [1–5] have received significant attention due to their competitive performance.

The standard discriminative-correlation-filter (DCF)-based tracker treats the filter learning as a ridge regression problem, and the objective function can be transferred to the frequency domain by the fast Fourier transform (FFT) for the solution. Bolme et al. [6] first learned the correlation filter to perform the target tracking task and proposed a minimum output sum of squared error (MOSSE) model. The MOSSE trains the filter

**Citation:** Yu, Y.; Chen, L.; He, H.; Liu, J.; Zhang, W.; Xu, G. Second-Order Spatial-Temporal Correlation Filters for Visual Tracking. *Mathematics* **2022**, *10*, 684. https://doi.org/10.3390/ math10050684

Academic Editors: Jianping Gou, Weihua Ou, Shaoning Zeng and Lan Du

Received: 10 January 2022 Accepted: 17 February 2022 Published: 22 February 2022

**Publisher's Note:** MDPI stays neutral with regard to jurisdictional claims in published maps and institutional affiliations.

**Copyright:** © 2022 by the authors. Licensee MDPI, Basel, Switzerland. This article is an open access article distributed under the terms and conditions of the Creative Commons Attribution (CC BY) license (https:// creativecommons.org/licenses/by/ 4.0/).

by calculating the minimum actual and expected mean-squared errors of sequence images. Inspired by the MOSSE, Henriques et al. [7] considered that cyclic displacement could be used to replace random sampling to achieve dense sampling and proposed a theoretical framework to explore the effect of dense sampling. The proposed framework formulates a kernelized correlation filter to improve the tracking performance. Zhang et al. [8] adopted the Bayesian principle to build a spatial–temporal context model for tracking. However, these CF-based trackers only utilize single-channel features, which is not robust in the tracking scenarios with complex appearance variations. To tackle this issue, some CF-based methods [9–19] extract multiple features to learn the filters. The commonly used handcrafted features include the histogram of oriented gradients (HOG), color names (CNs), the local binary pattern (LBP), and scale-invariant feature transform (SIFT). These features describe the shape and color information of the targets. Trackers using multiple features are more robust to the fast movement and deformation variation of targets. For instance, Galoogahi et al. [17] employed multi-channel HOG descriptors in the frequency domain to extract HOG features for filter learning and proposed a multi-channel CF tracker (MCCF). Huang et al. [14] used hybrid color features to learn filters in which the compressed CN features and the HOG features based on the opponent color space were extracted, and principal component analysis was used to reduce the computational cost. Li et al. [12] integrated the raw pixel, HOG, and color label features into the DCF framework and presented an adaptive multiple feature tracker. Kumar et al. [19] exploited the LBP, color histogram, and pyramid of the histogram of gradients to model the object's appearance and developed an adaptive multi-cue particle filter method for real-time visual tracking.

Even though these DCF-based trackers using multi-channel features succeed to some extent, some aspects such as the redundancy of multi-channel features, the boundary effect, and data fitting have not been fully explored. To tackle these issues, many structural regularized DCF methods [20–26] have been presented. Zhu et al. [2] proposed an adaptive attribute-aware strategy to distinguish the importance of different channel features. Jain et al. [20] presented a channel graph regularized CF model by introducing a channel weighing strategy in which a channel regularizer was integrated into the CF framework to learn the channel weights. Xu et al. [22] proposed a channel selection scheme for multi-channel feature representations and adopted a low-rank approximation to learn filters in a low-dimensional manifold. In addition, many trackers propose a variety of strategies to solve the boundary effect. The SRDCF [23] incorporates a spatial regularizer into the DCF to deal with the problem caused by the periodic assumption. Li et al. [24] supplemented the temporal regularization term into the SRDCF tracker [23] and proposed a spatial–temporal regularization CF framework. To be specific, the STRCF integrates both temporal regularization and spatial regularization into the standard DCF model and can perform model updating and DCF learning simultaneously. As a result, the STRCF could be regarded as an approximation of the SRDCF with multiple samples and achieves better tracking performance than the SRDCF. The BACF [25] utilizes a cropping matrix to extract patches densely from the background and expands the search area at a low computational cost. Xu et al. [26] combined temporal consistency constraints and spatial feature selection to propose an adaptive DCF model in which the multi-channel filters can be learned in a low-dimensional manifold space. However, the aforementioned trackers only employ the first-order data-fitting information of the feature maps. In other words, such methods do not consider high-order data-fitting information for tracking.

On the basis of the above-mentioned analysis, we propose a novel CF-based tracker, the second-order spatial–temporal correlation filter (SSCF) learning model. We formulated our tracking algorithm by incorporating a second-order data-fitting term into the DCF framework, which helps to take full advantage of target features against surrounding background clutter. The main contributions of the SSCF are summarized as follows:

• We propose a new discriminative correlation filter model for visual tracking with complex appearance variations, unlike prior DCF-based trackers in which the firstorder data-fitting information is only used. We incorporated the second-order data fitting and spatial–temporal regularization into the DCF framework and developed a more robust tracker;


The remainder of this paper is organized as follows. Section 2 introduces the related work. Section 3 describes the detailed mathematical formulation of the proposed model and introduces the optimization algorithm. Section 4 reports the experimental results and the corresponding analysis. Finally, Section 5 draws the conclusions.

#### **2. Related Work**

In this section, we review mainly three categories of tracking methods, including trackers based on target detection, trackers based on clustering, and channel-reliability learning trackers.

Since target detection techniques [27–29] have attracted wide attention in the computer vision field, many trackers based on target detection have been proposed. Guan et al. [30] proposed a joint detection and tracking framework for object tracking in which the detection threshold was adaptively modified according to the information fed back to the detector by the tracker. Zhang et al. [31] employed a faster recurrent convolutional neural network to extract the candidate detection areas and proposed a multi-target tracking algorithm. In [32], Liu et al. combined motion detection with correlation filtering and presented a new model for object tracking. The presented model determines the object position via the weighted outputs of motion detection and the tracker. Considering that the existing kernelized correlation filter tracking methods fail to identify occlusion, Min et al. [33] adopted a detector to assist the occlusion judgment and improve the tracking performance.

Clustering-based algorithms [34,35] have been commonly used in pattern recognition and computer vision, such as image segmentation [36] and patten classification [37]. Inspired by this, many researchers use clustering algorithms to improve the performance of object tracking. For instance, Keuper et al. [38] combined motion segmentation with object tracking and presented a correlation co-clustering model to improve the performance. In [39], Li et al. developed an intuitionistic fuzzy clustering model for object tracking. Specifically, the local information of the targets is incorporated into the intuitionistic fuzzy clustering to improve the robustness. Considering that DBSCAN clustering does not require the number of clusters, He et al. [40] employed a DBSCAN clustering-based track-to-track fusion strategy for multi-target tracking.

Recently, the idea of different weights distinguishing the importance of different components has been widely used in pattern classification [41,42] and face recognition [43]. Similarly, some DCF-based channel-reliability learning trackers have been proposed to deal with the problem of model degradation. Du et al. [44] argued that different channels have different contributions in the tracking process and proposed a joint channel-reliability and correlation-filter learning model. The proposed tracker assigns each channel a weight to distinguish the different importance. To exploit the interaction between different channels, Jain et al. [20] assigned similar weights to similar channels to emphasize important channels and developed a channel attention model. Li et al. [45] argued that the existing trackers do not consider the complementary information of different channels and proposed a channel-feature integration method. All channels of each feature share an importance map to avoid overfitting. In [46], the authors introduced channel and spatial reliability to the DCF framework and employed the reliability scores to weight the per-channel filter responses. The experiments showed that the channel weights were able to improve the tracking performance. These methods principally focus on overcoming model degradation by incorporating channel reliability and enhance the discriminative performance to some extent.

#### **3. The Proposed Model**

#### *3.1. Objective Function Construction*

As mentioned above, the existing DCF-based methods only utilize first-order datafitting information and ignore high-order data-fitting information for tracking, which cannot take full advantage of target features against surrounding background clutter and suffer from the stability–plasticity dilemma. To deal with these issues, we built a second-order spatial–temporal correlation-filter learning framework. Specifically, we incorporated a second-order data-fitting term and spatial–temporal regularization into the DCF framework and formulated a robust model. The objective function is able to be formulated as below.

We first denote the dataset S = {X*t*}*<sup>T</sup> <sup>t</sup>*=1, and each frame X*<sup>t</sup>* ∈ *<sup>R</sup>M*×*N*×*<sup>K</sup>* contains *<sup>K</sup>* feature maps with a size of *<sup>M</sup>* × *<sup>N</sup>*. **<sup>Y</sup>** ∈ *<sup>R</sup>M*×*<sup>N</sup>* is the Gaussian-shaped label. Our aim was to learn a multi-channel convolution filter F ∈ *<sup>R</sup>M*×*N*×*<sup>K</sup>* by minimizing the following objective function:

$$\begin{aligned} \min\_{\mathbf{F}} & \frac{1}{2} \left\| \sum\_{k=1}^{K} \mathbf{X}\_{t}^{k} \ast \mathbf{F}^{k} - \mathbf{Y} \right\|\_{F}^{2} + \frac{1}{2} \sum\_{k=1}^{K} \|\mathbf{W} \cdot \mathbf{F}^{k}\|\_{F}^{2} \\ & + \frac{\lambda}{2} \left\| \sum\_{k=1}^{K} \mathbf{X}\_{t}^{k} \ast \mathbf{F}^{k} \ast \mathbf{X}\_{t}^{k} - \mathbf{Y} \right\|\_{F}^{2} + \frac{\mu}{2} \|\mathbf{F} - \mathbf{F}\_{t-1}\|\_{F}^{2} \end{aligned} \tag{1}$$

where ∗ represents the convolution operator and · denotes the Hadamard product. **W** is the spatial regularization matrix, and F*t*−<sup>1</sup> is the correlation filter used in the *t* − 1-th frame. *λ* and *μ* are penalty parameters. The first term is the first-order data-fitting term, which is a generic formulation for learning the filter in DCF-based trackers. The second term is the spatial regularizer to solve the boundary effect. The third term is the second-order datafitting term, which can be helpful to make full use of discriminative target features. The last term is the temporal regularizer to force the current frame filter close to the previous one, which helps to prevent the effect caused by the corrupted samples.

#### *3.2. Optimization Algorithm*

It can be noted that the objective function in Equation (1) is convex, and the minimization problem can be solved by the ADMM algorithm. To be specific, we introduced an auxiliary variable G ∈ *<sup>R</sup>M*×*N*×*<sup>K</sup>* by restricting F = G and constructed the augmented Lagrangian form of Equation (1) as:

$$\begin{split} L(\mathbb{F}, \mathbb{G}, \mathbb{S}) &= \frac{1}{2} \left\| \sum\_{k=1}^{K} \mathbf{X}\_{t}^{k} \ast \mathbf{F}^{k} - \mathbf{Y} \right\|\_{F}^{2} + \frac{1}{2} \sum\_{k=1}^{K} \||\mathbf{W} \cdot \mathbf{G}^{k}||\_{F}^{2} \\ &\quad + \frac{\lambda}{2} \left\| \sum\_{k=1}^{K} \mathbf{X}\_{t}^{k} \ast \mathbf{F}^{k} \ast \mathbf{X}\_{t}^{k} - \mathbf{Y} \right\|\_{F}^{2} + \frac{\mu}{2} \|\mathbb{F} - \mathbb{F}\_{t-1}\|\_{F}^{2} \\ &\quad + \frac{\gamma}{2} \sum\_{k=1}^{K} \||\mathbf{F}^{k} - \mathbf{G}^{k}\||\_{F}^{2} + \sum\_{k=1}^{K} Tr((\mathbf{F}^{k} - \mathbf{G}^{k})^{T} \mathbf{S}^{k}) \end{split} \tag{2}$$

where S = [**S**1, **<sup>S</sup>**2, ··· , **<sup>S</sup>***K*] ∈ *<sup>R</sup>M*×*N*×*<sup>K</sup>* is the Lagrange multiplier and *<sup>γ</sup>* is the stepsize. Assuming H = <sup>1</sup> *<sup>γ</sup>* S, Equation (2) can be written as:

$$\begin{split} L(\mathbb{F}, \mathbb{G}, \mathbb{H}) &= \frac{1}{2} \left\| \sum\_{k=1}^{K} \mathbf{X}\_{t}^{k} \ast \mathbb{F}^{k} - \mathbf{Y} \right\|\_{F}^{2} + \frac{1}{2} \sum\_{k=1}^{K} \|\mathbb{W} \cdot \mathbb{G}^{k}\|\_{F}^{2} \\ &\quad + \frac{\lambda}{2} \left\| \sum\_{k=1}^{K} \mathbf{X}\_{t}^{k} \ast \mathbb{F}^{k} \ast \mathbf{X}\_{t}^{k} - \mathbf{Y} \right\|\_{F}^{2} + \frac{\mu}{2} \|\mathbb{F} - \mathbb{F}\_{t-1}\|\_{F}^{2} \\ &\quad + \frac{\gamma}{2} \sum\_{k=1}^{K} \left\| \mathbb{F}^{k} - \mathbb{G}^{k} + \mathbb{H}^{k} \right\|\_{F}^{2} \end{split} \tag{3}$$

The optimization problem can be divided into several subproblems as follows.

$$\begin{aligned} \mathbb{F}^{(l+1)} &= \arg\min\_{\mathbf{F}} \left\| \sum\_{k=1}^{K} \mathbf{X}\_{l}^{k} \ast \mathbf{F}^{k} - \mathbf{Y} \right\|\_{F}^{2} \\ &+ \left\| \sum\_{k=1}^{K} \mathbf{X}\_{l}^{k} \ast \mathbf{F}^{k} \ast \mathbf{X}\_{l}^{k} - \mathbf{Y} \right\|\_{F}^{2} \\ &+ \gamma \sum\_{k=1}^{K} \left\| \mathbf{F}^{k} - \mathbf{G}^{k} + \mathbf{H}^{k} \right\|\_{F}^{2} + \mu \|\|\mathbf{F} - \mathbf{F}\_{t-1}\|\_{F}^{2} \end{aligned} \tag{4}$$

$$\mathbb{G}^{(l+1)} = \arg\min\_{\mathbf{G}} \sum\_{k=1}^{K} \left\lVert \mathbf{W} \cdot \mathbf{G}^{k} \right\rVert\_{F}^{2} + \gamma \sum\_{k=1}^{K} \left\lVert \mathbf{F}^{k} - \mathbf{G}^{k} + \mathbf{H}^{k} \right\rVert\_{F}^{2} \tag{5}$$

$$\mathbb{H}^{(l+1)} = \mathbb{H}^{(l)} + \mathbb{F}^{(l+1)} - \mathbb{G}^{(l+1)} \tag{6}$$

Then, we can alternatively solve each subproblem as follows:

**Solving** F: According to Parseval's theorem, the subproblem in Equation (4) can be formulated in the Fourier domain as:

$$\begin{aligned} \arg\min\_{\mathbf{\hat{F}}} \left\| \sum\_{k=1}^{K} \mathbf{\hat{X}}\_{t}^{k} \cdot \mathbf{f}^{k} - \mathbf{\hat{Y}} \right\|\_{F}^{2} + \lambda \left\| \sum\_{k=1}^{K} \mathbf{\hat{X}}\_{t}^{k} \cdot \mathbf{f}^{k} \cdot \mathbf{\hat{X}}\_{t}^{k} - \mathbf{\hat{Y}} \right\|\_{F}^{2} \\ + \gamma \sum\_{k=1}^{K} \left\| \mathbf{\hat{f}}^{k} - \mathbf{\hat{G}}^{k} + \mathbf{\hat{H}}^{k} \right\|\_{F}^{2} + \mu \left\| \mathbf{\hat{F}} - \mathbf{\hat{F}}\_{t-1} \right\|\_{F}^{2} \end{aligned} \tag{7}$$

Here, Fˆ represents the discrete Fourier transform (DFT) of F. From Equation (7), it can be noted that the *i*-th row and the *j*-th element of **Y**ˆ only depend on the *i*-th row and the *j*-th element of Fˆ and Xˆ *<sup>t</sup>* across all *K* channels. Assume *vij*(F) is a *K*-dimensional vector that contains the *i*-th row and the *j*-th elements of F along all *K* channels. Optimizing the problem in Equation (7) is equivalent to solving the following *MN* subproblems:

$$\begin{aligned} \arg\min\_{\boldsymbol{v}\_{ij}(\mathbb{P})} & \|\boldsymbol{v}\_{ij}(\mathbb{X}\_{t})^{T}\boldsymbol{v}\_{ij}(\mathbb{P}) - \boldsymbol{\mathcal{G}}\_{ij}\|\_{2}^{2} + \mu \parallel \boldsymbol{v}\_{ij}(\mathbb{P}) - \boldsymbol{v}\_{ij}(\mathbb{P}\_{t-1}) \parallel\_{2}^{2} \\ & + \lambda \parallel (\boldsymbol{v}\_{ij}(\mathbb{X}\_{t}) \cdot \boldsymbol{v}\_{ij}(\mathbb{X}\_{t}))^{T}\boldsymbol{v}\_{ij}(\mathbb{P}) - \boldsymbol{\mathcal{G}}\_{ij} \parallel\_{2}^{2} \\ & + \gamma \parallel \boldsymbol{v}\_{ij}(\mathbb{P}) - \boldsymbol{v}\_{ij}(\mathbb{G}) + \boldsymbol{v}\_{ij}(\mathbb{H}) \parallel\_{2}^{2} \end{aligned} \tag{8}$$

where *i* = 1, ··· , *M* and *j* = 1, ··· , *N*.

Taking the derivative of Equation (8) with respect to *vij*(Fˆ) as zero, we have:

$$w\_{i\rangle}(\hat{\mathbb{P}}) = \left(\mathbf{Q} + (\gamma + \mu)\mathbf{I}\right)^{-1}\mathbf{z} \tag{9}$$

Here, **<sup>Q</sup>** <sup>=</sup> *vij*(X<sup>ˆ</sup> *<sup>t</sup>*)*vij*(X<sup>ˆ</sup> *<sup>t</sup>*)*<sup>T</sup>* <sup>+</sup> *<sup>λ</sup>*(*vij*(X<sup>ˆ</sup> *<sup>t</sup>*) · *vij*(X<sup>ˆ</sup> *<sup>t</sup>*))(*vij*(X<sup>ˆ</sup> *<sup>t</sup>*) · *vij*(X<sup>ˆ</sup> *<sup>t</sup>*))*<sup>T</sup>* and **<sup>z</sup>** <sup>=</sup> *vij*(X<sup>ˆ</sup> *<sup>t</sup>*)*y*ˆ*ij* <sup>+</sup> *<sup>μ</sup>vij*(Fˆ*t*−1) + *<sup>λ</sup>*(*vij*(X<sup>ˆ</sup> *<sup>t</sup>*) · *vij*(X<sup>ˆ</sup> *<sup>t</sup>*)) + *<sup>γ</sup>vij*(G<sup>ˆ</sup> ) <sup>−</sup> *<sup>γ</sup>vij*(H<sup>ˆ</sup> ).

**Solving** G: From Equation (5), each element of G is able to be updated independently, and we adopted the same strategy as solving F. Assume *vij*(G) is a *K*-dimensional vector

that contains the *i*-th row and the *j*-th elements of G along all *K* channels. Optimizing the problem in Equation (5) is equivalent to solving the following *MN* subproblems:

$$\arg\min\_{v\_{ij}(\mathbb{G})} w\_{ij}^2 \parallel v\_{ij}(\mathbb{G}) \parallel\_2^2 + \gamma \parallel v\_{ij}(\mathbb{F}) - v\_{ij}(\mathbb{G}) + v\_{ij}(\mathbb{H}) \parallel\_2^2 \tag{10}$$

Taking the derivative of Equation (10) with respect to *vij*(G) as zero, we have:

$$v\_{ij}(\mathbb{G}) = (\mathbf{P}^T \mathbf{P} + \gamma \mathbf{I})^{-1} (\gamma v\_{ij}(\mathbb{F}) + \gamma v\_{ij}(\mathbb{H})) \tag{11}$$

where **P** is a diagonal matrix and each diagonal element is *wij*.

**Updating** H: Let *vij*(H) be a *K*-dimensional vector that contains the *i*-th row and the *j*-th elements of G along all *K* channels. In the *l* + 1-th iteration of the ADMM, the Lagrange multiplier vector *vij*(H) can be updated as follows:

$$v\_{i\bar{j}}(\mathbb{H})^{(l+1)} = v\_{i\bar{j}}(\mathbb{H})^{(l)} + v\_{i\bar{j}}(\mathbb{F})^{(l+1)} - v\_{i\bar{j}}(\mathbb{G})^{(l+1)} \tag{12}$$

The details of the optimization procedure can be seen in Algorithm 1.


**Input**: Feature maps X*t*, Gaussian-shaped label **Y**, previous correlation filters F*t*−1, spatial regularization matrix **W**, initial values G(0) and H(0). **Output**: Estimated correlation filters F. 1: **repeat** Step 2–Step 5 2: Update *vij*(Fˆ)(*l*+1) via Equation (9); 3: Update *vij*(G)(*l*+1) via Equation (11); 4: Update *vij*(H)(*l*+1) via Equation (12); 5: *l* = *l* + 1;


#### *3.3. Computational Complexity*

In this subsection, we discuss the computational complexity of the presented SSCF. As shown in Section 3.2, we divided the optimization problem into several subproblems. According to the Parseval theorem and the ADMM algorithm, the complexity of solving **F** is *O*(*KMN*) in each iteration. Taking the DFT and inverse DFT into account, the computational complexity of solving **F** is *O*(*KMN*log(*MN*)). Moreover, the complexity of subproblems **H** and **G** is *O*(*KMN*). Suppose the number of iteration is *T*: the whole computational complexity of the proposed SSCF is *O*(*TKMN*(log(*MN*) + 1)). In view of this, the speed of our tracker is not fast.

#### **4. Experiment Results and Analysis**

This section provides the experiments to validate the superiority of the presented SSCF in target tracking. To evaluate the performance of the proposed model, we compared it with the state-of-the-art trackers, including spatially regularized discriminative correlation filters (SRDCFs) [23], kernelized correlation filters (KCFs) [47], spatial–temporal regularized correlation filters (STRCFs) [24], background-aware correlation filters (BACFs) [25], learning adaptive discriminative correlation filters (LADCFs) [26], discriminative scale space tracking (DSST) [48], the scale-adaptive with multiple features tracker (SAMF) [12], ECOHC [49], ARCF-HC [50], the MSCF [51], and AutoTrack [52]. These experiments were

conducted on the CVPR2013 [53], OTB50 [54], OTB100 [54], DTB70 [55], UAV123 [56], and UAVDT-M databases [57].

In the experiments, our tracker was implemented using MATLAB R2017a on a computer with an i7-8700K processor (3.7GHz) with 48GB RAM. *λ* was set to 10−5, and other parameters were set to the same values as the STRCF. The histogram of oriented gradients (HOG) features were used to conduct the comparative experiments. In addition, we followed the one-pass evaluation (OPE) protocol [53] to evaluate the performance of different trackers. The success and precision plots are reported based on the bounding box overlap and center location error. The AUC is the area under the curve of the success plot, and the distance precision (DP) is the percentage of the location errors within 20 px.

#### *4.1. Results on the CVPR2013 Database*

The CVPR2013 database contains 50 fully annotated video sequences with 11 different attributes, such as background clutter, low resolution, occlusion, and out of view. The overall performance, which is summarized by the success and precision plots, is listed in Figure 1. It can be observed that the proposed SSCF achieved the top-ranking results. The area under the curve (AUC) and distance precision (DP) scores were 0.681 and 0.882, respectively. Specifically, the AUC and DP scores of SSCF were higher by 1.2% and 0.9% than the STRCF. This indicates that incorporating the second-order data-fitting term is effective at improving the tracking performance.

**Figure 1.** Success plots (**a**) and precision plots (**b**) of the proposed SSCF and other trackers on the CVPR2013 database.

To evaluate the robustness of the proposed SSCF on different attributes, we constructed subsets with different dominant attributes for the experiments. The 11 challenging factors were background clutter (BC), low resolution (LR), illumination variation (IV), motion blur (MB), out of view (OV), fast motion (FM), deformation (DEF), occlusion (OCC), out-ofplane rotation (OPR), scale variation (SV), and in-plane rotation (IPR). Table 1 shows the AUC and DP scores of the proposed SSCF and the other trackers on the 11 attributes on the CVPR2013 database. Despite not all scores of the proposed SSCF being the highest, our method achieved the best robustness. Especially for the AUC scores on the different attributes, our SSCF outperformed the other trackers, except LADCF.

#### *4.2. Results on the OTB100 Database*

OTB100 is a database containing 100 challenging video sequences, and these sequences consist of more than 28,000 fully annotated frames. The results of the success and precision plots for all trackers are shown in Figure 2. From the figure, the proposed SSCF outperformed all the competing trackers in its overall performance. Our tracker achieved 0.664 and 0.868 in terms of the AUC and DP scores, respectively.

We also provide the attribute-based evaluation to validate the robustness of our SSCF. The AUC and DP scores of all trackers on the 11 different attributes are reported in Table 2. From the DP scores listed in the table, the proposed SSCF outperformed all competing trackers on eight attributes. In terms of the AUC scores, our tracker performed better than the other trackers on seven attributes. On other attributes, the SSCF was among the top-three trackers. These results demonstrate that our SSCF is more robust than the other trackers.

**Table 1.** The area under the curve (AUC) and distance precision (DP) scores of the proposed SSCF and the other trackers on different attributes on the CVPR2013 database. The top-three methods on each attribute are denoted by different colors: red, blue, and green. That is, red represents the best performance, blue represents the second best, and green represents the third best (AUC/DP).


**Figure 2.** Success plots (**a**) and precision plots (**b**) of the proposed SSCF and the other trackers on the OTB100 database.

**Table 2.** The area under the curve (AUC) and distance precision (DP) scores of the proposed SSCF and the other trackers on different attributes on the OTB100 database. The top-three methods on each attribute are denoted by different colors: red, blue, and green. That is, red represents the best performance, blue represents the second best, and green represents the third best (AUC/DP).


#### *4.3. Results on the OTB50 Database*

Figure 3 lists the success plots comparing the presented method on OTB50 with the existing trackers. The overall performance is summarized in Figure 3a. It can be seen that the proposed SSCF had the best success rates. The success plots of all trackers on the 11 different attributes are shown in Figure 3b–l. The proposed SSCF outperformed the existing trackers on eight attributes, i.e., fast motion, background clutter, motion blur, illumination variation, in-plane rotation, occlusion, out-of-plane rotation, and out of view. Our SSCF incorporates the second-order data fitting and spatial–temporal regularization into the DCF framework to develop a robust tracking pattern. The tracking results of the SSCF on the other three attributes were among the top two. This also demonstrates the effectiveness and robustness of our tracker.

**Figure 3.** Success plots of the proposed SSCF and the other trackers on the OTB50 database. (**a**) Overall performance; (**b**–**l**) success plots on the 11 different attributes.

#### *4.4. Results on the DTB70 Database*

Figures 4 and 5 show the success plots and precision plots comparing the presented method on the DTB70 database with the existing trackers. The overall performance is summarized in Figures 4a and 5a. It is observed that our SSCF achieved the best results in the overall performance. The success plots and precision plots of all trackers on the 11 different attributes are shown in Figures 4b–l and 5b–l. Our SSCF outperformed the existing trackers on nine attributes except motion blur and low resolution.

**Figure 4.** Success plots of the proposed SSCF and the other trackers on the DTB70 database. (**a**) Overall performance; (**b**–**l**) success plots on the 11 different attributes.

**Figure 5.** Precision plots of the proposed SSCF and the other trackers on the DTB70 database. (**a**) Overall performance; (**b**–**l**) precision plots on the 11 different attributes.

#### *4.5. Results on the UAV123 Database*

The UAV123 dataset contains 123 video sequences, which is the most commonly used and most comprehensive dataset for UAV tracking. The overall performance, which is summarized by success and precision plots, is listed in Figure 6. It can be observed that the proposed SSCF achieved the top-ranking results. The area under the curve (AUC) and distance precision (DP) scores were 0.479 and 0.676, respectively.

In order to visually show the performance of the proposed SSCF in the tracking process, we selected three different types of video sequences, namely person, boat, and car sequences, to conduct the experiments. As shown in Figure 7, each column corresponds to three frames of the images, and the images were randomly selected from the video sequences. The comparative methods were five trackers, including our SSCF, AutoTrack, the MSCF, the STRCF, and the LADCF, marked in green, red, blue, yellow, and orange, respectively. It can be seen that our SSCF always tracked the correct target and had the best performance. The STRCF and LADCF were not robust in tracking the small targets.

**Figure 6.** Success plots (**a**) and precision plots (**b**) of the proposed SSCF and the other trackers on the UAV123 database.

**Figure 7.** The qualitative analysis of different trackers on three video sequences.

*4.6. Results on the UAVDT-M Database*

In this section, we compare our SSCF with the existing methods on the UAVDT-M database. We also report the running speed of these methods. The running speed was measured in frames per second (FPS). Table 3 shows the comparison results. It can be observed that our SSCF achieved better performance than the existing trackers. The area under the curve (AUC) and distance precision (DP) scores were 0.667 and 0.928, respectively. However, It should be pointed out that the performance improvement of our tracker came at the expense of speed reduction.

**Table 3.** The area under the curve (AUC), distance precision (DP) scores, and FPS of the proposed SSCF and other trackers on the UAVDT-M database.


#### **5. Conclusions**

In this paper, we proposed a new model called the second-order spatial–temporal correlation filter (SSCF) for visual object tracking. The SSCF is a DCF framework of combining the second-order data-fitting term and spatial–temporal regularization. To solve the proposed model, we divided the optimization problem into several subproblems and adopted the ADMM algorithm to solve each subproblem. By taking full advantage of the second-order data-fitting information, the SSCF becomes more discriminative and robust in addressing complex tracking situations. Extensive experiments on the benchmarking databases demonstrated that our SSCF can achieve competitive performance compared to the state-of-the-art trackers.

It can be noted that the presented SSCF achieved better tracking results than the existing trackers on most of the attributes, but it was not robust on a few attributes, such as low resolution and occlusion. Recently, occlusion-processing methods have been presented in face recognition such as occlusion dictionary learning [58,59] and the occlusion-invariant model [60]. Can these occlusion processing methods be used for object tracking with occlusion? If the answer is yes, how can we design a new model to enhance the performance? It also should be pointed out that the performance improvement of our tracker came at the expense of speed reduction. How to improve the running speed of our SSCF is an important problem. In addition, although the proposed SSCF achieved better results than the existing methods, the accuracy was not high when tracking small targets. Self-paced learning has been widely used in computer vision and machine learning [61]. Combining self-paced learning and filter learning could potentially yield better performance in tracking small targets. In future work, we will focus on these topics.

**Author Contributions:** Conceptualization, Y.Y. and G.X.; methodology, Y.Y. and G.X.; software, H.H. and J.L.; validation, L.C., H.H., W.Z. and G.X.; writing—original draft preparation, Y.Y.; writing—review and editing, Y.Y., L.C., J.L., W.Z. and G.X.; supervision, L.C. and G.X.; funding acquisition, L.C. All authors have read and agreed to the published version of the manuscript.

**Funding:** This work was supported in part by the Science and Technology Development Fund, Macau SAR (File no. 0119/2018/A3), in part by the National Natural Science Foundation of China under Grant 62006056, in part by the Natural Science Foundation of Guangdong Province under Grant 2019A1515011266, in part by National Statistical Science Research Project of China under Grant 2020LY090, and in part by Science and Technology Planning Project of Guangzhou under Grant 202102020699.

**Institutional Review Board Statement:** Not applicable.

**Informed Consent Statement:** Not applicable.

**Acknowledgments:** We greatly thank the Reviewers and Editors for the insightful comments and suggestions.

**Conflicts of Interest:** The authors declare no conflict of interest.

#### **References**


### *Article* **A RUL Prediction Method of Small Sample Equipment Based on DCNN-BiLSTM and Domain Adaptation**

**Wenbai Chen 1,\*, Weizhao Chen 1, Huixiang Liu 1, Yiqun Wang 1, Chunli Bi <sup>2</sup> and Yu Gu 3,4,5**


**Abstract:** To solve the problem of low accuracy of remaining useful life (RUL) prediction caused by insufficient sample data of equipment under complex operating conditions, an RUL prediction method of small sample equipment based on a deep convolutional neural network—bidirectional long short-term memory network (DCNN-BiLSTM) and domain adaptation is proposed. Firstly, in order to extract the common features of the equipment under the condition of sufficient samples, a network model that combines the deep convolutional neural network (DCNN) and the bidirectional long short-term memory network (BiLSTM) was used to train the source domain and target domain data simultaneously. The Maximum Mean Discrepancy (MMD) was used to constrain the distribution difference and achieve adaptive matching and feature alignment between the target domain samples and the source domain samples. After obtaining the pre-trained model, fine-tuning was used to transfer the network structure and parameters of the pre-trained model to the target domain for training, perform network optimization and finally obtain an RUL prediction model that was more suitable for the target domain data. The method was validated on a simulation dataset of commercial modular aero-propulsion provided by NASA, and the experimental results show that the method improves the prediction accuracy and generalization ability of equipment RUL under cross-working conditions and small sample conditions.

**Keywords:** DCNN-BiLSTM; domain adaptation; MMD; fine-tuning; C-MAPSS; cross-working; small sample

### **1. Introduction**

As one of the key technologies of Prognosis and Health Management (PHM), RUL prediction has become an important research content. RUL refers to the length of continuous working time of equipment components or systems from the current moment to the moment when a specific function cannot be performed [1]. Accurate RUL prediction plays a crucial role in guaranteeing system reliability and preventing system failures [2].

At present, the widely studied equipment RUL prediction methods can be divided into physical model-based methods and data-driven methods [3]. Due to the complex structure of some systems, the diverse failure modes, and the uncertainty of operating conditions, it is difficult to establish a physical failure model [4]. Data-driven methods without prior knowledge and complex physical modeling process [5] have become a research hotspot in recent years. Among them, deep learning has attracted much attention due to its powerful nonlinear mapping ability and high-dimensional feature extraction ability [6]. Babu et al. [7]

**Citation:** Chen, W.; Chen, W.; Liu, H.; Wang, Y.; Bi, C.; Gu, Y. A RUL Prediction Method of Small Sample Equipment Based on DCNN-BiLSTM and Domain Adaptation. *Mathematics* **2022**, *10*, 1022. https://doi.org/ 10.3390/math10071022

Academic Editors: Jianping Gou, Weihua Ou, Shaoning Zeng and Lan Du

Received: 16 February 2022 Accepted: 19 March 2022 Published: 23 March 2022

**Publisher's Note:** MDPI stays neutral with regard to jurisdictional claims in published maps and institutional affiliations.

**Copyright:** © 2022 by the authors. Licensee MDPI, Basel, Switzerland. This article is an open access article distributed under the terms and conditions of the Creative Commons Attribution (CC BY) license (https:// creativecommons.org/licenses/by/ 4.0/).

first tried to use the Convolutional Neural Network (CNN) to apply it to the RUL prediction of aero-engines. This model can automatically extract multi-dimensional sensor features and obtain better results than the shallow regression model. Zheng et al. [8] proposed a prediction model based on a Long Short-Term Memory (LSTM) network, which can extract the features of time series, is suitable for RUL prediction of most equipment.

The premise of data-driven methods is that the training and test data come from the same operating conditions. As a new machine learning method, transfer learning relaxes the premise that training samples and test samples must obey the same data distribution. The knowledge learned from the source domain is applied to different but related target domains to solve the problem of only a small number of labeled sample data in the target domain. Transfer learning improves the generalization ability of the machine learning model to a certain extent [9]. When the feature space and data distribution between the source domain and target domain samples are quite different, how to use the transfer learning strategy to solve the small sample problem becomes the focus of research.

Domain adaptation is an important research direction in transfer learning, which is used to solve the problem of transfer learning when the feature space and category space of two domains are consistent but the feature distribution is inconsistent. Domain adaptation methods have been used in the field of RUL prediction of equipment. Fu et al. [10] proposed a domain adaptation SAE-LSTM model, which adopted MMD to reduce the data distribution difference in RUL prediction. Li et al. [11] first proposed a multi-core MMDbased convolutional neural network model. Ragab [12] proposed a Contrastive Adversarial Domain Adaptation (CADA) method to learn similar features between different domains and improve the RUL prediction accuracy and noise immunity. Miao [13] proposed a Deep Domain Adaptative Network (DDAN) to solve the problem of cross-domain feature distribution shift under different operating conditions and failure modes. Costa et al. [14] proposed a domain adaptation method for RUL prediction under cross-working conditions based on LSTM and Domain Adversarial Neural Network (DANN). In order to solve the problem of low RUL prediction accuracy caused by small sample data sets, Lv et al. [15] proposed a Sequence Adaptation Adversarial Network (SAAN) to expand the dataset.

Traditional deep learning relies heavily on labeled data. Therefore, in view of the problem that small-sample equipment status data under different working conditions affect the RUL prediction accuracy, this paper proposes a small-sample equipment RUL prediction method based on DCNN-BiLSTM and domain adaptation. The model includes a pre-training stage, a parameter-transfer stage, and an RUL predicting stage. The pretraining and MMD constraints are used to reduce the distribution differences of sample data under different working conditions and learn the common characteristics of the source domain samples and the target domain samples after domain adaptation. Then transfer the trained model to the target domain for training fine-tune the pre-trained model to obtain an RUL prediction model more suitable for the target domain task. Finally, the Commercial Modular Aero-Propulsion System Simulation (C-MAPSS) dataset provided by NASA was used to verify the effectiveness of the method proposed in this paper.

#### **2. Related Work**

#### *2.1. CNN Convolution Model*

The CNN has powerful parameter learning and feature extraction capabilities and can be used to process multi-dimensional matrix data. In practical engineering applications, each device has multiple sensors to detect the operating status of the device, and the collected data also contains a lot of information. In order to extract deeper features, this paper used a DCNN, which consists of multiple layers of CNN.

Since the degradation data of the equipment is the time series data collected by the sensor, in this study, the input data is a two-dimensional vector, the length represents the number of features collected by the sensor, and the width represents the time series of each feature. After the two-dimensional data is processed by time window, the size of each sample obtained is represented as (*Nw*, *m*), where *Nw* represents the size of the time window and *m* represents the number of features. Each convolutional layer performs convolution operations on the input data along the time series direction through convolution kernels of different sizes, which can extract different features between the data, and finally combine the generated local feature maps as the input of the BiLSTM.

#### *2.2. BiLSTM Network Model*

The LSTM model is used to process sequence data. Compared with the Recurrent Neural Network (RNN), LSTM is mainly used to solve the problems of gradient disappearance and gradient explosion in the training process of long sequence data. The LSTM model consists of an input layer, a hidden layer, and an output layer, with three gating units and memory units, and the historical information is affected by the input gate, forgetting gate, and output gate, respectively [16]. The dependencies between long and short periods of time series can be better learned.

As shown in Figure 1, *it*, *ot*, *ft* represent the input gate, output gate, and forget gate, respectively. The forget gate decides whether to retain the previous cell state information *Ct*−1; the input gate updates the long-term memory of the cell state; the output gate is the output of the current LSTM; *C<sup>t</sup>* represents the current temporary memory unit; *xt* represents the time series of moments *t*; *ht*−<sup>1</sup> represents the output value of the previous moment; *ht* represents the output value of the current moment. Then the calculation formula of each threshold state in the forward propagation process of LSTM is as follows:

$$\mathcal{C}\_t = \tanh(\mathcal{W}\_{\text{xc}} \mathbf{x}\_t + \mathcal{W}\_{\text{hc}} h\_{t-1} + b\_{\text{c}}) \tag{1}$$

$$\mathbf{C}\_{t} = f\_{t}\mathbf{C}\_{t-1} + i\_{t}\dot{\mathbf{C}}\_{t} \tag{2}$$

$$i\_t = \sigma(\mathcal{W}\_{\text{xi}}\mathbf{x}\_t + \mathcal{W}\_{\text{hi}}\mathbf{h}\_{t-1} + \mathcal{W}\_{\text{ci}}\mathbf{c}\_{t-1} + b\_i) \tag{3}$$

$$f\_t = \sigma \left( \mathcal{W}\_{xf} \mathbf{x}\_t + \mathcal{W}\_{hf} h\_{t-1} + \mathcal{W}\_{cf} \mathbf{c}\_{t-1} + b\_f \right) \tag{4}$$

$$\circ\_{t} = \sigma \left( \mathcal{W}\_{\text{xo}} \mathbf{x}\_{t} + \mathcal{W}\_{\text{ho}} h\_{t-1} + \mathcal{W}\_{\text{co}} \mathbf{c}\_{t-1} + b\_{\text{o}} \right) \tag{5}$$

$$h\_l = o\_l \cdot \tanh(c\_l) \tag{6}$$

where *σ* represents the sigmoid activation function, tanh is the hyperbolic tangent activation function, *Wxc*, *Whc*, *Wxi*, *Whi*, *Wci*, *Wx f* , *Wh f* , *Wc f* , *Wxo*, *Who*, *Wco*, *bc*, *bi*, *bf* , and *bo* represent the weights and bias terms of each respective gate. LSTM contains many neurons, and the neurons exchange information with each other to extract time-dependent features of the data.

**Figure 1.** LSTM structure diagram.

Bi-directional LSTM (BiLSTM) contains two LSTM network layers in opposite directions, namely the forward propagation layer and the backward propagation layer, which connect the input layer and the output layer at the same time, perform time-sequence and reverse-order calculations, respectively, and obtain the output of the forward and backward hidden layer at each moment in turn. Finally, the final output is obtained by combining the corresponding output results of the forward layer and the backward layer at each moment. The BiLSTM structure diagram is shown in Figure 2. The specific calculation formula is as follows:

$$h\_t = f(w\_1 \mathbf{x}\_t + w\_2 h\_{t-1})\tag{7}$$

$$h\_t' = f\left(w\_3x\_t + w\_5h\_{t-1}{'}\right) \tag{8}$$

$$\rho\_l = \mathcal{g}\left(w\_4 x\_4 + w\_6 h\_l^{'}\right) \tag{9}$$

where *ht* and *h <sup>t</sup>* are the outputs of the forward propagation layer and the backward propagation layer at time t, respectively. *w*<sup>1</sup> and *w*<sup>3</sup> are the weight matrices from the input layer to the forward and backward propagation layers, respectively. *w*<sup>2</sup> and *w*<sup>5</sup> are the weight matrices from the forward and backward propagation layers to the self-propagation layer, respectively. *w*<sup>4</sup> and *w*<sup>6</sup> are the weight matrices from the forward and backward propagation layers to the output layer, respectively. *ot* is the output values of the final output gate. *g* are the functions for splicing the forward and backward propagation results.

**Figure 2.** BiLSTM structure diagram.

#### **3. Proposed Method**

#### *3.1. RUL Prediction Model Based on DCNN-BiLSTM*

The multi-dimensional sensor data obtained through time window processing is used as the input of the DCNN-BiLSTM fusion model, and the structure of the fusion model is shown in Figure 3. DCNN and BiLSTM process the input data, where DCNN consists of four layers of CNN and activation functions. Each layer of CNN performs low-level feature extraction by setting convolution kernels of different sizes, and then input to two layers of BiLSTM to extract time-series features, and finally two layers of fully connected layers. The BiLSTM network can comprehensively consider the historical information and future information at each moment and make full use of the information of the previous and subsequent moments to make the feature extraction process more comprehensive, improve the prediction accuracy of the time series model, and reduce the risk of overfitting. The output of the first fully connected layer is used as the measurement value of MMD. The second layer is the final prediction layer, and the output represents the RUL value of the device.

#### *3.2. Domain Adaptation Method Based on MMD*

Domain adaptation is a method of transfer learning. Domain adaptation is a machine learning algorithm that targets the distribution difference between source and target domains. A wide variety of domain adaptation methods aim to apply knowledge learned from the source domain to the target domain in the absence or few labels of the target domain by learning domain-invariant features of the source and target domains.

The MMD is the most widely used loss function in transfer learning, especially domain adaptation, and is mainly used to measure the distance between two different but similar distributions. Compared to other metrics, MMD can estimate nonparametric distances between various distributions and avoid the computation of intermediate process quantities. MMD maps the source and target domains to a Reproducing kernel Hilbert space (RKHS) and then calculates the distribution distance between the two domains. MMD is defined as:

$$MMD(X,\,\,Y) = \left\| \left| \frac{1}{n\_s} \sum\_{i=1}^{n\_s} \varphi(x\_i) - \frac{1}{n\_t} \sum\_{j=1}^{n\_t} \varphi(y\_j) \right| \right\|\_{\mathcal{H}}^2 \tag{10}$$

where H represents the RHKS space, *ns* and *nt* represents the number of samples in the source domain and the target domain, respectively, *ϕ*(*x*) : *X* → H represents the mapping function from the original feature space to the RKHS, and then uses the kernel method to calculate the inner product to avoid high-dimensional complex operations, usually using a Gaussian kernel function, which represents for:

$$K(\mu, \nu) = e^{\frac{-\|\mu - \nu\|^2}{\sigma}} \tag{11}$$

where *μ* and *ν* represent different samples and σ is the width parameter of the function, which controls the radial range of the function.

In the case where there is a difference in the distribution between the source domain data and the target domain data, the MMD is added to the loss function to optimize the target. Therefore, the loss function of the pre-trained network model is defined as:

$$Loss = MSE\_{\text{\\_loss}} + \lambda MMD\_{\text{\\_loss}} \tag{12}$$

where *MSE*\_*loss* is the mean square loss function and *λ* represents the balance function, *λ* > 0.

The transfer learning in this paper is based on the method of domain adaptation. During the pre-training process, the source domain and target domain datasets are trained at the same time. The output value of the first fully connected layer of the DCNN-BiLSTM network model is used as the sample space for calculating the distribution distance between

the two domainsm, and finally, the pre-training model after domain adaptation is obtained. The training process is shown in Figure 4.

**Figure 4.** The pre-training framework based on domain adaptation.

#### *3.3. Fine-Tune the Target Model*

In order to shorten the training time of the target model, make the target model more adaptable to different operating conditions and environments, and improve the generalization ability. In this section, the Adam optimizer is used to fine-tune the pretrained model. The flowchart of fine-tuning is shown in Figure 5. First, initialize the target model with the weights and parameters of the pre-trained model, then freeze the parameters of the feature extraction layers, including 4-layer CNN and 2-layer BiLSTM, and only update the parameters of the task-specific layer, i.e., the two-layer fully connected layer. Furthermore, to prevent overfitting, different learning rates are set for the two fully connected layers. Finally, a prediction model that is more suitable for the target domain task and has strong generalization ability is obtained by training.

**Figure 5.** The flowchart of fine-tuning.

#### **4. Experimental Results and Analysis**

#### *4.1. Dataset Description*

The method in this paper was evaluated using the turbofan engine degradation data of the Commercial Modular Aero-Propulsion System Simulation (C-MAPSS) dataset provided by NASA. The detailed information of the dataset is presented in Table 1.This dataset consists of four different sub-datasets with different operational conditions and fault modes. Each sub-dataset contains time-series information collected by 21 sensors and 3 measurements of operational conditions. The training set and the test set have different numbers of degraded engines, each with a different degree of initial wear, and after the number of cycles increases, the engine slowly ages until it fails to work. The training set records the degradation process of the entire life cycle of the engine, while the test set only includes a certain moment before failure. The task is to predict the remaining useful life (RUL) of the engine units in the test set.


**Table 1.** Information of the C-MAPSS dataset.

Seven sensor values were observed to remain unchanged within the FD001 subset. In order to save computing resources, meaningless data is eliminated, and 14 sensors were obtained as 2, 3, 4, 7, 8, 9, 11, 12, 13, 14, 15, 17, 20, and 21.

#### *4.2. Data Processing*

The original data is composed of data detected by multiple sensors. Different data sets have different sequence lengths, and the data dimensions are high and have different dimensions. Therefore, the min-max normalization method is used to unify the data into the range [−1,1]. Each measurement *xi*,*<sup>j</sup>* is min-max normalized and can be expressed as [17]:

$$
\widetilde{\mathbf{x}}\_{i,j} = \frac{2\left(\mathbf{x}\_{i,j} - \mathbf{x}\_{\min}^j\right)}{\left(\mathbf{x}\_{\max}^j - \mathbf{x}\_{\min}^j\right)} - \mathbf{1} \tag{13}
$$

where *<sup>x</sup>i*,*<sup>j</sup>* represents the normalized data, *<sup>x</sup><sup>j</sup> min* and *<sup>x</sup><sup>j</sup> max* represents the minimum and maximum values of the data monitored by the *j*th sensor in one operating cycle, respectively.

In order to obtain more useful temporal information from the input data, the normalized data is subjected to time windowing. For continuous time-series data, a sliding time window is used to define data labels, and the size of the input model sequence is determined by the size of the time window.

The window of size *Nw* slides along the time series, and each time step slides *l* will feedback the data to the slider, which is used as the input of the prediction model, so the input size of the network is *Nw* × *m*. To get more samples and reduce the risk of overfitting, the sliding time step is set to 1.

When the engine is running under normal conditions, taking the remaining operating cycle period as RUL, then we assume that RUL decreases linearly, using a piecewise linear function, choose 125 as the initial life period [18], and apply it to the training set and test set.

#### *4.3. Selection of Evaluation Indicators*

To verify the effectiveness of the method in this paper, two functions were used as evaluation metrics, namely the Root Mean Square Error (RMSE) function and the Score function [19]. The RMSE function formula is:

$$RMSE = \sqrt{\frac{1}{N} \sum\_{i=1}^{N} (\mathcal{Y}\_i - y\_i)} \tag{14}$$

The formula for the Score function is:

$$Score = \begin{cases} \begin{array}{l} \sum\_{i=1}^{N} \left( e^{-\frac{\hat{y}\_{i} - y\_{i}}{15}} - 1 \right), \hat{y}\_{i} - y\_{i} < 0\\ \sum\_{i=1}^{N} \left( e^{\frac{\hat{y}\_{i} - y\_{i}}{10}} - 1 \right), \hat{y}\_{i} - y\_{i} \ge 0 \end{array} \tag{15}$$

where *y*ˆ*<sup>i</sup>* and *yi* represent the predicted value and the actual value of RUL, respectively.

RMSE reflects the degree of fit between the predicted life and the actual life, and the size of the Score measures the rationality of life prediction. The lower the values of RMSE and Score, the better the predictive ability of the model.

#### *4.4. Experimental Configuration and Parameters*

All experiments in this paper are performed on a processor configured with 16 GB memory (RAM), NVIDIA GeForce TITAN XP graphics card, and Intel(R) Xeon(R) CPU E5-2620 v4 @ 2.10GHz processor. The network model proposed in this paper is based on Python3.6 and the PyTorch deep learning framework. In the experiments in this paper, considering the influence of the sample size on the prediction accuracy and the influence of different operational conditions and fault modes, in order to improve the generalization ability of the RUL prediction model, according to the size of the data, we use FD002 and FD004 with the sufficient sample size in C-MAPPS as source domain datasets, and FD001 and FD003 datasets in C-MAPPS with insufficient sample size as target domains. We evaluate the performance of transfer learning in RUL prediction of the target domain and investigate how different working conditions and the number of samples of the source and target domain datasets affect the performance of the final prediction model. Therefore, set the experimental tasks as shown in Table 2.


**Table 2.** Transfer learning experiment tasks.

#### *4.5. Model Prediction Results and Analysis*

In order to compare the effectiveness of the transfer learning method proposed in this paper, the results of the method were compared with the experiments without transfer, as shown in Table 3. Source-Only refers to directly testing the target domain with the pre-trained model, Target-Only refers to training and testing only on the target domain.


**Table 3.** Compare transfer learning with no transfer.

As can be seen from Table 3, the transfer learning algorithm proposed in this paper greatly improves the accuracy of the prediction model. Due to the influence of the difference in the distribution of the data set, the pre-training model of Source-Only was directly used for testing, and the effect was very bad. On the basis of the traditional Target-Only prediction method, the pre-training model after domain adaptation was loaded, and then the model was optimized in the target domain, and the prediction accuracy was improved. Take FD002→FD003 as an example, the RMSE increased by at least 6.82%, and the score function value increased by at least 13.48%.

In the pre-training stage, the MMD item of the tuning process not only affected the prediction accuracy of the data set but also affected the matching degree of the conditional distribution. Therefore, the coefficient *λ* of MMD\_loss of the loss function had a greater impact on the adaptive effect. Taking FD002→FD001 as an example, when −1 was used as the median value, a large number of comparative experiments were carried out by increasing or decreasing order of magnitude. As shown in Figure 6, the horizontal axis is the value size, and the vertical axis is the two evaluation indicators values of RMSE and Score. It can be seen that when *λ* = 0.001, both RMSE and Score achieve the minimum value, so the coefficient value *λ* of MMD\_loss in this experiment was 0.001.

**Figure 6.** The impact of different *λ* values on the prediction results.

In order to compare the prediction effect of the DCNN-BiLSTM model based on transfer learning on the small sample data set, Figure 7 shows the prediction results of all the engine units on the test set sorted from small to large according to the RUL value of the four tasks of experiments. The horizontal axis represents the test engine unit, and the vertical axis represents the RUL. It can be seen from Figure 7 that the DCNN model could effectively extract the detailed features and similar features of the engine degradation, even if it is difficult to predict at the beginning of the operation. The value was also closer to the set value of 125. As the running period increases, BiLSTM could effectively obtain the relationship between the time series before and after. Combining the functions of fusion model and domain adaptation, it can be seen from Figure 7 that its prediction trend was stable and could better fit the real degradation curve. Therefore, the transfer learning model proposed in this paper shows a good prediction effect.

**Figure 7.** RUL prediction results of four tasks of experiments. (**a**) FD001 Engine Prediction Results (FD002→FD001). (**b**) FD001 Engine Prediction Results (FD004→FD001). (**c**) FD003 Engine Prediction Results (FD002→FD003). (**d**) FD003 Engine Prediction Results (FD004→FD003).

Taking FD002→FD001 as an example, the error and relative error of all engines in FD001 are used to intuitively show the accuracy of RUL prediction with the method in this paper. The results are shown in Figure 8. It can be seen from Figure 8a that when the engine starts to run, the RUL value is relatively large, and the prediction error is relatively large. When the engine runs for a long time or is about to fail, the degradation information is more obvious, and the prediction performance is significantly enhanced. Under a limited sample, it is difficult to accurately predict the equipment life of one set of different working conditions with the sensor data of another set of working conditions. The method in this paper improves this problem to a certain level so that the relative error generally remains at [−25%, 25%] as the Figure 8b.

**Figure 8.** Error curve of RUL prediction results in task FD002→FD001. (**a**) Absolute error. (**b**)Relative error.

In order to verify the effectiveness of the DCNN-BiLSTM, the five state-of-the-art network models are used to compare the hybrid network DCNN-BiLSTM; the RUL prediction results are shown in Table 4. It can be observed that the DCNN-BiLSTM model performed significantly better than SVM, MLP, CNN, LSTM, and CNN-LSTM in datasets FD001 and FD003. The DCNN-BiLSTM adopts a multi-layer convolutional network structure and a bidirectional long and short-term memory network, which can extract spatial and temporal features in detail, strengthen the feature extraction ability, and effectively improve the prediction accuracy.

**Table 4.** The results of the hybrid network model in this paper are compared with other network models on the C-MAPSS dataset.


To further verify the effectiveness and superiority of the proposed method in this paper, this method is compared with the advanced methods in recent years, and the comparison results with CORAL, WDGRL, DDC, ADDA, and RULDDA methods are shown in Table 5.


**Table 5.** The results of the methods in this paper are compared with other methods on the C-MAPSS dataset.

From Table 5, we can see the proposed DCNN-BiLSTM (TL) method obtained substantially improved RMSE and Score prediction accuracy on all tasks. More specifically, RMSE and Score indicators on four tasks had reduced 5.77%, 63.26%, 49.49%, 18.65%, and 41.89%, 97.79%, 96.76%, and 73.47%, respectively, compared with the best result in the state-of-the-art methods. In addition, it can be observed that knowledge transfer between simple and complex datasets is challenging due to the large domain shift. For example, FD002→FD003 and FD004→FD001 are the transfer learning tasks of simple and complex datasets, and our proposed method obtained the greatest improvement and successfully aligned the two distant domains. The results show that the proposed transfer learning method could reduce the impact of operational conditions and fault modes on the RUL prediction accuracy of the target domain and effectively transfer the knowledge of the source domain with a large sample size, which is equivalent to data augmentation effectively for the target domain with small sample sizes. It improves the performance of the RUL prediction model. This is of great significance for equipment RUL prediction with small sample sizes in complex environments. Therefore, the proposed method is very promising in solving the small sample problem in the field of RUL prediction.

#### **5. Conclusions**

In the traditional data-driven RUL prediction method, the state detection data of the training set and the test set are required to have the same or similar distribution. However, due to different operational conditions, fault modes, and some force majeure factors in the actual working environment, it is generally difficult to obtain data sets that satisfy the same data distribution. In order to solve the problem that it is difficult to collect equipment operational data in some specific environments and the RUL prediction accuracy of equipment is not high, and the generalization ability is weak under different working conditions, this paper proposes a transfer learning-based RUL prediction method for small-sample equipment.

The method in this paper uses the DCNN-BiLSTM model to simultaneously train the source and target domain data and uses MMD to constrain the distribution difference between the two domains so as to realize the adaptation matching and feature alignment of the target domain samples and the source domain samples. The deep features are extracted to obtain a pre-trained model. Then, the network structure and parameters of the pre-trained model are transferred to the target domain for training by a fine-tuned transfer learning strategy, and the network is optimized. Finally, an RUL prediction model that is more suitable for the target domain data is obtained. When used on the C-MAPSS dataset, compared with other state-of-the-art methods, it verifies the effectiveness of the method

proposed in this paper for predicting the RUL of aero-engines. For the subsets FD002 and FD004 with complex operating conditions and sufficient sample data, the transfer learning method is used to solve the subsets FD001 and FD003 with single operating conditions and small data samples, and the effect is significantly improved.

In future research, more experiments will be conducted on different degradation datasets to demonstrate the reliability and generality of the proposed model. Furthermore, domain adaptation methods are applied to make unsupervised predictions on incomplete data of target domains with missing labels. Although the experiments in this paper have obtained good experimental results, it is still necessary to further optimize the network structure and parameters to improve the performance of the RUL model.

**Author Contributions:** Conceptualization, W.C. (Weizhao Chen) and W.C. (Wenbai Chen); methodology, W.C. (Weizhao Chen); validation, H.L. and Y.W.; formal analysis, H.L. and Y.W.; investigation, W.C. (Weizhao Chen) and W.C. (Wenbai Chen); resources, C.B. and Y.G.; writing—original draft preparation, W.C. (Weizhao Chen); writing—review and editing, W.C. (Weizhao Chen); supervision, W.C. (Wenbai Chen); project administration, W.C. (Wenbai Chen); funding acquisition, W.C. (Wenbai Chen), C.B. and Y.G. All authors have read and agreed to the published version of the manuscript.

**Funding:** This research was funded by The Major Project of Scientific and Technological Innovation 2030 (2021ZD0113603), The Qin Xin Talents Cultivation Program, Beijing Information Science and Technology University (QXTCP A202102), and The General Project of Beijing Municipal Education Commission Scientific Research Program (KM202011232023).

**Institutional Review Board Statement:** Not applicable.

**Informed Consent Statement:** Not applicable.

**Data Availability Statement:** The data of this paper came from the NASA Prognostics Center of Excellence, and the data acquisition website was: https://ti.arc.nasa.gov/tech/dash/groups/pcoe/ prognostic-data-repository/#turbofan, accessed on 10 February 2022.

**Conflicts of Interest:** The authors declare no conflict of interest.

#### **References**


### *Article* **Blind Image Deblurring via a Novel Sparse Channel Prior**

**Dayi Yang 1,2\*, Xiaojun Wu 1,2,\* and Hefeng Yin 1,2**


**Abstract:** Blind image deblurring (BID) is a long-standing challenging problem in low-level image processing. To achieve visually pleasing results, it is of utmost importance to select good image priors. In this work, we develop the ratio of the dark channel prior (DCP) to the bright channel prior (BCP) as an image prior for solving the BID problem. Specifically, the above two channel priors obtained from RGB images are used to construct an innovative sparse channel prior at first, and then the learned prior is incorporated into the BID tasks. The proposed sparse channel prior enhances the sparsity of the DCP. At the same time, it also shows the inverse relationship between the DCP and BCP. We employ the auxiliary variable technique to integrate the proposed sparse prior information into the iterative restoration procedure. Extensive experiments on real and synthetic blurry sets show that the proposed algorithm is efficient and competitive compared with the state-of-the-art methods and that the proposed sparse channel prior for blind deblurring is effective.

**Keywords:** blind image deblurring; image prior; sparse channel; sparsity

**MSC:** 68U10

### **1. Introduction**

The goal of blind image deblurring is to restore a sharp image and a blur kernel from the input degraded image. The degradation types include motion blur, noise, outof-focus and camera shake. Assuming that the blur is uniform and spatially invariant, the mathematical formulation of the blurring process can be modeled as

$$b = l \ast k + n \tag{1}$$

where *b* is the blurry input, *k* is the blur kernel and *n* is the additive noise. The ∗ denotes the convolution operator. This problem is highly ill-posed because both the latent sharp image *l* and blur kernel *k* are unknown. In order to make this problem well-posed, most existing methods utilize the statistics of natural images to estimate the blur kernel. For example, a heavy-tailed distribution [1], patch recurrence prior [2], nuclear norm [3,4], low-rank prior [5], sparse prior [6], multiscale latent prior [7] or additional information of a specific image [8–10] have been used to estimate a better kernel.

Strong sparsity of image intensity and gradient has been widely used in low-level computer vision processing problems. It also has mature applications in the field of image deblurring [6,11–13], such as the *L*1/*L*<sup>2</sup> [14] norm, the reweighted *L*<sup>1</sup> norm [15], the *L*<sup>0</sup> norm prior [16–19] and the sparse prior–local maximum gradient (LMG) [20]. For favoring clear images over blurry ones, the edge selection method [21–23] is embedded in the blind deconvolution framework. However, strong edges are not always available in many cases. The channel prior was introduced by He et al. for image defogging in Ref. [24]. Then, Pan et al. [18] enforced the sparsity of the dark channel by the *L*<sup>0</sup> norm for kernel estimation. Unfortunately, this prior does not work well on images with large noise and large numbers

**Citation:** Yang, D.; Wu, X.; Yin, H. Blind Image Deblurring via a Novel Sparse Channel Prior. *Mathematics* **2022**, *10*, 1238. https://doi.org/ 10.3390/math10081238

Academic Editors: Jianping Gou, Weihua Ou, Shaoning Zeng and Lan Du

Received: 25 February 2022 Accepted: 6 April 2022 Published: 9 April 2022

**Publisher's Note:** MDPI stays neutral with regard to jurisdictional claims in published maps and institutional affiliations.

**Copyright:** © 2022 by the authors. Licensee MDPI, Basel, Switzerland. This article is an open access article distributed under the terms and conditions of the Creative Commons Attribution (CC BY) license (https:// creativecommons.org/licenses/by/ 4.0/).

of pixels. To solve this problem, Yan et al. [19] proposed an extreme channel prior (ECP) which utilizes both the dark channel and bright channel for estimating the blur kernel.

In this paper, a novel sparse channel prior is proposed for blind image deblurring. Inspired by [18,19,24], we take the advantages of the DCP and BCP to construct a confrontation constraint D/B. We prove its characteristic from a mathematical perspective and explore how these properties can be used to estimate the blur kernel. In the proposed algorithm, the optimization of the proposed prior is a challenging problem. We use the idea of auxiliary variables and the alternating minimization method to decompose the problem into independent subproblems optimised by the alternating direction minimization (ADM) method. The main contributions of this work can be stated as follows:


The rest of this paper is organized as follows. Section 2 introduces the related work. The proposed D/B is detailed in Section 3. Our blind deblurring model and optimization strategy are presented in Section 4. Section 5 shows the experimental results. Further discussion of our proposed deblurring algorithm is given in Section 6. Section 7 summarizes this paper.

#### **2. Related Work**

Blind image deblurring algorithms have made great progress due to the use of the proper kernel estimation model. In this part, we introduce the methods related to our work in an appropriate context.

The success of many blind image deblurring algorithms is based on the use of the statistical characteristics of the image intensity and gradient. Krishnan et al. [14] presented the *L*1/*L*<sup>2</sup> norm based on the sparsity of image intensity. The *L*1/*L*<sup>2</sup> norm is a normalized version of *L*1, which enhances the sparsity of *L*1. Levin et al. [1] observed the heavy-tailed distribution of image intensities and introduced a maximum posteriori (MAP) framework. Shan et al. [25] introduced a probability model to fit the sparse gradient distribution of a natural image. Pan et al. [16] developed a method in which both intensity and gradient are regularized by the *L*<sup>0</sup> norm for text image deblurring. These methods are limited by the modeling of more complex image structures and contexts.

Another group of blind image deblurring methods [22,23] employs a significant edge detection step for kernel estimation. Specifically, Cho et al. [21] predicted sharp edges by the bilateral and shock filters. Joshi et al. [26] detected image contours by locating the subpixels' extrema. These methods cannot capture the sparse kernel and structures, which makes the restored image blurry and noisy sometimes. To solve these problems, researchers have proposed many better models to estimate the blur kernel. Xu et al. [27] presented a two-phase kernel estimation algorithm, which separates kernel initialization from the iterative support detection (ISD)-based kernel refinement step, giving an efficient estimation process and maintaining many small structures. Zoran and Weiss [28] proposed the expected patch log likelihood (EPLL) method, which imposes a prior on the patches of the final image. However, this will iteratively restore the degradation. Vardan et al. [29] exploited the multiscale prior to further improve the EPLL and reduce the error to that of the global modeling. Bai et al. [7] developed a multiscale latent structures (MSLS) prior. Based on the MSLS prior, their deblurring algorithm consists of two stages: sharp image estimation in the coarse scales and a refinement process in the finest scale. For the patch-based methods, global modeling is a difficult problem.

With the rapid development of the deep learning method, remarkable results have been achieved in the field of blind image deblurring [30–34]. For example, convolutional neural networsk (CNN) [35], Wasserste generative adversarial networks (GAN) [36], deep hierarchical multipatch networks (DMPHN) [37], ConvLSTM [38] and scale-recurrent networks (SRN) [39] are all designed for image deblurring. Zheng et al. [40] presented an edge heuristic multiscale GAN, which utilizes the edge's information to conduct the deblurring process in a coarse-to-fine manner for nonuniform blur. Liang et al. [41] learned novel neural network structures from RAW images and achieved superb performance. Chang et al. [42] proposed a long–short-exposure fusion network (LSFNet) for low-light image restoration by using the pairs of long- and short-exposure images. The success of deep-learning-based methods mainly relies on the consistency between training and test data, which limits the generalization ability of these methods.

Recently, the classical dark channel prior (DCP) has been proved effective for image deblurring. The DCP was introduced by He et al. [24] for image defogging. It is based on the observation that there is at least one color channel that has very low and close-to-zero pixel values on outdoor haze-free nonsky image patches. Pan et al. [18] further found that most elements of the dark channel are zero for nature images and then enhanced the sparsity of dark channel for image deblurring. Inspired by the DCP, the bright channel prior (BCP) is proposed. That is, in most of nature patches, at least one color channel has very high pixel values. Yan et al. [19] used the simple addition of the DCP and BCP to form an extreme channel prior (ECP) for a blind image deblurring algorithm. However, the relationship between the BCP and DCP is not fully explored in the ECP.

#### **3. Proposed Sparse Channel Prior**

To explain that the proposed sparse channel vary after blurring, we model the blurring process as described in [43]. For an image *I*, consider the noise is small enough to be neglected. We have:

$$b(\mathbf{x}) = \sum\_{z \in \Psi(\mathbf{x})} l\left(\mathbf{x} + \left[\frac{m}{2}\right] - z\right) k(z) \tag{2}$$

where *x* and *m* denote the coordinates of the pixel and the size of the blur kernel *k*, respectively. <sup>Ψ</sup>(*x*) represents an image patch centered at *<sup>x</sup>*, <sup>∑</sup>*z*∈Ψ(*x*) *<sup>k</sup>*(*z*)= 1 and *<sup>k</sup>*(*z*) ≥ 0. [·] is a rounding operator.

Inspired by the two channels (dark and bright channels) and the statistics of images, we observe that when the dark channel is more different from the bright channel of one image patch, the edges are more salient, which is helpful to estimate an accurate blur kernel. To formally describe this observation, the proposed sparse channel prior is defined by:

$$\begin{aligned} \mathcal{R}(\mathbf{x}) &= \min\_{\mathbf{y} \in \mathbb{F}(\mathbf{x})} \left( \min\_{\mathbf{c} \in (r\_{\mathcal{S}}, \mathbf{b})} (I^c(\mathbf{y})) \right) \\ &\quad / \left( \max\_{\mathbf{y} \in \mathbb{F}(\mathbf{x})} \left( \max\_{\mathbf{c} \in (r\_{\mathcal{S}}, \mathbf{b})} (I^c(\mathbf{y})) \right) + \epsilon \right) \\ &= D(\mathbf{x}) / (B(\mathbf{x}) + \epsilon) \end{aligned} \tag{3}$$

where *x* and *y* denote the coordinates of the pixel,  is a non-negative constant and Ψ(*x*) represents an image patch centered at *x*. *I<sup>c</sup>* is the *c*-th color channel of image *I*. As described in Equation (3), B(*x*) = max*y*∈Ψ(*x*) max*c*∈(*r*,*g*,*b*)(*Ic*(*y*)) represents the BCP and *<sup>D</sup>*(*x*) = min*y*∈Ψ(*x*) min*c*∈(*r*,*g*,*b*)(*Ic*(*y*)) represents the DCP. Dark channels are obtained by two minimization operations: min*c*∈(*r*,*g*,*b*) and min*y*∈Ψ(*x*). The bright channel is obtained by two maximization operations: max*c*∈(*r*,*g*,*b*)and max*y*∈Ψ(*x*). In the implementations of the DCP and BCP, if *I* is a gray image, then only the latter operation is performed. A small value of *R*(*x*) implies there are salient edges in the image patch. On the contrary, a large *R*(*x*) implies that there are fine structures in an image patch. The reason is that when the edge is salient, the pixel values are more different between the two sides of edges. It means that the minimum value is more different from the maximum value of the image patch. Conversely, when the difference between the DCP and BCP is not that large, the image edge is unclear, and the value of *R*(*x*) is large. Therefore, it is natural to think that if the

DCP is equal to or slightly smaller than the BCP, small edges can be accurately removed by minimizing Equation (3).

Consider a natural image that was blurred by a blur kernel. Blur reduces the maximum pixel value and increases the minimum pixel value of one patch. In other words, the DCP of one patch will increase and the BCP will decrease. Let *R*(*b*) and *R*(*l*) denote the proposed sparse channel of the blurred and clear image, respectively, when the *l*(*x*) = max*y*∈Ψ(*x*)*l*(*y*) = min*y*∈Ψ(*x*)*l*(*y*), *<sup>R</sup>*(*b*)(*x*) ≥ *<sup>R</sup>*(*l*)(*x*). To further apply this proposition to the definition of the proposed sparse channel, we have:

$$\begin{split} R(b)(x) &= \frac{\min\_{y \in \Psi(x)} \left( \min\_{c \in (x, y, b)} (b^c(y)) \right)}{\max\_{y \in \Psi(x)} \left( \max\_{x \in (x, y, b)} (b^c(y)) \right) + \epsilon} \\ &= \frac{\min\_{y \in \Psi(x)} b(y)}{\max\_{y \in \Psi(x)} b(y) + \epsilon} \\ &= \frac{\min\_{y \in \Psi(x)} \sum\_{z \in \Phi(x)} l\left(y + \left\lfloor \frac{w}{2} \right\rfloor - z\right) k(z)}{\max\_{y \in \Psi(x)} \sum\_{z \in \Phi(x)} l\left(y + \left\lfloor \frac{w}{2} \right\rfloor - z\right) k(z) + \epsilon} \\ &\geq \frac{\sum\_{z \in \Phi(x)} \min\_{y \in \Psi(x)} l\left(y + \left\lfloor \frac{w}{2} \right\rfloor - z\right) k(z)}{\sum\_{z \in \Phi(x)} \max\_{y \in \Psi(x)} l\left(y + \left\lfloor \frac{w}{2} \right\rfloor - z\right) k(z) + \epsilon} \\ &\geq \frac{\sum\_{z \in \Phi(x)} \min\_{y \in \Psi(x)} l\left(\hat{y} + \left\lfloor \frac{w}{2} \right\rfloor - z\right) k(z)}{\sum\_{z \in \Phi(x)} \max\_{y \in \Psi(x)} l\left(\hat{y} + \left\lfloor \frac{w}{2} \right\rfloor - z\right) k(z) + \epsilon} \\ &= \frac{\min\_{y \in \Psi(x)} l\left(\hat{y}\right)}{\max\_{y \in \Psi(x)} l\left(\hat{y}\right)} \\ &= R(l)(x) \end{split} \tag{4.1}$$

Let *<sup>m</sup>* and *<sup>S</sup>*<sup>Ψ</sup> denote the size of <sup>Ψ</sup>(*x*) and <sup>Ψ</sup>(*x*), respectively. Then we have *<sup>m</sup>* <sup>=</sup> *<sup>S</sup>*<sup>Ψ</sup> <sup>+</sup> *<sup>m</sup>*. Equation (4) shows that *R*(*x*) of the image patch centered at *x* after blurring is no less than the value of the original image patch centered at *x*.

Equation (4) proves *R*(*l*)(*x*) ≤ *R*(*b*)(*x*). This means that after blurring, the difference between the DCP and the BCP is smaller than that of the corresponding patch in a sharp image. In other words, *R*(*x*) always favors the sharp image. We further validate our analysis on the dataset [44]. Figure 1a–c show the histogram of the average number of dark channel pixels, bright channel pixels and D/B channel pixels, respectively. As can be observed, a large portion of the pixels in the dark channels and bright channels possess very small or large values, and our D/B channel pixels possess smaller values than those of the DCP and BCP. As shown in Figure 1, the proposed sparse channels of clear images have significantly more zero elements than those of blurred images. Thus, the sparsity of the proposed channel is a natural metric to distinguish clear images from blurred images. This observation motivates us to introduce a new regularization term to enforce sparsity of the proposed channels in latent images.

#### *Proposed Sparse Channel as an Image Prior*

Equation (4) shows that after blurring, the difference between the DCP and BCP is smaller than that of the corresponding patch in a sharp image. Therefore, in order to generate sharp and reliable salient edges, we propose a novel sparse channel prior which combines the D/B and *L*<sup>0</sup> norm:

$$P(\mathbf{x}) = \frac{||D(\mathbf{x})||\_0}{||B(\mathbf{x})||\_0 + \epsilon} \tag{5}$$

**Figure 1.** The statistics of the DCP, the BCP and our proposed D/B prior: (**a**–**c**) average channel pixels distribution of bright, dark and our D/B, respectively.

We define *P*(*x*) as a D/B prior, and the *L*<sup>0</sup> norm is used for sparsity. Let Ψ(*x*) denote one patch of the image *I*. If there exist some pixels *x* ∈ Ψ(*x*) such that *I*(*x*) = 0, we have

$$P(b)(\mathbf{x}) \ge P(l)(\mathbf{x})\tag{6}$$

where *P*(*b*)(*x*) and *P*(*l*)(*x*) denote the D/B prior of the blurred and clear image, respectively. This property directly follows from Equation (4). In the framework of MAP, by minimizing the sparse prior *P*(*x*), we obtain a result that favors a sharp image. This property is also validated using dataset [44]. As shown in Figure 1c, the average number of D/B channels in clear images has significantly more zero elements than that of blurred ones.

#### **4. Proposed Blind Deblurring Model**

Based on the proposed D/B prior, we construct the blind deblurring model under the maximum a posteriori (MAP) framework.

$$\mathbf{z}\mathbf{g}\mathbf{m}\mathbf{n}\_{l,k}\|l\odot k-b\|\_{2}^{2}+\mu P(l)+\theta\|\nabla l\|\_{0}+\gamma\|k\|\_{2}^{2}\tag{7}$$

where *P*(*l*) is our proposed prior, ∇ denotes the gradient operation and *μ*, *ϑ* and *γ* are non-negative weights. The data-fitting term of our model ensures that the latent sharp image is consistent with the observed image. ∇*l*<sup>0</sup> is the *L*<sup>0</sup> norm of the image gradient, which is used to suppress ringing and artifacts. Finally, we use the *L*<sup>2</sup> norm to increase the sparsity of the blur kernel.

#### *4.1. Optimization*

In this part, we adopt the ADM method to obtain the solution to the objective function. By using the idea of alternating optimization, we can obtain two independent subproblems about *l* and *k*, respectively:

$$\mathbf{argmin}\_{l} \|l \otimes k - b\|\_{2}^{2} + \mu \frac{\|D(l)\|\_{0}}{\|B(l)\|\_{0} + \epsilon} + \theta \|\nabla l\|\_{0} \tag{8}$$

and

$$\mathbf{a}\mathbf{g}\mathbf{m}\mathbf{n}\_k\|l\otimes k-b\|\_2^2+\gamma\|k\|\_2^2\tag{9}$$

Equation (9) is a classical least squares problem with respect to *k*. By introducing the auxiliary variable *g*, which is related to ∇*l*, Equation (8) can be written as follows:

$$\mathbf{z}\mathbf{g}\mathbf{z}\mathbf{m}\mathbf{n}\_{l\_{\mathrm{cl}}\mathbb{R}}\|l\otimes k-b\|\_{2}^{2}+\lambda\|\nabla l-\mathbf{g}\|\_{2}^{2}+\mu\frac{\|D(l)\|\_{0}}{\|B(l)\|\_{0}+\epsilon}+\theta\|\mathbf{g}\|\_{0}\tag{10}$$

Equation (10) can be decomposed into:

$$\mathbf{z}\mathbf{g}\mathbf{m}\mathbf{n}\_{l}\|l\otimes k-b\|\_{2}^{2}+\lambda\|\nabla l-\mathbf{g}\|\_{2}^{2}+\mu\frac{\|D(l)\|\_{0}}{\|B(l)\|\_{0}+\epsilon}\tag{11}$$

and

$$\text{argmin}\_{\mathcal{J}} \lambda ||\nabla I - \mathcal{g}||\_2^2 + \theta ||\mathcal{g}||\_0 \tag{12}$$

Equation (12) is an *L*<sup>0</sup> norm minimization problem for *g*.

#### *4.2. Estimating Intermediate Image l*

For the *k*-th iteration, we consider *B*(*l*) estimated in the (*k* − 1)-th iteration as a constant. Denoting

*wk* = *μ*/(*B*(*l*)<sup>0</sup> + ) (13)

Equation (11) can be rewritten as follows:

$$\mathbf{z}\mathbf{z}\mathbf{g}\mathbf{r}\mathbf{n}\_{\parallel}\|l\odot k-b\|\_{2}^{2}+\lambda\|\nabla l-\mathbf{g}\|\_{2}^{2}+w\_{k}\|D(l)\|\_{0}\tag{14}$$

By introducing an auxiliary variable, *p*, which is related to *D*(*l*), Equation (14) can be reformulated as follows:

$$\mathbf{z}\mathbf{g}\mathbf{z}\mathbf{m}\mathbf{n}\_{l,p}\left\|l\odot k-b\right\|\_{2}^{2}+\xi\left\|D(l)-p\right\|\_{2}^{2}+\lambda\left\|\nabla l-\mathbf{g}\right\|\_{2}^{2}+w\_{k}\left\|p\right\|\_{0}\tag{15}$$

Using the idea of alternating optimization, we can obtain two independent subproblems to solve for *l* and *p*, respectively:

$$\mathbf{a}\mathbf{g}\mathbf{m}\mathbf{in}\_{l}\left\|l\otimes k-b\right\|\_{2}^{2}+\xi\left\|D(l)-p\right\|\_{2}^{2}+\lambda\left\|\nabla l-\mathbf{g}\right\|\_{2}^{2}\tag{16}$$

and

$$\arg\min\_{p} \mathcal{J} \|D(l) - p\|\_{2}^{2} + w\_{k} \|p\|\_{0} \tag{17}$$

Equation (16) contains all quadratic terms, and we can obtain its solution by the least squares method. In each iteration, the FFT (Fast Fourier Transform) is used to accelerate the computation process. Its closed-form solution is given as follows:

$$l = \mathcal{F}^{-1}\left(\frac{\overline{\mathcal{F}(k)}\mathcal{F}(b) + \xi\mathcal{F}(p) + \lambda\mathcal{F}\_{\xi}}{\overline{\mathcal{F}(k)}\mathcal{F}(k) + \lambda\overline{\mathcal{F}(\nabla)}\mathcal{F}(\nabla) + \xi}\right) \tag{18}$$

where F*<sup>g</sup>* = F(∇*v*)F(*gv*) + F(∇h)F(*g*h) and <sup>F</sup>(·) and <sup>F</sup> <sup>−</sup>1(·) are the Fast Fourier Transform (FFT) and its inverse, respectively. F(·) denotes the complex conjugate operator of FFT and ∇*<sup>v</sup>* and ∇*<sup>h</sup>* are gradients in the vertical and horizontal directions, respectively.

#### *4.3. Estimating p and g*

Equations (12) and (17) are minimization problems of the *L*<sup>0</sup> norm. Due to the difficulty of solving the *L*<sup>0</sup> norm minimization problem, we adopt the method described in Ref. [13]. As a result, the solution of Equation (17) can be expressed as:

$$p = \begin{cases} \begin{array}{cc} D(l), & D(l) \ge \frac{w\_k}{\delta} \\ 0, & \text{otherwise} \end{array} \end{cases} \tag{19}$$

Given *l*, the solution of Equation (12) can be expressed as:

$$\mathbf{g} = \begin{cases} \quad \nabla l, & |\nabla l|^2 \ge \frac{\theta}{\Lambda} \\ 0, & \text{otherwise} \end{cases} \tag{20}$$

#### *4.4. Estimating Blur Kernel k*

Since the updating of the blur kernel is an independent subproblem, we estimate *k* in the gradient space. Specifically, we obtain the solution to the blur kernel by minimizing the following problem though the known intermediate image *l*:

$$\|\mathbf{m}\mathbf{m}\_k\| \|\nabla l \otimes k - \nabla y\|\_2^2 + \gamma \|k\|\_2^2 \tag{21}$$

where ∇ denotes the gradient operation. Note that we use Equation (21) to estimate the blur kernel instead of Equation (9), which helps to suppress ringing artifacts and eliminate noise. The closed-form solution to Equation (21) is obtained by FFT.

$$k = \mathcal{F}^{-1}\left(\frac{\overline{\mathcal{F}(\nabla l)}\mathcal{F}(\nabla y)}{\overline{\mathcal{F}(\nabla l)}\mathcal{F}(\nabla l) + \gamma}\right) \tag{22}$$

The coarse-to-fine strategy is used in the process of blur kernel estimation, which is similar to that used in [26,45]. In the process of solving the problem, it is very important to restrict the small values of the blur kernel by thresholding at fine scale, which enhances the robustness of the algorithm to noise.

#### *4.5. Estimating Latent Sharp Image*

Although the latent sharp images can be estimated from Equation (18), this formulation is less effective for fine-texture details. For the purpose of suppressing ringing and artifacts, we fine-tune the final restored image. With the estimated blur kernel and blur input image *y*, we can use the nonblind deconvolution method to obtain the final latent sharp image *llatent*. Algorithm 1 summarizes the main steps of the final latent sharp image restoration method. Firstly, we estimate the restored image *lh* by the method in Ref. [46] using the hyper-Laplacian prior. Then we restore image *lr* according to the method in Ref. [47] using the total variation prior. Finally, the latent sharp image *llatent* is calculated by the average of the two restored images, i.e., *llatent* = (*lh* + *lr*)/2. The main steps of our proposed algorithm are summarized as Algorithm 2.

#### **Algorithm 1** Final latent sharp image restoration.

**Input:** Blurry image *b* and estimated kernel *k*.


```
llatent = (lh + lr)/2.
```
**Output:** Sharp latent image *llatent*.

#### **Algorithm 2** The proposed blind deblurring algorithm.

**Input:** Blurry image *y*;


**Output:** Sharp latent image *llatent*.

We first initialize the intermediate image *l* and blur kernel *k* according to the blurry input. Then we alternately update *l* and *k*. In order to avoid falling into a local minimum, our algorithm is executed in a coarse-to-fine manner. The results of the coarse layer are

up-sampled with the bilinear interpolation method as the initialization of the next fine layer. Finally, a latent sharp image is obtained by Algorithm 1 with the estimated blur kernel.

#### **5. Results**

We examine our method and compare it with the state-of-the-art BID methods on different image datasets, including a synthetic image dataset and real-world blurred images. We then evaluate the quality of deblurring models by different metrics, including the peak signal-to-noise ratio (PSNR, unit: dB), which is a measure of image quality, and cumulative error ratio (CER). The higher the CER value, the better the model.

In all the experiments, the parameter settings of our model are as follows: *μ* = *ϑ* = 0.003, *γ* = 2 and the size of image patch to compute the D/B channel is set to be 35. The maximum iteration is empirically set to 5 as a trade-off between accuracy and speed.

#### *5.1. Synthetic Image Deblurring*

We first test our method on the synthetic image dataset [44] for quantitative evaluations. This dataset includes 4 ground truth images and 12 different kernels. We compare our results with the state-of-the-art methods [11,14,18,19,21,27,48]. Our algorithm performs well with other methods on this benchmark dataset. Additionally, we present a challenging example in Figure 2. We record the largest PSNR calculated by comparing each restored result with 199 ground truth images captured along the camera shake trajectory in Figure 3. Since the proposed method considers not only BCP and DCP information but also the relationship between them, the PSNR values of the restored images achieved by our method are higher than those of the state-of-the-art algorithms [11,14,18,19,25,45,48–50].

**Figure 2.** Visual comparison of the results using one challenging image from dataset [44]. The image (**a**) is blurry input; (**b**–**h**) are deblurring results of Ref. [21], Ref. [27], Ref. [14], Ref. [48], Ref. [18], Ref. [19] and our proposed method, respectively.

We also test our algorithm against the competing methods [6,14,18,19,21,48,51,52] on another benchmark dataset [12], which includes four ground truth images and eight different kernels. One example is shown in Figure 4 with a visual result comparison against the state-of-the-art methods [18,19]. Although the image restored by Pan et al. [18] performs well against other approaches, the generated image still contains significant fake textures and blur regions in Figure 4b. The algorithm proposed by Yan et al. [19] considers both the DCP and BCP, but the generated result still has unclear edges, as Figure 4c shows. However, our method generates a sharp image with fine textures, as shown in Figure 4d. We can observe that the result is more visually pleasing than that of others. The main reason is that the enhanced edges in local patches help to remove the small textures and fine details. Figure 5a plots the cumulative error ratios of our method and the other competing methods. Note that our D/B-based method outperforms state-of-the-art algorithms by 100% under error ratio 2. All the experimental results consistently show that our method is competitive on this dataset.

**Figure 4.** A comparison of our method with state-of-the-art methods. The images (**a**–**d**) are blurry input, result of Pan et al. [18], result of Yan et al. [19] and our result, respectively. The PSNR values of (**b**–**d**) are 30.19, 30.33 and 32.15, respectively.

We further carry out experiments of our method against the state-of-the-art approaches [16,19] on text images from the dataset [16]. This dataset consists of 15 images and eight different kernels ranging in size from 13 × 13 to 27 × 27. Figure 6 visually shows that our method performs well on a challenging blurry image in comparison with [19] and the method designed for text images [16]. As shown in the figure, the DCP and ECP also help the blind deblurring of text images. Our deblurred result in Figure 6d utilizing the proposed D/B generates sharper edges and clearer text compared to other results [16,19]. Another text example is shown in Figure 7. Note that the text becomes extremely sharper after the deblurring process, which demonstrates that our proposed *L*<sup>0</sup> norm based on the D/B is helpful for kernel estimation and image deblurring. In particular, sharp text images contain more salient edges in local patches, which drives our D/B to perform well. Table 1 presents the average PSNR values of the deblurred results on the text image dataset [16] compared with the state-of-the-art methods. Our method achieves the maximum PSNR value.

**Figure 5.** Quantitative results of our method on two benchmark datasets [12,22]: (**a**) error ratios comparison between our approach and the other methods on the benchmark dataset [12]; (**b**) quantitative evaluations on the benchmark dataset [22].

**Figure 6.** A comparison of our method with state-of-the-art methods. The images (**a**–**d**) are blurry input, result of Pan et al. [16], result of Yan et al. [19] and our result, respectively.

**Figure 7.** Visual comparison of the results using one challenging image: (**a**) blurry image; (**b**–**h**) deblurring results generated by Ref. [14], Ref. [6], Ref. [52], Ref. [48], Ref. [19], Ref. [18] and our method, respectively. The recovered image by the proposed algorithm is visually more pleasing.


#### **Table 1.** PSNR values of state-of-the-art text image deblurring methods.

#### *5.2. Real Image Delurring*

In this part, we test our method on real-world blurred images against the recent state-of-the-art blind single image deblurring methods [11,14,18,19,21,48]. We analyze the deblurring results qualitatively as the blur kernels and ground truth images are unknown. Figure 8 shows one challenging real-world blurred image. The recovered images generated by the proposed algorithm are sharper and clearer than those generated by [11,14,18,19,21,48]. As shown in Figure 8, the blurry image contains large and small edges and textures, which causes trouble for deblurring with the methods designed for natural images. Pan et al. [18] exploited the dark channel and achieved encouraging results. However, the deblurred image still contains visually blurry artifacts. In contrast, by further utilizing the edge information in local patches, our method generates sharper and clearer image details compared with other methods as shown in Figure 8. As a second example, we present deblurring results on a challenging image in Figure 9. Note that our deblurred image has clear background and sharp edges against other results.

**Figure 8.** Visual comparison of the results using one challenging image. (**a**) is blurry input and (**b**–**h**) are generated by [11,14,18,19,21,48] and our proposed method, respectively.

**Figure 9.** An example of real-world image results. The images (**a**–**e**) are blurry input, result of Krishnan et al. [14], result of Pan et al. [18], result of Yan et al. [19] and our result, respectively

#### *5.3. The Effectiveness of Proposed Sparse Channel Prior*

In this subsection, experiments are conducted to verify the performance of the proposed D/B for blind image deblurring. As mentioned above, the proposed D/B regularization term considers the contrast and salient edges' information in local patches. To demonstrate the effectiveness of the proposed prior, we compare the proposed method with the DCP-based method [18] and the ECP-based method [19] in image deblurring. Figure 10 shows the changes of the DCP, the BCP and the proposed sparse channel prior in each phase of the image. Initially, the contrast and clarity of the DCP, the BCP and the proposed

sparse channel prior of the proceeding blurred images are very low, while the contrast of the middle layer is significantly improved and the final restored images have a higher contrast and sharper contour. At this time, the ringing and artifacts in the images are greatly reduced. Note in each stage, the proposed sparse channel prior has a clearer outline than the DCP and BCP. Compared with the literature [18], the proposed method estimates the blur kernels better with less artifacts. Figure 11 shows the quantitative evaluations on the benchmark dataset [12] by the ECP and our method with and without the proposed D/B. Note that the PSNR (Figure 11a) of the proposed D/B-based method is higher than that of the ECP and our method without D/B. Moreover, our method with the proposed D/B prior performs more favorably in terms of error ratio (Figure 11b) than without the D/B regularization, which further demonstrates the effectiveness of the proposed D/B-based methods. The proposed D/B-based algorithm generates the results with PSNR values higher than the other two methods.

**Figure 10.** Visual comparison of the intermediate results generated during iteration. (**a**–**c**) are intermediate results generated during iteration using the DCP, ECP and our sparse channel prior, respectively.

**Figure 11.** Quantitative results of our method on benchmark dataset [12]: (**a**) quantitative evaluations on the benchmark dataset by [12] and our method with and without D/B; (**b**) error comparison between our approach and the other methods.

In addition, our method has a higher success rate on the dataset [22], as shown in Figure 5b. All the results consistently demonstrate that the proposed sparse channel prior improves the deblurring performance.

#### **6. Discussion**

#### *6.1. Comparison with Other Related Methods*

In this part, we will discuss some methods most related to the algorithm in this paper. The dark channel prior was used by Pan et al. [18] for blind image deblurring. They enhanced the sparsity of the DCP and achieved good results on low-light images. Yan et al. [19] used the ECP to solve the problem that the DCP has less effect on sky images. However, the ECP is a simple addition of the DCP and BCP, and the relationship between them has not been deeply studied.

Figure 10 shows the intermediate images of three different methods (Refs. [18,19] and ours). Although the intermediate results become clearer and sharper as iterations increase, the images (Figure 10c) generated by our method have sharper edges and clearer contents than those of Refs. [18] (Figure 10a) and [19] (Figure 10b). Figure 12 shows the results generated by these three methods on some challenging images, including real blurred and low-light images. Our results have fewer blurred areas and ringings, which look more pleasant. Table 2 shows the error ratio of two related approaches [18,19] on dataset [22], and the proposed method fails on one image in which the error ratio value is larger than 4.

**Figure 12.** Deblurring results on some challenging examples. (**a**) Blurry inputs. (**b**–**d**) Deblurring results generated by Ref. [18], Ref. [19] and our method, respectively.

In order to analyze the three methods in more detail, we show the different maps of the DCP, the BCP and our D/B in Figure 13. Although the dark channel, bright channel and our D/B map of the recovered image all have improvement with respect to that of the corresponding blurry image, our D/B map improves more than the dark channel and bright channel. Moreover, our D/B map is clearer (higher contrast and sharper edges) than the dark channel and bright channel in both the blurry image and recovered image. All have improvement with respect to that of the corresponding blurry image.

**Table 2.** Quality evaluation of competitive methods on dataset [22] in terms of error ratio.


**Figure 13.** Visual comparison of different maps. (**a**) is blurry image. (**b**–**d**) are dark channel, bright channel and our D/B map of (**a**), respectively. (**e**) is recovered image. (**f**,**g**) are dark channel, bright channel and our D/B map of (**h**), respectively.

#### *6.2. Convergence Analysis*

Blind deconvolution is a highly ill-posed problem and we introduce a new spare prior to make the problem produce feasible results in this paper. The optimization scheme of our model is challenging, and with the idea of auxiliary variables and the alternating direction minimization (ADM) method, one may question the convergence. Thus, we show the traces of the objective function (computed from Equation (8)) and kernel similarity [53] on dataset [12] with respect to iterations in Figure 14. Figure 14a shows our method converges after less than 30 iterations, and Figure 14b shows the kernel similarity [53] becomes higher with more iterations. Overall, our method converges well after less than 30 iterations.

**Figure 14.** Convergence analysis of the proposed algorithm. (**a**) Energy value computed from Equation (8). (**b**) Average kernel similarity [53] becomes higher with more iterations.

#### *6.3. Running Time*

We simply explain the computational complexity through the running time of the algorithms. We select several competing algorithms closely related to this paper to run on the same database as this algorithm. All experiments were carried out under the same computer. The running time on different sizes of images is summarized in Table 3. As can be seen from the table, our algorithm is faster than [19] and slower than [14].


**Table 3.** Running time (s) of competing approaches.

#### **7. Conclusions**

In this paper, a novel, simple yet efficient image prior D/B for blind image deblurring is proposed, which builds on the DCP and BCP. An extensive investigation on natural images shows that the DCP behaves inversely to the BCP, and a large difference between the DCP and BCP preserves salient edges. For blind image deblurring, salient edges are helpful to estimate the blur kernel. In order to utilize the advantages of the DCP and BCP and further exploit the edge information in a local patch, we propose the D/B prior for image deblurring. The D/B prior preserves the main edges and eliminates the fine textures of intermediate latent images. Meanwhile, it retains the advantages of the DCP and BCP. The feasibility and effectiveness of using the D/B prior to estimate the blur kernel are discussed. The experimental results show that our algorithm is competitive with the state-of-the-art algorithms. In addition, experiments with our proposed prior show that it can significantly improve the performance of the deblurring algorithm.

**Author Contributions:** Conceptualization, D.Y. and X.W.; methodology, D.Y.; software, D.Y.; validation, D.Y., H.Y. and X.W.; formal analysis, D.Y.; investigation, D.Y., H.Y. and X.W.; resources, D.Y.; data curation, D.Y.; writing—original draft preparation, D.Y.; writing—review and editing, D.Y., H.Y. and X.W.; visualization, D.Y. and H.Y.; supervision, D.Y. and X.W.; project administration, X.W.; funding acquisition, X.W. All authors have read and agreed to the published version of the manuscript.

**Funding:** The research was funded by the National Natural Science Foundation of China (Grant No.62020106012, U1836218) and the 111 Project of Ministry of Education of China (Grant No. B12018).

**Institutional Review Board Statement:** Not applicable.

**Informed Consent Statement:** Not applicable.

**Data Availability Statement:** The data presented in this study are available on request from the corresponding author.

**Conflicts of Interest:** The authors declare no conflict of interest.

#### **References**


### *Article* **Geometric Metric Learning for Multi-Output Learning**

**Huiping Gao and Zhongchen Ma \***

The School of Computer Science & Communications Engineering, Jiangsu University, Zhenjiang 212013, China; huiping.gao@nuaa.edu.cn

**\*** Correspondence: 555mzc@163.com

**Abstract:** Due to its wide applications, multi-output learning that predicts multiple output values for a single input at the same time is becoming more and more attractive. As one of the most popular frameworks for dealing with multi-output learning, the performance of the k-nearest neighbor (kNN) algorithm mainly depends on the metric used to compute the distance between different instances. In this paper, we propose a novel cost-weighted geometric mean metric learning method for multi-output learning. Specifically, this method learns a geometric mean metric which can make the distance between the input embedding and its correct output be smaller than the distance between the input embedding and the outputs of its nearest neighbors. The learned geometric mean metric can discover output dependencies and move the instances with different outputs far away in the embedding space. In addition, our objective function has a closed solution, and thus the calculation speed is very fast. Compared with state-of-the-art methods, it is easier to explain and also has a faster calculation speed. Experiments conducted on two multi-output learning tasks (i.e., multi-label classification and multi-objective regression) have confirmed that our method provides better results than state-of-the-art methods.

**Keywords:** multi-output; kNN; metric learning; cost-weighted; geometric mean metric

**MSC:** 68T10

#### **1. Introduction**

In real-world applications, many machine-learning problems, e.g., multi-label learning and multi-target regression, involving diverse prediction can be classified as multi-output learning. Multi-output learning is an emerging machine-learning paradigm that aims to predict multiple output values of a given input at the same time [1]. For example, text documents or semantic scenes can be assigned to multiple topics; one sensor can output different environmental coefficients; a gene can have multiple biological functions; a patient may suffer from multiple diseases, and so on.

Let there be a multi-output training set D = {**x***j*, **y***j*|1 ≤ *j* ≤ *n*}, where *n* is the number of instances, **x***<sup>j</sup>* ∈ X and **y***<sup>j</sup>* ∈ Y are the feature vector and the output vector for the *<sup>j</sup>*-th instance, respectively, and X ∈ R*<sup>p</sup>* denote the *<sup>p</sup>*-dimensional input space and Y ∈ R*<sup>c</sup>* denote the output space with *<sup>c</sup>* output variables. Multi-output learning aims to learn a mapping function *h* : X→Y from D to assign an instance with a proper output vector. Compared with the traditional single-output learning, multi-output learning has a multivariate nature and its output values have diverse data types; thus it subsumes many learning problems in many real-world applications. For example, binary output values **<sup>y</sup>***<sup>j</sup>* ∈ {0, 1}*<sup>c</sup>* can refer to a multi-label classification problem [2] and real-valued outputs **<sup>y</sup>***<sup>j</sup>* ∈ R*<sup>c</sup>* to a multi-target regression problem [3].

As one of the the most popular frameworks for solving multi-output problems, it has been proven that the *k* nearest neighbor (*k*NN) algorithm's prediction performance can be significantly improved by learning a proper distance metric. For example, by imposing the constraint that two nearby instances from different classes will be pushed further apart

**Citation:** Gao, H.; Ma, Z. Geometric Metric Learning for Multi-Output Learning. *Mathematics* **2022**, *10*, 1632. https://doi.org/10.3390/ math10101632

Academic Editor: Daniel Gómez Gonzalez

Received: 13 March 2022 Accepted: 5 May 2022 Published: 11 May 2022

**Publisher's Note:** MDPI stays neutral with regard to jurisdictional claims in published maps and institutional affiliations.

**Copyright:** © 2022 by the authors. Licensee MDPI, Basel, Switzerland. This article is an open access article distributed under the terms and conditions of the Creative Commons Attribution (CC BY) license (https:// creativecommons.org/licenses/by/ 4.0/).

with a large margin, Gou et al. [4,5] show that the prediction performance of kNN can be greatly improved. For handling multi-label learing, Zhang et al. [6] proposed a novel maximum margin output coding (MMOC) method based on structural SVMs [7,8]. It learns a distance metric such that the instances with different multiple outputs will be moved far away. Unfortunately, the training and testing of MMOC are time-consuming, which involves both solving a box-constrained quadratic programming (QP) problem for each training sample and a QP problem on {0, 1}*<sup>c</sup>* space, respectively. Even if approximate inference is used to solve this QP problem, it is still computationally expensive. Inspired by kNN and MMOC, Liu et al. [9] proposed a large margin metric learning paradigm (LMMO) for multi-output tasks with only *k* nearest neighbor constraints, reducing the training computationally complexity from O(*nc*<sup>3</sup> + *npc*<sup>2</sup> + *<sup>n</sup>*4) of MMOC to O(*c*<sup>3</sup> + *knpc*2) for each iteration, and the testing computationally complexity from O(*c*3) of MMOC to O(*cn* + *pc*), thus significantly breaking the bottleneck of MMOC. Nevertheless, as the state-of-the-art metric learning method for multi-output learning, the LMMO algorithm adopts the accelerated proximal gradient (APG) method to train LMMO, but cannot directly obtain the optimal metric with a closed-form solution. To achieve an *ε*-solution, the number of iterations needed by APG update is at least <sup>O</sup>( <sup>√</sup><sup>1</sup> *ε* ). In order to obtain the metric with good performance, more iterations of APG are needed to obtain the more accurate solution.

Therefore, it is non-trivial to develop a gradient-free metric learning algorithm for a multi-output task. To achieve this goal, this paper presents a novel geometric mean metric learning method for multi-output tasks, which learns a cost-weighted metric such that the instances with very different multiple outputs will be moved far away. Our formulation also possesses several attractive properties: closed-form solution, ease of interpretability, and computational speed several orders of magnitude faster than the state-of-the-art method.

Our contributions are as follows. (1) We propose a novel geometric mean metric learning method for multi-output tasks, which possesses several attractive properties: closed-form solution, ease of interpretability, and computational speed several orders of magnitude faster than the state-of-the-art method. (2) Experiments conducted on two multi-output learning tasks have confirmed that our method provides better results than the state-of-the-art methods.

This paper is organized as follows. Section 2 gives related work. Section 3 presents our geometric mean metric learning method for multi-output tasks. The performance of our proposed method for MLC and MTR is evaluated in Section 4. Section 5 concludes the work.

#### **2. Related Work**

#### *2.1. Multi-Output Learning*

Multi-output learning is an important machine-learning paradigm, which subsumes many learning problems in many practical applications. This paper focuses on the following two most popular multi-output learning tasks, namely multi-label classification and multiobjective regression.

Multi-label Classification aims to predict multiple different labels of a single sample. It has become an attractive emerging field and can be used in many practical applications, such as document classification [10], image retrieval [11], and image annotation [12]. In the past few years, many multi-label classification algorithms have been proposed. According to [2], these methods can be roughly divided into two categories: problem transformation and algorithm adaptation. By transforming popular learning techniques, algorithm adaptation methods try to directly deal with multi-label learning problems. ML-*k*NN [13], ML-DT [14], and Rank-SVM [15] are the typical methods. By converting the original problem into other well-established learning problems, the problem transformation methods try to use off-the-shelf techniques to solve the problem. Binary relevance [16], random *k*-labelsets [17], calibrated label ranking [18] and classifier chains [19] are the representive methods.

Multi-target Regression aims to predict the values of multiple continuous target variables for a set of predictor variables. Similar to multi-label learning methods, multi-objective regression methods can also be roughly divided into two categories: algorithm adaptation and problem transformation [20]. Compared with the existing problem transformation methods, the algorithm adaptive methods usually generate a single multi-output model, which is easier to interpret and can be extended to a larger output space. On the other hand, by adopting suitable basic learners, the problem transformation methods can easily adapt to the the problem at hand, and it is found that it is generally better than the algorithm adaptive method in terms of accuracy [21].

#### *2.2. Metric Learning*

Given a set of a pair of similar/dissimilar points, metric learning aims to learn the distance metric to keep similar/dissimilar points close/away in the embedding space. The distance metric retains the distance relationship between the training data [22]. The previous works [23–25] show that designing appropriate metrics can significantly improve the *k*NN classification accuracy of single-output learning tasks and multi-output learning tasks.

Metric learning methods can be roughly divided into global distance metric learning and local distance metric learning. Global distance metric learning learns appropriate metrics to keep all data points in the same class close, while pulling instances of different classes away. The most representative methods are found in [26,27]. The second type of methods tries to learn the distance metric that satisfies the local pairwise constraints, which is particularly useful for the *k*NN classifier. The most representative methods are found in [9,28]. However, these methods usually use gradient-based optimization methods to obtain appropriate metrics. On the contrary, we proposed a novel cost-weighted geometric mean metric learning method for multi-output tasks. It learns a cost-weighted metric with a gradient-free optimization method. This makes the learned metric more accurate and the training procedure more efficient.

#### **3. The Proposed Method**

#### *3.1. Background*

Suppose we are given a multi-output training set with *n* instances, i.e., D = {**x***j*, **y***j*|1 ≤ *j* ≤ *n*}, where **x***<sup>j</sup>* ∈ X and **y***<sup>j</sup>* ∈ Y are the feature vector and the output vector for the *j*-th instance, respectively. Multi-output learning aims to learn a function *h* : X→Y from D to predict the corresponding output vector of an instance.

To address this problem, a linear regression model simply learns the matrix **W** according to the following formulation:

$$\min\_{\mathbf{W}\in\mathbb{R}^{p\times c}} \frac{1}{2} \|\mathbf{XW} - \mathbf{Y}\|\_{F'}^2\tag{1}$$

where ·*<sup>F</sup>* is the Frobenius norm, **<sup>X</sup>** ∈ R*n*×*<sup>p</sup>* is the input matrix and **<sup>Y</sup>** ∈ R*n*×*<sup>c</sup>* is the output matrix. However, due to a lack of modeling correlations of output space, this method usually yields low performance.

LMMO [9] learns a large margin metric to model correlations of output space. It forces the distance between input **W***T***x***<sup>i</sup>* and its corresponding output **y***<sup>i</sup>* to be smaller than the distance between **W***T***x***<sup>i</sup>* and the output **y** of the nearest neighbors of **x***<sup>i</sup>* with at least a margin, which is measured by Δ(**y***i*, **y**), the difference between **y***<sup>i</sup>* and **y**. The large margin metric learning formulation is formulated as follows:

$$\begin{array}{ll}\min\_{\mathbf{Q}\in S\_{\varepsilon}^{+}, \{\boldsymbol{\xi}\_{i}\geq 0\}\_{i=1}^{n}} & \frac{1}{2}\operatorname{trace}(\mathbf{Q}) + \frac{\boldsymbol{\zeta}}{n}\sum\_{i=1}^{n}\boldsymbol{\xi}\_{i}^{2} \\ \text{s.t. } \boldsymbol{\phi}\_{\mathbf{x}\_{i},\mathbf{y}\_{i}}^{T}\mathbf{Q}\boldsymbol{\phi}\_{\mathbf{x}\_{i},\mathbf{y}\_{i}} + \boldsymbol{\Delta}(\mathbf{y}\_{i},\mathbf{y}) - \boldsymbol{\xi}\_{i} \\ \leq \boldsymbol{\phi}\_{\mathbf{x}\_{i},\mathbf{y}}^{T}\mathbf{Q}\boldsymbol{\phi}\_{\mathbf{x}\_{i},\mathbf{y}\_{i}}\,\forall\,\mathbf{y} \in \mathit{Nci}(i), \forall i\end{array} \tag{2}$$

where *S*<sup>+</sup> *<sup>c</sup>* represents a *<sup>c</sup>* × *<sup>c</sup>* symmetric positive semidefinite matrix, *<sup>φ</sup>***x***i*,**y***<sup>i</sup>* = **<sup>W</sup>***T***x***<sup>i</sup>* − **<sup>y</sup>***<sup>i</sup>* , *<sup>ξ</sup><sup>i</sup>* is the slack variable, *C* is a positive constant that controls the trade-off between the square loss function and the regularizer and *Nei*(*i*) is the output set of *k* nearest neighbors of input instance **x***i*. The constraints in Equation (2) guarantee that the distance between **W***T***x***<sup>i</sup>* and its correct output **y***<sup>i</sup>* stays closer, but it enlarges the distance between **W***T***x***<sup>i</sup>* and any other output in the metric space.

However, as the state-of-the-art metric learning method for multi-output learning, the LMMO algorithm cannot directly obtain the optimal metric with a closed-form solution. To achieve an *<sup>ε</sup>*-solution, the number of iterations needed is at least <sup>O</sup>( <sup>√</sup><sup>1</sup> *ε* ). Thus, it is worth studying to further improve the computing efficiency of metric learning in multioutput learning.

#### *3.2. Proposed Formulation*

It is non-trivial to further obtain a closed-formed solution for LMMO. Inspired by GMML [29], we propose a novel metric learning method with a closed-form solution for multi-output learning, namely, geometric metric learning for cost-weighted multi-output learning (GCMoL), as follows:

$$\min\_{\mathbf{Q}\in\mathcal{S}\_{\varepsilon}^{+}}\sum\_{i=1}^{n}\left(\boldsymbol{\phi}\_{\mathbf{x}\_{i},\mathbf{y}\_{i}}^{T}\mathbf{Q}\boldsymbol{\phi}\_{\mathbf{x}\_{i},\mathbf{y}\_{i}}+\sum\_{\forall\mathbf{y}\in\mathcal{N}:\mathrm{ci}(i)}\Delta(\mathbf{y}\_{i},\mathbf{y})\boldsymbol{\phi}\_{\mathbf{x}\_{i},\mathbf{y}}^{T}\mathbf{Q}^{-1}\boldsymbol{\phi}\_{\mathbf{x}\_{i},\mathbf{y}}\right),\tag{3}$$

where *S*<sup>+</sup> *<sup>c</sup>* represents a *<sup>c</sup>* × *<sup>c</sup>* symmetric positive semidefinite matrix, *<sup>φ</sup>***x***i*,**y***<sup>i</sup>* = **<sup>W</sup>***T***x***<sup>i</sup>* − **<sup>y</sup>***<sup>i</sup>* and *Nei*(*i*) is the output set of *k* nearest neighbors of input instance **x***i*, and Δ(·) represents the cost functions of interest.

Compared with LMMO, in Equation (2), we have transformed several independent inequality constraints into a very uniform formulation. According to Lemma 1, the distance between input **W***T***x***<sup>i</sup>* and its correct output **y***<sup>i</sup>* increases monotonically in **G**, whereas the distance between **W***T***x***<sup>i</sup>* and the output **y** of the nearest neighbors of **x***<sup>i</sup>* decreases monotonically in **G**. By optimizing the object function in Equation (3), the distance between input **W***T***x***<sup>i</sup>* and its correct output **y***<sup>i</sup>* is naturally smaller than the distance between **W***T***x***<sup>i</sup>* and the output **y** of the nearest neighbors of **x***i*.

For GCMoL, Equation (3), it is cost-weighted of the distance between **W***T***x***<sup>i</sup>* and the output **y** of the nearest neighbors of **x***i*. Thus, by using the loss function Δ(·) the metric **G** can be learned in the cost-sensitive way. For simplicity, the loss functions Δ(·) = ·<sup>1</sup> is always used to measure the distance between different outputs for multi-label learning and mutli-target regression.

### **Lemma 1.** *Let* **<sup>A</sup>***,* **<sup>B</sup>** *be (strictly) positive definite matrices such that* **<sup>A</sup> <sup>B</sup>***. Then,* **<sup>A</sup>**−<sup>1</sup> ≺ **<sup>B</sup>**−1*.*

In the following, we further simplify the objective function in Equation (3). Let us define the following two matrices:

$$\mathbf{S} := \sum\_{i=1}^{n} \phi\_{\mathbf{x}\_i, \mathbf{y}\_i} \boldsymbol{\phi}\_{\mathbf{x}\_i, \mathbf{y}\_i}^T \tag{4}$$

$$\mathbf{D} := \sum\_{i=1}^{n} \sum\_{\forall \mathbf{y} \in Nci(i)} \Delta(\mathbf{y}\_{i\prime} \mathbf{y}) \boldsymbol{\phi}\_{\mathbf{x}\_{i},\mathbf{y}} \boldsymbol{\phi}\_{\mathbf{x}\_{i},\mathbf{y}}^{T} . \tag{5}$$

Then, the objective function in Equation (3) can be reformulated as

$$\min\_{\mathbf{G}} tr(\mathbf{G}\mathbf{S}) + tr(\mathbf{G}^{-1}\mathbf{D}).\tag{6}$$

The minimization problem (6) is both strictly convex and strictly geodesically convex (Theorem 3 of [29]), which is similar to problem (13) of [29]. It has a global optimal solution and a closed form solution as shown below:

$$\mathbf{G} = \mathbf{S}\_{\sharp\_{1/2}}^{-1} \mathbf{D} = \mathbf{S}^{-1/2} \left( \mathbf{S}^{1/2} \mathbf{D} \mathbf{S}^{1/2} \right)^{1/2} \mathbf{S}^{-1/2}. \tag{7}$$

Clearly, solution of (6) is the geometric mean between **S**−<sup>1</sup> and **D**. But the matrix **S** might sometimes be non-invertible or near-singular in practice. To address this issue, a regularizing term, which can be used to incorporate prior knowledge about the distance function, is added to the objective function,

$$\min\_{\mathbf{G}\succeq\mathbf{0}} \quad \lambda D\_{\text{sld}}(\mathbf{G}, \mathbf{G}\_0) + \text{tr}(\mathbf{G}\mathbf{S}) + \text{tr}\left(\mathbf{G}^{-1}\mathbf{D}\right),\tag{8}$$

where **G**<sup>0</sup> is the "prior" and *Dsld*(**G**, **G**0) is the symmetrized LogDet divergence, which is equal to

$$D\_{\rm sld}(\mathbf{G}, \mathbf{G}\_0) := \text{tr}\left(\mathbf{G}\mathbf{G}\_0^{-1}\right) + \text{tr}\left(\mathbf{G}^{-1}\mathbf{G}\_0\right) - 2c.\tag{9}$$

The minimization problem in (8) also has a closed-form solution,

$$\mathbf{G}\_{\rm reg} = \left(\mathbf{S} + \lambda \mathbf{G}\_0^{-1}\right)^{-1} \sharp\_{\frac{1}{2}} (\mathbf{D} + \lambda \mathbf{G}\_0). \tag{10}$$

From Equation (10), we can see that the solution is given by the midpoint of the geodesic joining **S** + *λ***G**−<sup>1</sup> <sup>0</sup> and **D** + *λ***G**0. From a geodesic viewpoint, assigning different weights to the matrices is also pivotal for the solution of (3). Therefore, we introduce a nonlinear cost guided by Riemannian geometry of the SPD manifold and obtain a weighted version of (3) below:

$$\min\_{\mathbf{G}\succ 0} h\_l(\mathbf{G}) := (1-t)\delta\_\mathcal{R}^2 \Big(\mathbf{G}, \mathbf{S}^{-1}\Big) + t\delta\_\mathcal{R}^2 (\mathbf{G}, \mathbf{D}),\tag{11}$$

where *t* is a parameter that determines the balance between the cost terms of *δ*<sup>2</sup> *R* **G**, **S**−<sup>1</sup> and *δ*<sup>2</sup> *<sup>R</sup>*(**G**, **D**). Moreover, *δ<sup>R</sup>* denotes the Riemannian distance

$$\delta\_{\mathbb{R}}(\mathbf{X}, \mathbf{Y}) := \left\| \log \left( \mathbf{Y}^{-1/2} \mathbf{X} \mathbf{Y}^{-1/2} \right) \right\|\_{\mathbb{F}} \quad \text{for } \mathbf{X}, \mathbf{Y} \succ \mathbf{0} \tag{12}$$

on SPD matrices.

The problem outlined in (11) is geodesically covex and its unique solution is the weighted geometric mean

$$\mathbf{G} = \mathbf{S}^{-1} \sharp\_t \mathbf{D}.\tag{13}$$

Similar to the regularized solution to problem (8), the solution to the regularized form of problem (11) is given by

$$\mathbf{G}\_{\text{reg}} = \left(\mathbf{S} + \lambda \mathbf{G}\_0^{-1}\right)^{-1}\_{\sharp \iota} (\mathbf{D} + \lambda \mathbf{G}\_0)\_{\iota} \tag{14}$$

for *t* ∈ [0, 1]. In the case where *t* = 1/2, it is equal to (10). Many approaches, e.g., Cholesky– Schur and scaled Newton methods, can be used for fast computation of Riemannian geodesics of SPD matrices. In this paper, we use the Cholesky–Schur method to implement the computation of Riemannian geodesics. The summary of our GCMoL algorithm for multi-output learning is presented in Algorithm 1.

#### **Algorithm 1:** GCMoL.


#### *3.3. Prediction*

In the metric space, our metric learning formulation can make the input **W***T***x***<sup>i</sup>* and its correct output **y***<sup>i</sup>* as close as possible. For a new test instance **x**, we can obtain its output by a decoding method. In general, the decoding process requires solving the QP problem on a combinatorial space [6], which is computationally expensive. In this paper, we follow the same prediction method as in [9]. Specifically, we find *k* nearest neighbors for a new testing input instance **x** in our learned metric space, where the distance between **x** and **x***<sup>i</sup>* can be computed as (**W***T***<sup>x</sup>** − **<sup>W</sup>***T***x***i*)*T*)**G**(**W***T***<sup>x</sup>** − **<sup>W</sup>***T***x***i*). Then, we conduct voting based on weighted nearest neighbors for the prediction. In particular, for multi-label classification problems, we set 0.5 as the threshold.

#### *3.4. Complexity Analysis*

In this subsection, we compare the training and testing time complexity of different methods.

#### 3.4.1. Training Time

The training of MMOC involves an exponential number of constraints and solving a box-constrained QP problem for each training instance. The authors therefore use the over-generating technique with the cutting plane method and CVX (http://cvxr.com/cvx/, accessed on 8 March 2022) to solve these problems, respectively. Because MMOC is optimized based on the gradient method, it is assumed that this method iterates *η* times at least to get the desired performance. From Liu et al. [9], the training time complexity of MMOC is at least O(*nc*<sup>3</sup> + *npc*<sup>2</sup> + *<sup>n</sup>*4) for each iteration. Therefore, the total training time complexity of MMOC is at least O(*ηnc*<sup>3</sup> + *<sup>η</sup>npc*<sup>2</sup> + *<sup>η</sup>n*4). The training time of LMMO is dominated by the APG algorithm. To achieve an *ε*-solution, the number of iterations needed by the APG update is <sup>O</sup>( <sup>√</sup><sup>1</sup> *ε* ). According to Liu et al. [30] , the time complexity for each iteration is O(*c*<sup>3</sup> + *knpc*2). Therefore, the total training time complexity of LMMO is at least <sup>O</sup>( <sup>√</sup><sup>1</sup> *ε c*<sup>3</sup> + <sup>√</sup><sup>1</sup> *ε knpc*2). The training time of our method (GCMoL) is dominated by the computation of Riemannian geodesics for SPD matrices. Many approaches, e.g., Cholesky–Schur and scaled Newton methods, can be used for fast computation of Riemannian geodesics of SPD matrices. Following [29], we use the Cholesky–Schur method to implement the computation of Riemannian geodesics. So, the time complexity of GCMoL is O(*c*<sup>3</sup> + *knpc*2).

#### 3.4.2. Testing Time

We analyze the testing time for each testing instance. Because the test time of MMOC involves solving a QP problem on the {0, 1}*<sup>c</sup>* space, which is essentially a combinatorial optimization problem, it is very intractable. To address this problem, MMOC uses a meanfield approximation to iteratively obtain approximate solutions. The time complexity of each iteration of the average approximate field is O(*c*2). If it iterates many times until

convergence, its time complexity is at least O(*c*3). Both LMMO and our method (GCMoL) use the same prediction method and therefore have the same prediction time complexity, i.e., O(*nc* + *pc*) .

#### **4. Experiments**

In this section, we extensively compared the proposed GCMoL method with related approaches on real-world multi-label classification and multi-target regression datasets. All the methods compared are implemented in MatLab. All experiments are conducted on a desktop with a 3.2 GHZ Intel CPU and 32 GB main memory running on a Windows platform.

#### *4.1. Experimental Setup*

(1) Datasets: We conduct experiments on five benchmark multi-label datasets (http:// mulan.sourceforge.net/, accessed on 9 March 2022), including emotions, scene, cal500 and genbase, and four benchmark multi-target regression datasets (http://mulan.sourceforge. net/, accessed on 9 March 2022), including edm, enb, jura, and scpf. We summarize the dataset details in the Table 1, where |*S*| represents the number of examples, *dim*(*S*) represents the number of features, *L*(*S*) represents the number of class labels, and *Card*(*S*) represents the average number of labels per example, *Dom*(*S*) represents the feature type of the dataset *S*, and *Cat* represents the type of task category.

**Table 1.** Characteristics of datasets.


(2) Evaluation Metrics: To testify to the performance, we focus on two evaluation metrics, i.e., Micro-F1 and Macro-F1, for multi-label classification datasets, and one evaluation metric, i.e., aRMAE, for multi-target regression datasets. For Micro-F1 and Macro-F1, the larger the values the better the performance. Their concrete metric definitions are defined in [2]. For aRMAE, the smaller the values the better the performance. It is defined as:

$$aRMAE(\mathbf{h}, \mathbf{D}) = \frac{1}{m} \sum\_{j=1}^{m} \frac{\sum\_{(\mathbf{x}, \mathbf{y}) \in D} |\hat{y}\_j - y\_j|}{\sum\_{(\mathbf{x}, \mathbf{y}) \in D} |\hat{Y}\_j - y\_j|},\tag{15}$$

where *Y*¯ *<sup>j</sup>* is the mean value of *Yj* over dataset **D** and *y*ˆ*<sup>j</sup>* is the prediction of **h** for *Yj*. Intuitively, aRMAE measures how much better (*aRMAE* < 1) or worse (*aRMAE* > 1) the prediction model is compared to a naive baseline that always predicts the mean value of each target.

(3) Comparing Methods: We compare our proposed method GCMoL with the following state-of-the-art multi-output learning methods.


• LMMO [9] is a recently proposed large-margin metric learning method for multioutput tasks. It projects both input and output into the same embedding space, and then learns a distance metric to keep instances with the same output close and instances with very different outputs farther away. Its formulation is presented in Equation (2) and can only be used for multi-label learning task. Parameter *λ* is selected from {10<sup>−</sup>5, 10−4, ··· , 104, 105}.

The hyper-parameters in compared methods are selected via 10-fold cross-validation on the training set. The parameter *<sup>λ</sup>* is selected from {10−5, 10−4, ··· , 104, 105}, and *<sup>t</sup>* is selected from {0.2, 0.5, 0.7}. We adopt *k* = 10, which yields the best performance.

#### *4.2. Experimental Results*

Detailed experimental results are reported in Table 2, where the performance rank on each dataset is also shown in the parentheses. Moreover, to show whether GCMoL achieves statistically superior performance against compared approaches, we employ a Nemenyi test (at 0.05 significance level) whose statistical test results are summarized in Figure 1. The performances between two methods will be significantly different if their average ranks differ by at least one critical difference *CD* = *q<sup>α</sup> k*(*k* + 1)/6*N*. For the Nemenyi test, *q<sup>α</sup>* = at significance level *α* = 0.05, and thus *CD* = 1.2075(*k* = 4, *N* = 12). In Figure 1, the connected algorithms indicate that their average rank difference is within one CD. Any unconnected pair of algorithms is considered to have a significant difference in performance.


**Table 2.** Experimental results for multi-output learning. The best ones are in bold.

**Figure 1.** Comparison of GCMoL against other comparing algorithms with the Nemenyi test.

Based on the reported experimental results, the following observations can be made: (1) Regarding Micro-F1 of the MLC task, GCMoL is basically better than other methods and only slightly inferior to MLkNN on the yeast dataset. (2) Regrading Macro-F1 of the MLC task, GCMoL is always better than other methods. (3) Regarding aRMAE of the MTR task, GCMoL is also basically better than other methods and only slightly inferior to BR on the wq dataset. (4) According to the Nemenyi test results, MLKNN, LMMO, and BR perform not significantly differently from each other, but GCMoL performs significantly better than other methods, which verifies the effectiveness of our method.

#### *4.3. Analysis*

#### 4.3.1. Hyper-Parameter Sensitivity Analysis

There are two hyper-parameters, i.e., *λ* and *t*, in our proposed method. To give their sensitivity analysis, we conduct experiments on CAL500 and edm datasets. The experimental results of GCMoL with different values of *λ* and *t* are depicted in Figure 2a–f. From the experimental results, we note that the performance of GCMoL is relatively insensitive to the value of *λ* and *t*.

**Figure 2.** *Cont*.

**Figure 2.** Sensitivity analysis about GCMoL with different *λ* and *t*. (**a**) Micro-F1 scores of different *λ* on CAL500 dataset. (**b**) Micro-F1 scores of different *t* on CAL500 dataset. (**c**) Macro-F1 scores of different *λ* on CAL500 dataset. (**d**) Macro-F1 scores of different *t* on CAL500 dataset. (**e**) aRMAE of different *λ* on edm dataset. (**f**) aRMAE of different *t* on edm dataset.

#### 4.3.2. Time-Comsuming Analysis

To further compare the time consumption of different methods, Figure 3 reports the single training time of 10-fold cross validation of our method and baseline approaches in terms of CAL500 dataset. The results illustrate that BR, MLkNN, and GCMoL complete the training in 3 s, but LMMO lasts more than 240 s. Our method is almost 100 times faster than LMMO. In Section 3.4, the results of theoretical analysis for the time complexity show that our method runs slower than BR and MLkNN, but faster than LMMO. The experimental results also confirmed this conclusion.

**Figure 3.** Running time results of different methods on the CAL500 dataset.

#### **5. Conclusions**

We proposed a novel cost-weighted geometric mean metric learning method for multioutput tasks in this paper. Our method can model output dependency by the learned geometric mean metric, which can make the instances with very different outputs far away. It also admits a closed-form solution and computational speed several orders of magnitude faster than the state-of-the-art LMMO method. Experiments show that our method outperforms the state-of-the-art methods on multi-output learning tasks.

There are several directions worth exploring further in our future work. First, we will try to design a novel robust geometric metric learning method to generalize our technique to weakly supervised multi-output learning task. In weakly supervised multioutput learning tasks, missing or noisy supervision information may bring great challenges to metric learning. Secondly, we will try to design new weakly supervised contrastive learning methods to effectively apply self-supervised learning techniques to a multi-output learning task.

**Author Contributions:** Methodology, H.G.; Writing—review & editing, Z.M. All authors have read and agreed to the published version of the manuscript.

**Funding:** This work is supported by the National Natural Science Foundation of China (No. 62006098) and the China Postdoctoral Science Foundation (No. 2020M681515).

**Conflicts of Interest:** The authors declare no conflict of interest.

#### **References**


## *Article* **Decoupling Induction and Multi-Order Attention Drop-Out Gating Based Joint Motion Deblurring and Image Super-Resolution**

**Yuezhong Chu, Xuefeng Zhang and Heng Liu \***

School of Computer Science and Technology, Anhui University of Technology, Ma'anshan 243002, China; yzchu@ahut.edu.cn (Y.C.); zxf\_06@ahut.edu.cn (X.Z.)

**\*** Correspondence: hengliu@ahut.edu.cn

**Abstract:** Resolution decrease and motion blur are two typical image degradation processes that are usually addressed by deep networks, specifically convolutional neural networks (CNNs). However, since real images are usually obtained through multiple degradations, the vast majority of current CNN methods that employ a single degradation process inevitably need to be improved to account for multiple degradation effects. In this work, motivated by degradation decoupling and multiple-order attention drop-out gating, we propose a joint deep recovery model to efficiently address motion blur and resolution reduction simultaneously. Our degradation decoupling style improves the continence and the efficiency of model construction and training. Moreover, the proposed multi-order attention mechanism comprehensively and hierarchically extracts multiple attention features and fuses them properly by drop-out gating. The proposed approach is evaluated using diverse benchmark datasets including natural and synthetic images. The experimental results show that our proposed method can efficiently complete joint motion blur and image super-resolution (SR).

**Keywords:** motion deblurring; image super-resolution; multi-order attention; gated learning; decoupling

**MSC:** 37M99

#### **1. Introduction**

Motion blur and resolution decrease are the two dominant forms of image quality degradation. The former is caused by the relative motion between the camera and the object, while the latter is generally originated by down-sampling. The inverse processes of these degradation forms are individual motion deblurring and image SR—recovering clear images from blurred ones or reconstructing high-resolution (HR) images from low resolution (LR) ones, respectively, which are the practical main means to deal with image quality degradation.

Assuming the original sharp image is *x*, and the blurred image is *y*; if ignoring the effect of the non-linear camera response function (CRF), theoretically the motion blur degradation may be represented as:

*y* = (*x* ∗ *h*) + *n* , (1)

where *h* represents the motion blur kernel, ∗ denotes the convolution operation and *n* usually indicates the random noise. According to Equation (1), obviously, the inverse motion deblurring process is a typical ill-conditioned problem because for one clear image there are possibly many blur images corresponding to it.

Actually, there are two different implementation methods for motion deblurring, namely, blind deblurring or the non-blind method. The usual non-blind method acquires the clear image *x* based on the estimated blur kernel and the observation *y*. However, the major difference in blind deblurring is that no kernel estimation is required. Due to the

**Citation:** Chu, Y.; Zhang, X.; Liu, H. Decoupling Induction and Multi-Order Attention Drop-Out Gating Based Joint Motion Deblurring and Image Super-Resolution. *Mathematics* **2022**, *10*, 1837. https://doi.org/10.3390/ math10111837

Academic Editors: Jianping Gou, Weihua Ou, Shaoning Zeng, Lan Du and Catalin Stoean

Received: 11 March 2022 Accepted: 24 May 2022 Published: 26 May 2022

**Publisher's Note:** MDPI stays neutral with regard to jurisdictional claims in published maps and institutional affiliations.

**Copyright:** © 2022 by the authors. Licensee MDPI, Basel, Switzerland. This article is an open access article distributed under the terms and conditions of the Creative Commons Attribution (CC BY) license (https:// creativecommons.org/licenses/by/ 4.0/).

end-to-end mapping and the powerful approximation properties, CNNs are particularly well suited for blind motion deblurring. For example, some recent CNN-based works [1,2] are presented for blind deblurring duties.

Compared with motion blur degradation, in the process of resolution degeneration, in addition to the low pass blur filter *k* and the noise *n*, there is another degradation downsampling operator at work. For an HR image *x*, a typical resolution degradation model to acquire the corresponding LR image *y* is formulated as:

$$\mathbf{y} = (\mathbf{x} \* \mathbf{k}) \downarrow \mathbf{s} + \mathbf{n} \cdot \tag{2}$$

where *k* denotes a low pass blur filter, ∗ denotes the convolution operation, ↓ *<sup>s</sup>* indicates a *s*× down-sampling (decimating) operation, and *n* represents the noise. Obviously, the inverse problem of Equation (2)—image SR—is also a typical ill-posed one as there is usually a non-unique solution.

Motion blur degradation is superficially seen to be a simpler problem than resolution degeneration due to there being no down-sampling operation. However, motion blur is most likely to be non-linear or non-uniform, which is usually more complex than resolution degradation (the blur kernel is generally linear and uniform). This makes it a difficult challenge to directly estimate the blur kernel used for non-blind deblurring.

In recent years, deep learning-based networks, especially CNN-based methods, have been the mainstream of image SR and motion deblurring research, such as [3–5]. Although CNN-based image SR and deblurring methods have reported fairly good results, CNNbased image restoration is far from simple when the resolution and motion blur are reduced simultaneously. In this case, either image SR or motion deblurring does not work well due to degraded convolution or the blur kernels not being equivalent, and the two degradation processes are not complementary with each other when they occur simultaneously.

Recently, there have been some CNN works [6–8] that have addressed simultaneous image SR and motion deblurring. All of these methods explicitly or implicitly adopt a global or local feature coupling structure—a deblurring part and an SR part are involved or intervene with each other, to recover the resolution and the motion details at the same time. Actually, these recovery methods only construct different comprehensive CNN mapping networks from the degraded images to the corresponding sharp and high-resolution ones, but do not fully utilize the characteristics of motion deblurring and image SR to achieve decoupling. Therefore, even if the results of these methods are good, they lack an explanation and have low efficiency.

On the other hand, typical deep image recovery models always use the residual connection to convey features. However, due to a lack of ability to mine the feature information across different layers, some complex residual variants are proposed, such as DRRN [9] and RDN [10], etc. Among them, RDN (Residual Dense Network) is representative, which uses not only local dense residual learning but also global residual learning, to extract and adaptively fuse the local and global features from all the observed layers. Since RDN makes full use of multiple hierarchical features, it is very beneficial to construct an image restoration model. In addition to the work on feature learning across different layers, recent attention mechanism-based methods, for example, RCAN [11] and SAN [12], select and enhance useful and important channel feature maps of the same layer through weighting for image or video restoration. In fact, these channel feature attention methods utilize the first- or second-order statistics of the channel maps of certain layers to calculate the dependence of channel features, and then select and weigh the important features. However, single firstor second-order feature statistics cannot make full use of the relationship between different channel feature maps.

In order to overcome the limitations of coupling recovery and single-order attention feature weighting, in this work, we first analyze the compound multiple degradation model of motion blur and resolution reduction and discuss the maximum likelihood (ML) solution of the degradation model. Then, based on the analysis, we discuss decoupling induction multi-task learning and the CNN model construction method for multiple degradation image restoration. In addition, we obtain the first-order and second-order attention features of the decoupled structures for motion deblurring and SR, respectively, and obtain the third-order attention features by combining local series and parallel features. On this basis, for the sake of improving the feature redundancy and generalization ability of multi-order attention fusion, we utilize the drop-out gating integration method, which enhances the robustness and stability of the proposed multi-order attention mechanism.

An example result of the proposed method to deal with compound degeneration (motion blur as well as 4× down-sampling) is shown in Figure 1, where the comparisons to those results of RCAN [11] and SCGAN [6] are also demonstrated. The dominant contributions of the work are summed up as follows:


(**a**) (**b**)

**Figure 1.** An example result of the presented decoupling induction and multi-order attention gating model for joint deblurring and 4× super-resolution. The details in the recovered image of our proposed method (**d**) are much clearer than those of RCAN [11] (**b**) and SCGAN [6] (**c**); the LR and blurred image is shown (**a**).

#### **2. Related Work**

#### *2.1. Joint Image Deblur and SR*

Compared with the traditional image SR methods, the first CNN-based image SR method, SRCNN [3,4], proposed by Dong et al., can generate more accurate HR details owing to powerful non-linear mapping. To extend three convolutional layers of SRCNN to a deeper level, Kim et al. [13] presented a true deep image SR model called VDSR via residual connection [14]. Recently, Liu et al. [15] also proposed a multi-scale deep encoder–decoder network called MSDEPC to super resolve LR images with the edge maps' prior information. In addition, Ledig et al. [16] proposed the application of a generative adversarial network (GAN) [17] for image SR, called SRGAN. SRGAN takes the perceptual loss and the adversarial loss to supervise the reconstruction of super-resolved images and can obtain more realistic SR results.

CNNs also play an effective role in motion deblurring. Xu et al. [18] and Sun et al. [1] developed some CNN-based methods to recover blurred images based on blur kernel estimation. Besides these non-blind deep methods, some deep blind deblurring methods [2,5] are also proposed. Nah et al. [2] applied a multiple scales CNN to recover clear images directly. Motivated by the work, Tao et al. [5] designed a simple structure motion deblurring network characterized by scale recursion. Moreover, inspired by the work of image translation [19], Ramakrishnan et al. [20] first applied GAN for motion deblurring. Then, Kupyn et al. [21] proposed DeblurGAN for blind motion deblurring, which utilizes the WGAN [22] with a gradient penalty to avoid the mode collapse issue in the classical GAN. Subsequently, Kupyn et al. [23] presented a new and very efficient GAN-based model for single image motion deblurring, named DeblurGAN-v2, which is based on a relativistic conditional GAN with a double-scale discriminator. Furthermore, for meteorological prediction application, Manzo et al. [24] adopted a pretrained deep network-based architecture for clouds' image description and classification. Recently, in order to address the problem that blurred images suffer from other degradation such as down-scaling and compression, Xu et al. [25] proposed the enhanced deep pyramid network (EDPN) model for blurry image restoration, by fully exploiting the self-scale and cross-scale similarities.

Few works can use CNNs for simultaneous motion deblurring and SISR. Xu et al. [6] solve the problem of super-resolving blurred facial images by SCGAN. However, their method is restricted to facial images, and it is not easy to obtain a good performance in real scenarios. Zhang et al. [7] proposed using a deep encoder–decoder model to perform joint motion deblurring and image SR. Zhang et al. [8] once again proposed a gated fusion method for concurrent motion deblurring and image SR. Recently, Liang et al. [26] utilized the dual supervised network to address this issue. However, they did not achieve satisfactory results. In addition, for plug-and-play image SR, Zhang et al. [27] proposed a new blind SR framework to achieve the processing of arbitrary blur kernels. In addition, Zhang et al. [28] proposed a dual supervised learning strategy to fully exploit the representation capacity of their deep model, which imposes constraints between LR and HR images.

#### *2.2. Attention*

In addition to feature transfer by residual connection, the attention mechanism is another widely used method for feature preservation and enhancement used in many image SR models [11,12,29,30]. Zhang et al. presented the RCAN [11] (residual channel attention network) model that utilizes channel attention with residual blocks to adjust the task adaptability of channel features and to strengthen their expression ability. Since RCAN only uses the first-order statistical information of channel features, Dai et al. [12] presented the so-called SAN (second-order attention network) model, which replaces the global average pooling with the global covariance pooling (second-order statistics) to obtain a better effect of channel features' selection and enhancement. Very recently, Niu et al. [31] designed a novel pixel-guided dual-branch attention network (PDAN) to jointly restore image details and the spatial scale.

In addition, Wang et al. [29] proposed the extraction and fusion of temporal and spatial attention features for video restoration. Furthermore, Fu et al. [30] introduced a dual attention network—containing one spatial branch and one channel branch for scene segmentation, which can adaptively extract and integrate the local and non-local features of spatial and channel attention.

#### **3. Methodology**

#### *3.1. Multiple Degradation Decoupling Induction*

For motion deblurring and image SR, we used the following equations to describe the corresponding degradation models, which are used to generate the LR and blur images for training.

$$\mathbf{y} = \left(\sum\_{i=1}^{N} \mathbf{x}\_i\right) / \mathbf{N} + \mathbf{n} \tag{3}$$

$$y = (\mathfrak{x} \downarrow\_{\mathfrak{s}}) + \mathfrak{n},\tag{4}$$

where Equation (3) represents one typical motion degradation of a certain image sequenceaveraging blur and Equation (4) denotes the process of image resolution reduction. Here, *xi* and *y* in Equation (3) represent a sharp image of one clear HR image sequence (the image number of the sequence is *N*) and the corresponding blur image, respectively; and *x* and *y* in Equation (4) are the HR image and the corresponding LR image, respectively. *N* in the equation denotes the additional noise (normally it is Gaussian white noise). ↓ *<sup>s</sup>* is *s*× the down-sampling operator (can be bicubic sub-sampling).

Based on Equations (3) and (4), the motion blur and the resolution reduction compound degeneration may be formulated as

$$y = \left( \left( \sum\_{i=1}^{N} x\_i \right) / N \right) \downarrow s + n \tag{5}$$

Obviously, averaging *N* frame images lead to blurring degradation and the subsequent down-sampling operation also reduces the resolution of the generated blur image.

Theoretically, if the frame averaging blur kernel and the spatial down-sampling kernel are denoted as *h* and *k*, respectively, Equation (5) can be generalized as the following:

$$y = (\mathbf{x} \* h) \* k + n,\\
y = (\mathbf{x} \* \mathbf{S}) + n \tag{6}$$

Here the kernel convolution *h* ∗ *k* is defined as a new kernel *S*. This equation means the comprehensive function of multi-degradation basically equals one blur operation. Moreover, according to Equation (6), the residual *r* between the sharp HR image *x* and the degraded observation *y* is easily calculated. Assuming the image data obey the Gaussian distribution, a solution of maximum likelihood estimation (MLE) for Equation (6) can be obtained by *<sup>x</sup>* <sup>=</sup> *<sup>y</sup>* <sup>+</sup> *<sup>r</sup>*. Naturally, if the residual *<sup>r</sup>* is looked upon as the high-frequency details of *x*, the observation *y* becomes its approximation component. Here, if assuming the details *r* can be decoupled into the deblurring details *rdb* and the SR details *rsr* that is *r* = *rdb* + *rsr*, the MLE solution is further expressed as

$$
\widetilde{\mathbf{x}} = \mathbf{y} + \mathbf{r}\_{db} + \mathbf{r}\_{sr} \tag{7}
$$

According to Equation (7), if we can obtain the deblurring and the SR details individually through deep decoupling induction learning, then the original clear and HR image can be recovered. Moreover, although changing the sequence between motion blur and resolution reduction will lead to the multiple degraded models being different from Equation (5), the MLE solution with decoupling details (Equation (7)) remains the same. This indicates that our proposed decoupling induction method is robust to different sequences of multiple degenerated images.

#### *3.2. Multi-Order Attention Gating*

The decoupled deblurring features and SR features were then exploited to calculate the first-order channel attention (FOCA) and the second-order channel attention (SOCA), respectively. Meanwhile, their SOCAs were concatenated to calculate the FOCA again, which acquires the so-called third-order channel attention (TOCA). Then, all the acquired multiple order attentions were fused with multi-routes gating. Closing a route means that the corresponding feature attention is blocked and cannot be used for subsequent reconstruction. In fact, we used the drop-out mechanism—a probability of 0.5 was used to turn off some feature attentions randomly. The above processes are called multi-order attention drop-out gating. We used a similar method to calculate the FOCA and the SOCA, as explored in RCAN [11] and SAN [12]. Based on the principles of the SOCA and FOCA, we give the mathematical description of the third-order channel attention (TOCA) and multi-order attention drop-out gating learning in the following.

Given the deblurring feature maps *xdb* and the SR feature maps *xsr*, assume they are with *C* feature channels and size *H* × *W*. Note that the channel size of *xdb* and *xsr* does not need to be equal and in the following we just take *xdb* as an example. We reshape the feature map *xdb* to a matrix *X* with the size *H* × *W*; each element of which is *C* dimension. Here, if we treat the feature elements as samples, then the covariance matrix may be calculated and decomposed by the eigenvalue decomposition (EIG) as:

$$\sum \mathbf{X} \mathbf{I} \mathbf{X}^T = \mathbf{U} \boldsymbol{\Lambda} \mathbf{U}^T \tag{8}$$

where *I* = <sup>1</sup> *s <sup>I</sup>* <sup>−</sup> <sup>1</sup> *s* 1 , *s* = *H* × *W*, and *I* and **1** are the identity matrix and the all-ones matrix, respectively. In addition, *U* is an orthogonal matrix and **Λ** = *diag*(*λ*1,..., *λC*) is a diagonal matrix with eigenvalues in decreasing order. Then, the normalized covariance matrix can be acquired as *Y*ˆ = ∑*<sup>α</sup>* = *U***Λ***αUT*; *α* is a positive real number. Obviously, the normalized covariance contains the correlations of channel-wise features. Let *<sup>Y</sup>*<sup>ˆ</sup> = *<sup>y</sup>*1,..., *yC*; the *c*-th channel-wise statistics *zc* can be obtained by global pooling *Y*ˆ as:

$$z\_{\mathfrak{c}} = \frac{1}{\mathcal{C}} \sum\_{i=1}^{\mathcal{C}} y\_{\mathfrak{c}}(i) \tag{9}$$

Based on the equation, the feature weighting coefficient can be obtained through a simple sigmoid gating function [32] as:

$$
\omega\_{\mathfrak{c}} = f(\mathsf{W}\_{\mathrm{II}}\delta(\mathsf{W}\_{\mathrm{D}}\mathsf{z}\_{\mathfrak{c}})) \tag{10}
$$

where *WD* and *WU* are usually the convolution layers to adjust the number of feature channels to *C*/*r* and *C*, respectively. *f*(·) and *δ*(·) are individually the sigmoid functions and RELU function. Thus, for deblurring feature *xdb* the second-order channel attention (SOCA) weighting is represented as:

$$
\mathfrak{X}\_{db} = \omega\_{\mathfrak{c}} \cdot \mathfrak{x}\_{db,\mathfrak{c}} \tag{11}
$$

Based on the equation, the SOCA weighting for image SR features *xsr*, can be similarly described as *xsr* = *ωc*·*xsr*,*c*. Then, *xdb*,*<sup>c</sup>* and *xsr*,*<sup>c</sup>* are concatenated and passed through the FOCA to obtain the final TOCA. Let *xcat* = *concat*(*xdb*, *xsr*) = [*xcat*,1,..., *xcat*,2*C*]; we calculate the global average pooling along each channel dimension and then transform the statistics with channel scaling convolution layers and proper activation functions to obtain the FOCA weighting, which can be described as:

$$z\_{\text{toa},\mathbf{c}} = \frac{1}{H \times W} = \sum\_{i=1}^{H} \sum\_{j=1}^{W} \text{concat}(\overline{\mathbf{x}}\_{\text{dlb}}, \overline{\mathbf{x}}\_{\text{sr}})\_{\mathbf{c}}(i, j), \tag{12}$$

$$S\_{toa,c} = f(\mathcal{W}\_S \delta(\mathcal{W}\_I z\_{toa,c})) \tag{13}$$

where *xcat*,*c*(*i*, *j*) is the value at the position (*i*, *j*) of the *c*-th concatenated SOCA features *xcat*, and *WS* and *WI* are the channel up-scaling and down-scaling convolution layers, similar to *WU* and *WD* in Equation (10). Finally, the third-order channel attention (TOCA) weighting can be denoted as:

$$
\dot{\mathfrak{x}}\_{toa,\mathcal{L}} = \mathbb{S}\_{toa,\mathcal{L}} \cdot \mathfrak{x}\_{\mathbf{c}cat} \tag{14}
$$

If the FOCA of the deblurring features and SR features are denoted as . *xdb* and . *xsr*, respectively, then all the multi-order attention features, . *xdb, xdb,* . *xsr, xsr*, and *x*ˆ*toa*, are sent to one five-routes gate for fusion. The gate works with the drop-out mechanism. Let the *j*-th route switch be a random variable *rj* and obey the Bernoulli distribution with the parameter *p* (which is set to 0.5 in our practice)—that is *rj* **~** *Bernoulli***(***p***)**—and then, all the attention that can pass through will be fused by concatenation as:

$$\widetilde{\mathbf{x}} = \mathbf{con}\mathbf{cat}\left(r\_1 \dot{\mathbf{x}}\_{\text{db}}, r\_2 \overline{\mathbf{x}}\_{\text{db}}, r\_3 \dot{\mathbf{x}}\_{\text{sr}}, r\_4 \overline{\mathbf{x}}\_{\text{sr}}, r\_5 \hat{\mathbf{x}}\_{\text{tot}}\right) \tag{15}$$

#### *3.3. Network Architecture*

Based on Equation (7), we can design two CNN branches to learn the deblurring details *rdb* and the SR details *rsr* separately. This step is called decoupling induction learning. Moreover, we can individually calculate their multiple orders attention features, and utilize the drop-out gating method to fuse them. Here the step is named multi-order attention drop-out gating. The fused attention features concatenated with the LR and blur input images are then sent to the subsequent reconstruction module to obtain the final SR result.

The overall architecture of the proposed model is illustrated in Figure 2. Our model contains four dominant modules: the first one is the deblurring features extraction module, which can be used to predict the sharp LR image; the second one is the SR features extraction module, which can be utilized to obtain the super-resolved blur images; the third is the proposed multi-order attention drop-out gating module, which calculates different order attentions and fuses them with the drop-out gating mechanism; and the fourth one is the reconstruction module to recover the final clear and SR result. In the figure, the four modules mentioned are indicated by dashed boxes of different colors.

**Figure 2.** The overall architecture of the proposed model. Our model mainly contains four modules deblurring feature extraction, SR feature extraction, multi-order attention drop-out gating, and reconstruction. An LR and blur input image is first passed through the separate SR and deblurring branches to obtain the decoupled features; then, they go through a multi-order attention drop-out gating fusion, before being reconstructed to output a super-resolved and clear image.

#### 3.3.1. Deblurring Feature Extraction

This module aims to acquire the decoupled deblurring features, and henceforth, sharp LR images from blurry LR images *ILR+blur*. Inspired by [21], we adopted a residual encoder– decoder structure in this module. The encoder part is composed of several convolution layers which reduce the size of feature maps to a quarter of the input image size. We then added nine residual blocks between the encoder and decoder to refine the deblurring features. Then, the decoder exploits two deconvolutional upscaling layers to raise the resolution of the deblurring feature maps. Additionally, based on the deblurring features, we can use another two convolution operations to obtain a deblur LR image *ILR+deblur* (see Figure 2).

Here, we denote the output deblurring features of the decoder as *xdb*, which were later sent to the multi-order attention gating module. All the used activation layers are the leaky rectified linear units (LeakyReLU), and we used IN (instance normalization) operations in the residual blocks instead of the BN (batch normalization) ones, because the BN layer may reduce the flexibility of the network and undermine the scale information by normalizing the features and increasing computation. The mapping relationship learned from this module between the input *ILR+blur* and the output *xdb* can be described as:

$$\mathbf{x\_{db}} = \mathbf{H\_{\uparrow 2}}\left(\mathbf{H\_{\uparrow 1}}\left(\mathbf{R}\mathbf{B}\left(\mathbf{H\_{\downarrow 2}}\left(\mathbf{H\_{\downarrow 1}}\left(\mathbf{H\_{c}}\left(\mathbf{I\_{LR+blur}}\right)\right)\right)\right)\right)\right) \tag{16}$$

where *H*↓**<sup>2</sup>** and *H*↓**<sup>1</sup>** are the down-scaling convolution layers of an encoder, *H*↑**<sup>1</sup>** and *H*↓**<sup>1</sup>** are the deconvolution layers of a decoder, RB represents the nine residual blocks, and *Hc* is the first convolution layer acting on the input *ILR+blur*. The activation and normalization operations are included in the layers by default.

#### 3.3.2. SR Feature Extraction

The purpose of this module is to obtain decoupled SR image details. We utilized eight residual dense blocks [10] (each block contains five convolution operations with four LeakyReLU layers; see Figure 2 for reference) and one convolution layer to construct the deep structure to extract the high-frequency spatial detail features. From this, the superresolved blur image *ILR+blur* can also be acquired through two consecutive pixel shuffle layers and several convolution layers. To maintain the spatial information, neither the pooling layer nor stride operation is used in the module. At the same time, no normalization operations are applied. If denoting the extracted SR features as *xsr***,** then the mapping relationship learned from the module between the input *ILR+blur* and the output *xsr* can be described as:

$$\mathbf{x}\_{\rm sr} = RDB\_8(H\_c(I\_{LR+blur})) \tag{17}$$

where *RDB***<sup>8</sup>** represents the eight consecutive residual dense blocks.

#### 3.3.3. Multi-Order Attention Drop-Out Gating

This module summarizes the multiple orders attention of the learned deblurring features *xdb* and the SR features *xsr* to obtain high-frequency image recovering details. *xdb* and *xsr* are the inputs of the module and their first-order, second-order, and common third-order feature attention maps are calculated, respectively. Then, all these attention features are concatenated and sent to the drop-out layer to obtain the final feature maps *x*. The mapping relationship of this module and its processing details can be referred to in the previous Section 3.3.2 and Figure 2.

#### 3.3.4. Reconstruction Module

In this module, the gated attention features *<sup>x</sup>* and the blur LR image are sent into 16 residual dense blocks [10] and the result is further fed to two-pixel shuffle layers to improve the spatial resolution to 4×. After that, two convolution layers are used to acquire the final SR and clear image *ISR+clear*. Since most operations of our model are performed in the LR low dimension functional space, the computation cost both in training and in the testing stages is quite low. The mapping relationship of the module is described as:

$$I\_{SR+clear} = H\_{2c}(P\_2(P\_1(RDB\_{16}(\texttt{concact}(\tilde{\mathbf{x}}, I\_{LR+blur}))))))\tag{18}$$

where *RDB***<sup>16</sup>** is the 16 consecutive residual dense blocks, *P***<sup>1</sup>** and *P***<sup>2</sup>** are the two-pixel shuffle layers, and *H***2***<sup>c</sup>* represents two convolution operations.

#### *3.4. Loss Functions*

Our proposed model has three outputs: the LR deblurring image *ILR+db*, the SR blur image *ISR+blur*, and the clear SR image *ISR+clear*. Then, the total loss of our model contains three parts: the LR but clear image loss, the HR but blur image loss, and the final HR and clear image loss. In our case, we usually calculate the difference between a certain output and its expectation with the <sup>1</sup> norm and treat it as the loss. The three losses of our model can be described as:

$$\ell\_{\mathbf{1}} = \sum\_{i=\mathbf{1}}^{N} ||y\_{HR+clear,i} - I\_{\mathbf{SR}+clear,i}||\_{\mathbf{1}} \tag{19}$$

$$\ell\_2 = \sum\_{i=1}^{N} ||y\_{LR+clear,i} - I\_{LR+db,i}||\_1 \tag{20}$$

$$\ell\_3 = \sum\_{i=1}^{N} ||y\_{HR+blur,i} - I\_{SR+blur,i}||\_1 \tag{21}$$

where *yHR*<sup>+</sup>*clear*,*<sup>i</sup>* , *yLR*<sup>+</sup>*clear*,*<sup>i</sup>* , and *yHR*<sup>+</sup>*blur*,*<sup>i</sup>* are the expectations of the three outputs, respectively. *N* is the number of training samples. Thus, the total loss is the sum of the above three losses:

$$L = \ell\_1 + \mathfrak{a}\ell\_2 + (\mathbf{1} - \mathfrak{a})\ \ell\_3 \tag{22}$$

where *α* is the loss balance factor, which is set to be 0.5 in our experiments.

In addition, sometimes in order to generate a more realistic image, we also consider introducing an SSIM [33] measure into the loss **1.** At this time, the loss **<sup>1</sup>** can be modified as:

$$\ell\_1 = \sum\_{i=1}^{N} \left( \beta \text{SSIM} \left( y\_{\text{HR}+c\text{clear},i}, I\_{\text{SR}+c\text{clear},i} \right) + (1-\beta) \| y\_{\text{HR}+c\text{clear},i} - I\_{\text{SR}+c\text{clear},i} \|\_{2} \right) \tag{23}$$

where the *β* is used to balance these two terms, which is set to 0.84.

#### **4. Experimental Results**

#### *4.1. Datasets and Training Details*

Many experiments and performance comparisons are performed on the well-known public blur datasets: the GOPRO dataset [2] and the dataset developed by Lai et al. [34]. Originated from some natural video sequences, the GOPRO [2] dataset contains 2103 highresolution training pairs (the sharp image and the blurry image) and 1111 test images. The size of every image in the dataset is 1280 × 720. The motion-blurred image is obtained by averaging several neighboring frame images and the LR image can be acquired by bicubic down-sampling on the corresponding HR image. In contrast to GOPRO, the dataset of Lai et al. [34] is composed of many man-made generated blur images, in which each degenerated image is the convolution result of the sharp image with a blur kernel. Here, the size of the degraded kernel may range from 21 × 21 to 75 × 75. Note that Lai et al.'s dataset [34] contains both uniform and non-uniform blurred images. The main characteristics of the two datasets are summarized in Table 1.

The training of the proposed model can be divided into two steps. In the first step, the model is trained by supervision with the LR blurry patches *ILR+blur*, the sharp LR patches *ILR+clear*, and the clear HR patches *IHR+clear*. During training, the loss of our model (Equation (22)) is minimized. In the second step, the trained model is finetuned by using Equation (23) to replace the original <sup>1</sup> in Equation (22). The training procedure is implemented by the SGD solver from Pytorch [35] and the learning rate decreases from 0.01 to 0.00001 and the decay is set to be 0.5. In addition, the moment of the used solver

is 0.9 and the batch size of the training samples is 12. It takes about two days to train the proposed model if using an Nvidia Titan GTX1080ti GPU.

**Table 1.** Basic dataset characteristics of GOPRO [2] and Lai et al. [34].


#### *4.2. Experiments and Comparisons*

Based on numerous LR and blurry input images on different test datasets, we performed lots of joint image deblurring and SR experiments and made comparisons with some recent SOTA (state-of-the-art) image SR models [10–12], the deblurring method [5], and the multiple degradations recovery approaches [6–8,36]. We also compared the combination method of the SR algorithm [10] and the deblurring method [5]. For fair play, all the comparisons were made by using the public codes provided by these methods. For those ones which cannot be publicly acquired (such as ED-DSRN [7]), we used our dataset to retrain the original networks. The comparisons with these related methods using the datasets of GOPRO [2] and Lai et al. [34], in terms of the PSNR, the SSIM, the model parameters, and the test time, are demonstrated in Tables 2 and 3. The visual results of these methods are also compared in Figures 3–5.

**Figure 3.** The details in the deblurred and super-resolved (4×) images generated by the presented decoupled induction and multi-attention drop-out gating model on GOPRO [2] and Lai et al. [34]; using our method, the image details are clearer than the ones acquired from RCAN [11], SCGAN [6], and GFN [8].


**Table 2.** The comparisons with SOTA methods of the quantitative performance on GOPRO dataset [2]. Best results are marked in bold.

**Table 3.** The comparisons with SOTA methods of the quantitative performance on Lai et al. dataset [34]. Best results are marked in bold.


(**a**) HR (PSNR/SSIM)

(**f**) HR (PSNR/SSIM)

(**b**) RCAN (24.555/0.72)

(**g**) RCAN (21.446/0.61)

(**c**) SCGAN (24.39/0.67)

(**h**) SCGAN (21.295/0.56)

(**d**) GFN (25.297/0.713)

(**i**) GFN (21.358/0.578)

(**e**) Ours (25.279/0.731)

(**j**) Ours (21.78/0.61)

**Figure 5.** More visual comparison of our model with other methods on Lai et al. [34].

According to Tables 2 and 3, it is clear that in most cases our model achieves the best multi-degradation recovery effects, and only in certain special scenarios, it is slightly inferior to GFN [8] (see Figure 3d,e), which seems to be the best joint image SR and the deblurring algorithm available at present. In Figure 3d,e and Figure 4, although the PSNR is slightly lower, the image we recovered looks better than the image generated by GFN [8]. Such contradictions may stem from the fact that the calculation of PSNR or SSIM only requires the neighborhood operations of certain image pixels and cannot reflect the true perception of human vision. In light of the quantitative metrics in Tables 2 and 3, it is easy to see that, compared to the other methods, even under multiple different blurs and LR datasets, the proposed method can achieve the best or the second best PSNR and SSIM performance.

According to Figure 5, we can easily see that on the Lai et al. [34] dataset, our approach shows a significant improvement. Although adjustments have been made to RCAN by fine-tuning the dataset, it still cannot compete with our trained network (see Figure 5b,g). It is clear that Figure 5b,g contains less texture detail than Figure 5e,j. This performance gap is mainly due to the lack of an encoder–decoder structure, which is a key architecture when designing a blind deblurring network. Although the performance of the retrained SCGAN is better than its pre-trained model, because of its small model capacity, this method cannot handle complex non-uniform blurs well.

In general, compared with other methods, especially GFN, the superiority of our method lies in (1) our two branches (super-resolution and motion deblurring), which are fully disentangled, whereas GFN's are not; and (2) we use multi-order attention to obtain the attention features of the two branches at different orders separately, and perform gated fusion through the drop-out mechanism, whereas GFN computes the correlation of different branches for fusion. Due to the simpler structure, the GFN method has fewer parameters and a faster computation speed than our approach. However, in practical applications, assuming no particular requirements for machine memory or computing speed, our method can be used in preference if the scene is rich in significant textures and the objects have multi-scale variations. Benefiting from joint attention learning, our method produces clearer and higher resolution images with good perceptual quality.

#### **5. Ablation Study**

For the sake of dissecting the role of the key components of the proposed decoupling induction and multi-order attention gating model, several variants were developed and tested: (1) deblurring alone, (2) SR alone, (3) without TOCA, and (4) no drop-out gating. These variants were trained with almost the same hyper-parameters as our original model. For the variants of deblurring alone and SR alone, the FOCA and SOCA features were concatenated and pushed to the reconstruction module. For the variant without TOCA, there were only four attention routes ( . *xdb*, *xdb*, . *xsr*, *xsr*) for drop-out gating. The final variant used direct concatenating to replace drop-out gating. The results are shown in Table 4.


**Table 4.** Ablation study on GOPRO [2] dataset. The best results are indicated in bold.

From Table 4, it is clear that without drop-out gating, the performance of the proposed approach is much suppressed. At the same time, the high-order attention TOCA really can help to improve the reconstruction effects. In addition, it seems that the deblurring branch contributes more than the SR forking in multiple degradation decoupling reconstruction. Thus, we can conclude that the proposed mechanism of multi-order attention and drop-out gating is very effective for joint deblurring and super-resolution.

#### **6. Conclusions**

In this work, we proposed an effective end-to-end deep model which can deal with multiple degeneration problems for concurrent motion deblurring and image SR. Inspired by the idea of decoupled learning and multi-order attention features selection, our model firstly manages to construct the discrete network structures of motion deblurring and image SR respectively, and then realizes selective features' enhancement and fusion through multiorder attention drop-out gating. Many experimental results and comparisons to other SOTA methods were carried out to demonstrate the superior performance of our method in compound degradation recovery and generalization power.

Future work will focus on two aspects. The first one is to investigate why the deblurring branch matters more than SR forking in the proposed multiple degradation reconstruction approaches. Secondly, based on blur and resolution reduction, if more degeneration action (such as noise interference) is also introduced, a way to obtain a good image recovery effect will be investigated.

**Author Contributions:** Conceptualization, Y.C. and H.L.; methodology, Y.C. and H.L.; software, Y.C. and X.Z.; writing—original draft preparation, Y.C.; writing—review and editing, H.L.; visualization, X.Z.; funding acquisition, H.L. All authors have read and agreed to the published version of the manuscript.

**Funding:** This research was funded by the National Natural Science Foundation of China, grant No. 61971004, the Natural Science Foundation of Anhui Province, grant No. 2008085MF190, the Key Project of Natural Science of Anhui Provincial Department of Education, grant No. KJ2019A0083 and KJ2021A1289, and the Open Project Fund of the Key Laboratory of Computational Intelligence and Signal Processing of the Ministry of Education (Anhui University), grant No. 2020A002.

**Data Availability Statement:** The links to the public datasets used in the paper are as follows: GOPRO dataset: https://github.com/SeungjunNah/DeepDeblur\_release (accessed on 1 December 2021), Lai's dataset [34]: http://vllab.ucmerced.edu/wlai24/cvpr16\_deblur\_study/ (accessed on 1 December 2021).

**Conflicts of Interest:** The authors declare no conflict of interest.

#### **References**


### *Article* **Face Recognition via Compact Second-Order Image Gradient Orientations**

**He-Feng Yin 1,2,\*, Xiao-Jun Wu 1,2,\*, Cong Hu 1,2 and Xiaoning Song 1,2**


**Abstract:** Conventional subspace learning approaches based on image gradient orientations only employ first-order gradient information, which may ignore second-order or higher-order gradient information. Moreover, recent researches on the human vision system (HVS) have uncovered that the neural image is a landscape or a surface whose geometric properties can be captured through second-order gradient information. The second-order image gradient orientations (SOIGO) can mitigate the adverse effect of noise in face images. To reduce the redundancy of SOIGO, we propose compact SOIGO (CSOIGO) by applying linear complex principal component analysis (PCA) in SOIGO. To be more specific, the SOIGO of training data are firstly obtained. Then, linear complex PCA is applied to obtain features of reduced dimensionality. Combined with collaborative-representationbased classification (CRC) algorithm, the classification performance of CSOIGO is further enhanced. CSOIGO is evaluated under real-world disguise, synthesized occlusion, and mixed variations. Under the real disguise scenario, CSOIGO makes 2.67% and 1.09% improvement regarding accuracy when one and two neutral face images per subject are used as training samples, respectively. For the mixed variations, CSOIGO achieves a 0.86% improvement in terms of accuracy. These results indicate that the proposed method is superior to its competing approaches with few training samples, and even outperforms some prevailing deep-neural-network-based approaches.

**Keywords:** face recognition; second-order gradient; image gradient orientations; collaborativerepresentation-based classification

**MSC:** 68T10

#### **1. Introduction**

As one of the most active research topics, face recognition (FR) has aroused great attention in the domain of pattern recognition and computer vision. Considerable progress has been made during the past decades and many successful methods have been proposed. Nevertheless, complicated variations in face images (e.g., occlusion, illumination, and expression) bring a great challenge for FR systems. To increase the robustness to occlusion, researchers have developed a variety of approaches. Sparse representation-based classification (SRC) [1] was developed for FR and shows robustness to occlusion and corruption in the test images when combined with the block partition technique. Naseem et al. [2] proposed a modular linear regression classification (Modular LRC) approach with a distance-based evidence fusion (DEF) algorithm to tackle the problem of contiguous occlusion. Dividing an image into different blocks is an effective way for feature extraction. Adjabi et al. [3] developed the multiblock color-binarized statistical image features (MB-C-BSIF) method for single-sample face recognition. Abdulhussain et al. [4] presented a method for fast calculation of features of overlapping image blocks. To further enhance the performance of SRC, Li et al. [5] proposed a sparsity augmented weighted CRC approach

**Citation:** Yin, H.-F.; Wu, X.-J.; Hu, C.; Song, X. Face Recognition via Compact Second-Order Image Gradient Orientations. *Mathematics* **2022**, *10*, 2587. https://doi.org// 10.3390math10152587

Academic Editors: Jianping Gou, Weihua Ou, Shaoning Zeng and Lan Du

Received: 9 June 2022 Accepted: 22 July 2022 Published: 25 July 2022

**Publisher's Note:** MDPI stays neutral with regard to jurisdictional claims in published maps and institutional affiliations.

**Copyright:** © 2022 by the authors. Licensee MDPI, Basel, Switzerland. This article is an open access article distributed under the terms and conditions of the Creative Commons Attribution (CC BY) license (https:// creativecommons.org/licenses/by/ 4.0/).

for image recognition. Dong et al. [6] designed a low-rank Laplacian-uniform mixed (LR-LUM) model, which characterizes complex errors as a combination of continuous structured noises and random noises. Yang et al. [7] presented nuclear norm-based matrix regression (NMR), which employs two dimensional image-matrix-based error model rather than the one-dimensional pixel-based error model. The representation vector in NMR is imposed by the <sup>2</sup> norm, to make use of the discriminative property of sparsity, Chen et al. [8] proposed a sparse regularized NMR (SR-NMR) by replacing the <sup>2</sup> norm constraint on the representation vector with the <sup>1</sup> norm. However, the above approaches need uncorrupted training images. When providing corrupted training data, their performance will be deteriorated. To tackle the situation that both the training and test data are corrupted, low-rank matrix recovery (LRMR) can be applied. Chen et al. [9] proposed a discriminative low-rank representation (DLRR) method, which introduces the structural incoherence into the framework of low-rank representation (LRR) [10]. Gao et al. [11] proposed to learn robust and discriminative low-rank representation (RDLRR) by introducing low-rank constraint to simultaneously model the representation and each error term. Hu et al. [12] presented a robust FR method, which employs dual nuclear norm low-rank representation and a self-representation induced classifier. Yang et al. [13] developed a sparse low-rank component-based representation (SLCR) method for FR with low-quality images. Recently, Yang et al. [14] extended SLCR and proposed a FR technique named sparse individual low-rank component representation (SILR) for IoT-based systems. Inspired by LRR and deep learning techniques, Xia et al. [15] developed an embedded conformal deep low-rank autoencoder (ECLAE) neural network architecture for matrix recovery.

Recently, image gradient orientation (IGO) has attracted much attention due to its impressive results in occluded FR. Wu et al. [16] presented a gradient direction-based hierarchical adaptive sparse and low-rank (GD-HASLR) model, which performs in the image gradient direction domain rather than the image intensity domain. Li et al. [17] incorporated IGO into robust error coding and proposed an IGO-embedded structural error coding (IGO-SEC) model for FR with occlusion. Apart from the above two works, Zhang et al. [18] designed Gradientfaces for FR under varying illumination conditions. In essence, Gradientfaces is the IGO. Tzimiropoulos et al. [19] introduced the notion of subspace learning from IGO and developed approaches such as IGO-PCA and IGO-LDA. Vu [20] proposed a face representation approach called patterns of orientation difference (POD), which explores the relations of both gradient orientations and magnitudes. Zheng et al. [21] presented an online image alignment method via subspace learning from IGO. Qian et al. [22] presented a method called ID-NMR, in which the local gradient distribution is exploited to decompose the image into several gradient images. Wu et al. [23] proposed a new feature descriptor called the histogram of maximum gradient and edge orientation (HGEO) for the purpose of multispectral image matching.

The above IGO-based approaches only take the first-order gradient information into account, thus neglecting the second-order or higher-order gradient information. Latest researches on human vision have discovered that the neural image is a landscape or a surface whose geometric properties can be described by local curvatures of differential geometry through second-order gradient information [24,25]. Based on the second-order gradient, Huang et al. [24] presented a new local image descriptor called histograms of second-order gradient (HSOG). Li et al. [26] proposed a patterned fabric defect detection method based on the second-order, orientation-aware descriptor. Zhang et al. [27] designed a blind image quality assessment (IQA) method based on multiorder gradient statistics. Bastian et al. [28] developed a pedestrian detector utilizing both the first-order and the second-order gradient information in the image. Nevertheless, the above second-ordergradient-based approaches do not involve a dimensionality reduction technique, which results in redundant information. To alleviate this problem, we introduce PCA into the framework of SOIGO to extract more compact features. Moreover, we employ CRC as the final classifier due to its effectiveness and efficiency. Experimental results show that our

proposed method (CSOIGO) is robust to real disguise, synthesized occlusion, and mixed variations and is superior to some popular deep-neural-network-based approaches.

Our main contributions are outlined as follows:


The remainder of this paper is arranged as follows. Section 2 reviews some related work. In Section 3, we present our proposed approach. Section 4 conducts several experiments to demonstrate the efficacy of our proposed method. Finally, conclusions are drawn in Section 5.

#### **2. Related Work**

#### *2.1. IGO-PCA*

Given a set of images {**Z***i*} (*i* = 1, 2, ... , *N*), where *N* denotes the number of training images and **<sup>Z</sup>***<sup>i</sup>* ∈ R*m*×*n*. Suppose that **<sup>I</sup>**(*x*, *<sup>y</sup>*) is the image intensities at pixel coordinates (*x*, *y*) of sample **Z***i*, the horizontal and vertical gradient can be obtained by the following formulations:

$$\begin{aligned} \mathbf{G}\_{i,\mathbf{x}} &= h\_{\mathbf{x}} \ast \mathbf{I}(\mathbf{x}, \mathbf{y}) \\ \mathbf{G}\_{i,\mathbf{y}} &= h\_{\mathbf{y}} \ast \mathbf{I}(\mathbf{x}, \mathbf{y}), \end{aligned} \tag{1}$$

where ∗ expresses convolution, and *hx* and *hy* are filters employed to approximate the ideal differentiation operator along the image horizontal and vertical directions, respectively [29]. Image gradient contains edge information and is used to characterize the structure of an image. In [30], gradient feature map is extracted from the input image and exploited as a structural prior to guide the process of image reconstruction. However, the image data mostly distribute discretely in real-world scenarios; so, we usually use differences to compute the gradients, i.e., achieving the gradients through the difference between adjacent pixels' gray values. Thus, horizontal and vertical gradients can be reformulated as

$$\begin{aligned} \mathbf{G}\_{i,\mathbf{x}} &= \mathbf{I}(\mathbf{x} + \mathbf{1}, \mathbf{y}) - \mathbf{I}(\mathbf{x}, \mathbf{y}) \\ \mathbf{G}\_{i,\mathbf{y}} &= \mathbf{I}(\mathbf{x}, \mathbf{y} + \mathbf{1}) - \mathbf{I}(\mathbf{x}, \mathbf{y}). \end{aligned} \tag{2}$$

Then, the gradient orientation of the pixel location (*x*, *y*) is

$$\Phi\_i(x, y) = \arctan \frac{\mathbf{G}\_{i, y}}{\mathbf{G}\_{i, x}}, i = 1, 2, \dots, N. \tag{3}$$

For each image **Z***<sup>i</sup>* whose size is *m* × *n*, we can obtain a corresponding gradient orientation matrix <sup>Φ</sup>*<sup>i</sup>* ∈ [0, 2*π*)*m*×*n*. Then, we can obtain the corresponding sample vectors by converting 2D images Φ*<sup>i</sup>* into 1D vectors *φi*. Referring to [19], we also define the mapping from [0, 2*π*)*K*(*<sup>K</sup>* <sup>=</sup> *<sup>m</sup>* <sup>×</sup> *<sup>n</sup>*) onto a subset of complex sphere with radius <sup>√</sup>*K*,

$$\mathfrak{t}\_i(\phi\_i) = \mathfrak{e}^{j\phi\_i},\tag{4}$$

where *ejφ<sup>i</sup>* = [*ejφ*<sup>1</sup> ,*ejφ*<sup>2</sup> , ...,*ejφ<sup>K</sup>* ] *<sup>T</sup>* and *ej<sup>θ</sup>* is Euler form, i.e., *ej<sup>θ</sup>* = cos*θ* + *j*sin*θ*. Then, we can apply complex linear PCA to the transformed *ti*—that is, we seek for a set of *d* < *K* orthonormal bases **<sup>U</sup>** = [**u**1, **<sup>u</sup>**2, ..., **<sup>u</sup>***d*] ∈ C*K*×*<sup>d</sup>* by solving the following problem:

$$\epsilon(\mathbf{U}) = \left\| \mathbf{X} - \mathbf{U}\mathbf{U}^H \mathbf{X} \right\|\_{F'}^2 \tag{5}$$

where **<sup>X</sup>** = [*t*1,*t*2, ...,*tN*] <sup>∈</sup> <sup>C</sup>*K*×*N*, **<sup>U</sup>***<sup>H</sup>* is the conjugate transpose of **<sup>U</sup>**, and .*<sup>F</sup>* denotes the Frobenius norm. Equation (5) can be reformulated as

$$\mathbf{U}\_o = \arg\max\_{\mathbf{U}} tr(\mathbf{U}^H \mathbf{X} \mathbf{X}^H \mathbf{U}), \text{ s.t. } \mathbf{U}^H \mathbf{U} = \mathbf{I}. \tag{6}$$

The solution is given by the *d* eigenvectors of **XX***<sup>H</sup>* corresponding to the *d* largest eigenvalues. Then, the *<sup>d</sup>*-dimensional embedding **<sup>Y</sup>** ∈ C*d*×*<sup>N</sup>* of **<sup>X</sup>** is produced by **<sup>Y</sup>** = **<sup>U</sup>***H***X**.

#### *2.2. Collaborative-Representation-Based Classification*

During the past few years, the representation-based classification method (RBCM) has attracted lots of attention in the community of pattern recognition. The pioneering work is SRC [1]. In SRC, the <sup>1</sup> norm constraint is employed to attain the sparse coefficient of test data. Zhang et al. [31] argued that it is the collaborative representation mechanism rather than the <sup>1</sup> norm constraint that makes SRC successful for FR. Therefore, they developed the CRC method, which replaces the <sup>1</sup> norm constraint with the <sup>2</sup> norm. Afterwards, many improved methods were proposed to further boost the classification performance of CRC. Gou et al. [32] developed a class-specific mean vector-based weighted competitive and collaborative representation (CMWCCR) method, which fully employs the discrimination information in different ways. Motivated by the idea of linear representation, Gou et al. [33] proposed a representation coefficient-based k-nearest centroid neighbor (RCKNCN) method. Recently, Gou et al. [34] presented a hierarchical graph augmented deep collaborative dictionary learning (HGDCDL) model, which applies collaborative representation to the deepest-level representation learning. For simplicity, in this paper, we employ the original CRC as the classifier, and the objective function of CRC is formulated as follows:

$$\min\_{\mathbf{a}} \left\{ \|\|\mathbf{y} - \mathbf{D}\mathbf{a}\|\|\_{2}^{2} + \lambda \|\|\mathbf{a}\|\|\_{2}^{2} \right\} \tag{7}$$

where *y* is the test sample, **D** is the dictionary that contains all the training data from *C* classes, and *λ* is a balancing parameter. Equation (7) has the following closed-form solution,

$$\boldsymbol{\mathfrak{u}} = \left(\mathbf{D}^{\mathrm{T}}\mathbf{D} + \lambda\mathbf{I}\right)^{-1}\mathbf{D}^{\mathrm{T}}\boldsymbol{y}.\tag{8}$$

In the classification stage, apart from the class-specific reconstruction error *y* − **D***jα<sup>j</sup>* 2 , *j* = 1, 2, ... ,*C*, where *α<sup>j</sup>* is the coefficient vector corresponding to the *j*th class, Zhang et al. [31] found that *α<sup>j</sup>* <sup>2</sup> also contains some discriminative information for classification. Thus, they presented the following regularized residuals for classification,

$$\text{identity}(y) = \arg\min\_{j} \frac{||y - \mathbf{D}\_j \mathbf{a}\_j||\_2}{||\mathbf{a}\_j||\_2}. \tag{9}$$

#### **3. Proposed Method**

Previous studies revealed that gradient information at different orders characterize different structural features of natural scenes. The first-order gradient information is related to the slope and elasticity of a surface, while the second-order gradient delivers the curvature-related geometric properties. Figure 1 depicts two images and their corresponding landscapes plotted as surfaces; one can see that these landscapes contain a variety of local shapes, such as cliffs, ridges, summits, valleys, and basins. Inspired by the above results, we propose a new FR method that exploits the SOIGO. The second-order gradient is obtained based on the first-order gradient information defined in Equation (2),

$$\begin{aligned} \mathbf{G}\_{i,\mathbf{x}}^2 &= \mathbf{G}\_{i,\mathbf{x}}(\mathbf{x}+1,\mathbf{y}) - \mathbf{G}\_{i,\mathbf{x}}(\mathbf{x},\mathbf{y}) \\ \mathbf{G}\_{i,\mathbf{y}}^2 &= \mathbf{G}\_{i,\mathbf{y}}(\mathbf{x},\mathbf{y}+1) - \mathbf{G}\_{i,\mathbf{y}}(\mathbf{x},\mathbf{y}), \end{aligned} \tag{10}$$

where **G**<sup>2</sup> *<sup>i</sup>*,*<sup>x</sup>* and **<sup>G</sup>**<sup>2</sup> *<sup>i</sup>*,*<sup>y</sup>* are the second-order gradient along the horizontal and vertical directions, respectively. Therefore, the SOIGO is computed as follows:

*i*,*y*

. (11)

Φ2

**Figure 1.** Original images (**left part**) and their surface plots (**right part**).

Figure 2 presents an original face image and its gradient orientations of the first and second orders; one can see that, compared with the first-order IGO, the SOIGO significantly depresses the noise in the orientation domain. Moreover, the SOIGO contains more fine information than the first-order IGO, e.g., areas around the eyes, nose, and mouth.

**Figure 2.** Original face image and its gradient orientations of the first and second orders, respectively.

To further illustrate the effectiveness of using the SOIGO, we visualize the original data, the first-order IGO, and the SOIGO on the AR database by employing the t-SNE algorithm [35] in Figure 3. These data are selected from the first ten subjects on the AR database; for each person, seven nonoccluded face images in Session 1 are used. Then, these images are occluded by a square baboon image with a percentage of 30%. For detailed experimental settings, please refer to Section 4.3. As can be seen from Figure 3, though the first-order IGO looks better compared with the original data, clusters of different classes are mixed together. In Figure 3c, the cluster of the same class is more compact than that of Figure 3b, which is beneficial for subsequent classification.

The procedures of obtaining the projection matrix **U** is the same as in IGO-PCA. Then, for a test image **Z***t*, we first compute its SOIGO and obtain *t* after the mapping defined by Equation (4). Embeddings of training and test images are derived as follows:

$$\mathbf{Y} = \mathbf{U}^H \mathbf{X}, \ z = \mathbf{U}^H \mathbf{t},\tag{12}$$

where **<sup>Y</sup>** ∈ C*d*×*<sup>N</sup>* and *<sup>z</sup>* ∈ C*d*×1. To make the embeddings of training and test images suitable for CRC, we employ both the real and imaginary parts of **Y** and *z* as the input of CRC; let

$$\mathbf{D} = \begin{bmatrix} \text{real}(\mathbf{Y}) \\ \text{imag}(\mathbf{Y}) \end{bmatrix}, \mathbf{y} = \begin{bmatrix} \text{real}(\mathbf{z}) \\ \text{imag}(\mathbf{z}) \end{bmatrix} \tag{13}$$

where real(·) and imag(·) are the real part and imaginary part of complex number, respectively. Then, we compute the representation coefficient vector of *y* over **D**; this is followed by checking which class results in the least regularized residual. The pipeline of our proposed CSOIGO is illustrated in Figure 4, and the complete process of CSOIGO is outlined in Algorithm 1.

When assessing the performance of an algorithm, we should take its computational complexity into account. The major consumption of CSOIGO lies in the linear complex PCA and CRC, and they both involve the operation of matrix. It takes O(*K*2*N*) to compute the covariance matrix and O(*K*3) for eigen-decomposition in the process of PCA, where *K* = *m* × *n* and *N* denote the dimensionality and total number of training images. From Equation (8), one can see that CRC contains matrix multiplication and matrix inversion, and it takes O(*N*2*d*) to compute **<sup>D</sup>***T***<sup>D</sup>** and O(*N*3) for the inverse operation of matrix, where *d* is the reduced dimensionality. Suppose there are *p* test samples, CRC takes

O(*N*2*<sup>d</sup>* + *<sup>N</sup>*<sup>3</sup> + *Ndp*) to completely classify them. Therefore, the total computational complexity of CSOIGO is O(*K*2*<sup>N</sup>* + *<sup>K</sup>*<sup>3</sup> + *<sup>N</sup>*2*<sup>d</sup>* + *<sup>N</sup>*<sup>3</sup> + *Ndp*).

**Figure 4.** The pipeline of our proposed CSOIGO.

#### **Algorithm 1** CSOIGO

**Input:** A set of *N* training images {**Z***i*}(*i* = 1, 2, ... , *N*) from *C* classes, test image **Z***t*, the number of principal components *d*, and the regularization parameter *λ* for CRC.

1. Obtain the SOIGO Φ<sup>2</sup> *<sup>i</sup>* of training images and convert it to 1D vector *<sup>φ</sup>*<sup>2</sup> *i* .

2. Compute **t***i*(*φ*<sup>2</sup> *<sup>i</sup>* ) = *<sup>e</sup>jφ*<sup>2</sup> *<sup>i</sup>* ; all the SOIGO of training images form the matrix **X** = [**t**1, **t**2, ..., **t***N*].

3. Obtain the projection matrix **U** via Equation (6).

4. For the test image **Z***t*, obtain its SOIGO Φ<sup>2</sup> *<sup>t</sup>* and convert it to 1D vector *φ*<sup>2</sup> *<sup>t</sup>* ; then, compute *t* = *ejφ*<sup>2</sup> *t* .

5. Obtain the embeddings of training and test images via Equation (12).


8. Compute the regularized residuals *<sup>r</sup><sup>j</sup>* <sup>=</sup> *y*−**D***jαj*<sup>2</sup> *αj*<sup>2</sup> , *j* = 1, 2, . . . , *C*.

**Output:** identity(**Z***t*) = arg min*<sup>j</sup> <sup>r</sup>j*.

#### **4. Experimental Results and Analysis**

In this section, experiments are conducted under different scenarios to validate the effectiveness of the proposed method. For reproduction, the source code of CSOIGO is available at https://github.com/yinhefeng/SOIGO.

#### *4.1. Recognition with Real Disguise*

The AR database contains over 4000 images of 126 subjects. For each individual, 26 images are taken in two separate sessions. There are 13 images for each session, in which three images with sunglasses, another three with scarves, and the remaining seven have different illumination and expression changes; the 13 images of one subject from Session 1 are shown in Figure 5. Each image is 165 × 120 pixels. For fair comparison, we use the same subset as in [16], which consists of 50 men and 50 women, and all images are resized to 42 × 30 pixels. The neutral face image of each subject is used as training data, and the sunglasses/scarf occluded images in each session for testing. The proposed method is compared with other state-of-the-art approaches, including HQPAMI [36], NR [37], ProCRC [38], F-LR-IRNNLS [39], EGSNR [40], LDMR [41], and GD-HASLR [16]. To better illustrate the superiority of CSOIGO, we also present the results of IGO-PCA-NNC [19], IGO-PCA-CRC, and SOIGO-PCA-NNC. Table 1 summarizes the experimental results; one can see that CSOIGO achieves the highest recognition accuracy under all cases except for the sunglasses scenario of session 1. Since the test images are partially occluded by sunglasses or scarf, HQPAMI, NR, ProCRC, and LDMR seem not very robust to contiguous occlusion. Due to the preprocessing step that separates outlier pixels and corruptions from the training samples, the overall classification accuracy of F-LR-IRNNLS is higher than that of EGSNR. IGO-PCA-CRC ranks second over all methods and achieves 5.66% higher accuracy than IGO-PCA-NNC, which validates the efficacy of CRC when coping with IGO features. GD-HASLR has competitive performance with SOIGO-PCA-NNC. However, the overall accuracy gain of CSOIGO over GD-HASLR and IGO-PCA-CRC is 4.5% and 2.67%, respectively. The above experimental results indicate that our proposed CSOIGO is robust to real disguise even when a single training sample per person is available.

**Figure 5.** Some example face images from the AR database: (**a**) the neutral image of a subject from Session 1; (**b**) face images with illumination and expression variations; (**c**) images occluded by sunglasses/scarf.

Next, we utilize two neutral face images per subject from Sessions 1 and 2 for training, and the test sets are identical with the first experiment. The results are reported in Table 2. As can be seen from Table 2, CSOIGO yields the best overall recognition accuracy and outperforms GD-HASLR by 2.92%. Again, IGO-PCA-CRC ranks second in all methods. SOIGO-PCA-NNC outperforms IGO-PCA-NNC, and CSOIGO achieves higher accuracy than IGO-PCA-CRC, which indicates that SOIGO is more robust to occlusion than IGO.


**Table 1.** Recognition accuracy (%) of competing approaches on a subset of the AR database (test samples contain sunglasses occlusion or scarf occlusion) when only one neutral face image per subject from Session 1 is used as training sample. The dimension that leads to the best result for IGO- and SOIGO-based approaches is given in parentheses.

Bold values indicate the best recognition accuracy.

**Table 2.** Recognition accuracy (%) of competing approaches on a subset of the AR database (test samples contain sunglasses occlusion or scarf occlusion) when two neutral face images (from Sessions 1 and 2) per subject are used as training samples, the dimension that leads to the best result for IGOand SOIGO-based approaches is given in parentheses.


Bold values indicate the best recognition accuracy.

#### *4.2. Comparison with CNN-Based Approaches*

In this subsection, we compare our proposed method with prevailing deep-learningbased approaches. The first one is VGGFace [42], which is based on the VGGNet [43] and has 16 convolutional layers, five max-pooling layers, three fully-connected layers, and a final linear layer with softmax layer. In our experiments, we employ FC6 and FC7 for feature extraction. The second one is Lightened CNN [44], which has a low computational complexity. Lightened CNN consists of two different models, i.e., Model A and Model B. Model A is based on the AlexNet [45], which contains four convolution layers using the max feature map (MFM) activation functions, four max-pooling layers, two fully-connected layers, and a linear layer with softmax activation in the output. Model B is based on the Network in Network model [46] and consists of five convolution layers using the MFM activation functions, four convolutional layers for dimensionality reduction, five max-pooling layers, two fully-connected layers, and a linear layer with softmax activation in the output. For Lightened CNN, FC1 is used for feature extraction. All the features extracted by VGGFace and Lightened CNN are classified using the nearest neighbor classifier with cosine distance. When training VGGFace, the size of input image is 224×224, and the preprocessing operation involves subtracting the mean RGB value, computed on

the training set, from each pixel. The batch size, number of epochs, and optimizer are 256, 74, and *sgdm*, respectively. The learning rate is initially set to 1 × <sup>10</sup>−<sup>2</sup> and then decreased by a factor of 10. For training Lightened CNN, the size of input image is 144 × 144, and the input image is cropped into 128 × 128 and mirrored. The batch size, number of epochs, and optimizer are 20, 150, and *rmsprop*, respectively. The learning rate is set to 1 × <sup>10</sup>−<sup>3</sup> initially and reduced to 5 × <sup>10</sup>−<sup>5</sup> gradually.

As in Section 4.1, the first experiment is one neutral face of each subject for training on the AR database, and the experimental results are summarized in Table 3. Table 4 lists the results when two neutral faces are used for training. From Tables 3 and 4, we can see that VGGFace performs better in the scarf scenario than in the sunglasses scenario. This indicates that VGGFace has difficulty tackling the upper face occlusion, and this phenomenon is also observed in [47]. Moreover, when using more training samples, the performance of VGGFace does not improve. Hence, to increase robustness to upper face occlusion, VGGFace may need much more training data. By comparison, our proposed CSOIGO can achieve better results even with few training samples. In practical applications, training data may be insufficient. In this situation, CSOIGO is more appropriate to realize robust face recognition than VGGFace.

**Table 3.** Comparison with CNN-based approaches on a subset of the AR database (test samples contain sunglasses occlusion or scarf occlusion) when only one neutral face image per subject from Session 1 is used as training samples. The dimension that leads to the best result for IGO- and SOIGO-based approaches is given in parentheses.


Bold values indicate the best recognition accuracy.

**Table 4.** Comparison with CNN-based approaches on a subset of the AR database (test samples contain sunglasses occlusion or scarf occlusion) when two neutral face images (from Sessions 1 and 2) per subject are used as training samples. The dimension that leads to the best result for IGO- and SOIGO-based approaches is given in parentheses.


Bold values indicate the best recognition accuracy.

Similar to the results of VGGFace, Lightened CNN performs worse in the sunglasses scenario than in the scarf scenario. Additionally, Model A outperforms Model B, and Model A also achieves higher accuracy than VGGFace. However, whether one or two neutral face images per subject are used for training, our proposed CSOIGO achieves the best overall recognition accuracy.

#### *4.3. Random Block Occlusion*

Here, we conduct other experiments using synthesized occluded face data as testing data. For each subject, seven nonoccluded face images in the AR dataset in Session 1 are used for training and the other seven nonoccluded images in Session 2 for testing, the image size is 42 × 30 pixels. Block occlusion is tested by placing the square baboon image on each test image. The location of the occlusion is randomly chosen and is unknown during training. We consider different sizes of the object such that the face is covered with the occluded object from 30% to 50% of its area; some occluded face images are shown in Figure 6. The above experimental results indicate that GD-HASLR is superior to other competing approaches; therefore, in this subsection and the following subsection, we report the result of GD-HASLR for comparison. Recognition results for different levels of occlusion are shown in Table 5. One can see that CSOIGO outperforms GD-HASLR by a large margin, and the performance gain is significant with the increasing percentage of occlusion. Moreover, SOIGO-PCA-NNC outperforms IGO-PCA-NNC and CSOIGO performs better than IGO-PCA-CRC, which demonstrates that SOIGO is more robust than IGO when dealing with artificial occlusion.

**Figure 6.** Original face image and its occluded images with different occlusion percentages; from the second to the last, the percentage is 30%, 40%, and 50%, respectively.

**Table 5.** Recognition accuracy (%) of competing methods under different percentages of occlusion on a subset of the AR database (original training and test samples have no sunglasses occlusion or scarf occlusion). The dimension that leads to the best result for IGO- and SOIGO-based approaches is given in parentheses.


Bold values indicate the best recognition accuracy.

To vividly show the performance of IGO- and SOIGO-based approaches under different numbers of features, in Figure 7 we plot the recognition accuracy against the number of features when the percentage of occlusion is 30%. We can clearly see that with the increasing number of features, CSOIGO consistently outperforms the other three competing approaches.

**Figure 7.** Recognition accuracy versus different numbers of features when the percentage of occlusion is 30%.

#### *4.4. Recognition with Mixed Variations*

In this subsection, we evaluate our proposed CSOIGO and other compared approaches under the mixed variations. As shown in Figure 5a,b, the first seven images per subject in Session 1 have variations of expression and illumination; thus, seven nonoccluded images from Session 1 of the AR database are selected for training and another seven undisguised images from Session 2 are used for testing. Recognition accuracy and testing time of compared methods are shown in Table 6. It should be noted that the testing time refers to the time that classifies all the test samples. All experiments are performed on a laptop with Windows 10, an Intel Core i9-8950HK CPU at 2.90 GHz, and 32.00 GB RAM. The implementation software is MATLAB R2022a. From Table 6, we can see that CSOIGO has the best classification performance. Specifically, it makes 1.86% and 0.86% improvement in terms of accuracy over GD-HASLR and IGO-PCA-CRC, respectively. Due to the complex optimization process, GD-HASLR consumes much more time than the other approaches. The testing time is almost the same for both IGO-PCA-NNC and SOIGO-PCA-NNC. NNC is a simple and efficient classifier, while CRC involves the computations of coefficient vector and classwise residual. As a result, CSOIGO takes a little longer than SOIGO-PCA-NNC. However, CSOIGO is much faster than GD-HASLR.

**Table 6.** Recognition accuracy (%) and testing time (s) of compared approaches with mixed variations on a subset of the AR database (training and test samples have expression and illumination changes). The dimension that leads to the best result for IGO- and SOIGO-based approaches is given in parentheses.


Bold values indicate the best recognition accuracy.

As in the previous subsection, we show the recognition accuracy against the number of features in Figure 8. It can be seen that as the number of features increases, the recognition accuracies of IGO-PCA-NNC, SOIGO-PCA-NNC, and CSOIGO also increase. The recognition accuracy of IGO-PCA-CRC firstly increases, then decreases to some extent, and then it increases again. When the number of features exceeds 108, CSOIGO always achieves higher accuracy than its competing methods. This again demonstrates that CSOIGO is robust to mixed variations in face images.

**Figure 8.** Recognition accuracy versus different number of features under mixed variations.

#### **5. Conclusions**

In this paper, we present a new method for occluded face recognition, namely, CSOIGO, by exploiting the second-order gradient information. SOIGO is robust to real disguise, synthesized occlusion, and mixed variations. By employing CRC as the final classifier, our proposed method achieves impressive results in various scenarios and even outperforms some deep-neural-network-based approaches. Taking the real disguise experiment as an example, when one and two neutral face images per subject are used as training samples, CSOIGO attains an overall accuracy of 79.50% and 91.17%, respectively. Therefore, our proposed CSOIGO is superior to its competing approaches.

The limitation of CSOIGO is that it needs registered images for training and testing, i.e., when classifying face images with pose changes, its recognition performance will be degraded. Consequently, CSOIGO can be applied to applications of access control, automatic teller machines, or other security facilities. In these circumstances, we can obtain controlled training images in advance and the test images will be collected under similar scenarios. However, if registered face images cannot be collected during either the training or test stage, one can employ image registration methods to remedy the above limitation to some extent.

In future work, we will introduce SOIGO into other popular subspace learning approaches, e.g., linear discriminant analysis (LDA), to extract more discriminative features. Moreover, other variants of CRC will also be investigated to further enhance the performance of recognition.

**Author Contributions:** Conceptualization, H.-F.Y.; methodology, H.-F.Y. and X.-J.W.; software, H.- F.Y.; validation, H.-F.Y., X.-J.W., C.H. and X.S.; formal analysis, H.-F.Y. and X.-J.W.; investigation, H.-F.Y. and X.-J.W.; resources, H.-F.Y.; data curation, H.-F.Y., X.-J.W. and X.S.; writing—original draft preparation, H.-F.Y.; writing—review and editing, X.-J.W., C.H. and X.S.; visualization, H.-F.Y.; supervision, X.-J.W.; project administration, X.-J.W. and X.S.; funding acquisition, X.-J.W., C.H. and X.S. All authors have read and agreed to the published version of the manuscript.

**Funding:** This research was funded in part by the National Natural Science Foundation of China (Grant 62020106012, Grant U1836218, Grant 61902153, Grant 61876072, Grant 62006097, Grant 61672265), in part by the Fundamental Research Funds for the Central Universities (Grant JUSRP121104), in part by the Major Project of National Social Science Foundation of China (Grant 21&ZD166), in part by the Natural Science Foundation of Jiangsu Province (Grant BK20200593), and in part by the 111 Project of Ministry of Education of China (Grant B12018).

**Institutional Review Board Statement:** Not applicable.

**Informed Consent Statement:** Not applicable.

**Data Availability Statement:** The data presented in this study are available on request from the corresponding author.

**Conflicts of Interest:** The authors declare no conflict of interest.

#### **References**


### *Article* **Theme-Aware Semi-Supervised Image Aesthetic Quality Assessment**

**Xiaodan Zhang 1,†, Xun Zhang 1,†, Yuan Xiao <sup>1</sup> and Gang Liu 2,\***


**Abstract:** Image aesthetic quality assessment (IAQA) has aroused considerable interest in recent years and is widely used in various applications, such as image retrieval, album management, chat robot and social media. However, existing methods need an excessive amount of labeled data to train the model. Collecting the enormous quantity of human scored training data is not always feasible due to a number of factors, such as the expensiveness of the labeling process and the difficulty in correctly classifying data. Previous studies have evaluated the aesthetic of a photo based only on image features, but have ignored the criterion bias associated with the themes. In this work, we present a new theme-aware semi-supervised image quality assessment method to address these difficulties. Specifically, the proposed method consists of two steps: a representation learning step and a label propagation step. In the representation learning step, we propose a robust theme-aware attention network (TAAN) to cope with the theme criterion bias problem. In the label propagation step, we use preliminary trained TAAN by step one to extract features and utilize the label propagation with a cumulative confidence (LPCC) algorithm to assign pseudo-labels to the unlabeled data. This enables use of both labeled and unlabeled data to train the TAAN model. To the best of our knowledge, this is the first time that a semi-supervised learning method to address image aesthetic assessment problems has been studied. We evaluate our approach on three benchmark datasets and show that it can achieve almost the same performance as a fully supervised learning method for a small number of samples. Furthermore, we show that our semi-supervised approach is robust to using varying quantities of labeled data.

**Keywords:** image aesthetic assessment; semi-supervised learning; label propagation; deep learning; computer vision

**MSC:** 68T07

### **1. Introduction**

With the vigorous development of mobile Internet, images have become an indispensable part of our life. In the face of vast amounts of data, relying solely on human beings for the aesthetic analysis of images is not able to meet our needs, so the design of automatic aesthetic assessment algorithms has aroused considerable interest in the research community.

With respect to the various methods available for generating features, existing image aesthetic quality assessment methods can be broadly divided into two categories. The first category includes shallow modeling methods which use hand-crafted features to infer image aesthetic quality [1–3]. These methods use global, local and general features to represent aesthetic attributes. Among them, the Fisher vector (FC) [3] is used to construct aesthetic attributes and predict aesthetic quality. However, The representation ability

**Citation:** Zhang, X.; Zhang, X.; Xiao, Y.; Liu, G. Theme-Aware Semi-Supervised Image Aesthetic Quality Assessment. *Mathematics* **2022**, *10*, 2609. https://doi.org/ 10.3390/math10152609

Academic Editors: Jianping Gou, Weihua Ou, Shaoning Zeng and Lan Du

Received: 20 June 2022 Accepted: 22 July 2022 Published: 26 July 2022

**Publisher's Note:** MDPI stays neutral with regard to jurisdictional claims in published maps and institutional affiliations.

**Copyright:** © 2022 by the authors. Licensee MDPI, Basel, Switzerland. This article is an open access article distributed under the terms and conditions of the Creative Commons Attribution (CC BY) license (https:// creativecommons.org/licenses/by/ 4.0/).

of hand-crafted features is limited. The second category includes deep-learning-based methods. Because of the outstanding capabilities in efficient feature learning, convolutional neural networks (CNNs) have been used to infer composition information and learn new aesthetic representations (see, for example, [4–6]). Since the high-level features constructed by convolutional neural networks can better express the aesthetic quality, the performance of convolutional neural networks is better than that of traditional hand-crafted feature methods. Earlier attempts to develop CNNs [4–11] were able to help computers learn how to automatically evaluate an image. However, there are two major flaws in existing deep learning-based methods: Firstly, existing deep-learning-based methods require a large number of labeled datasets to train the network. However, collecting the enormity of human scored training data is not always feasible since manual annotation of aesthetic quality is a time-consuming, expensive and error-prone task. Thus, it is crucial to develop a method that only uses a small quantity of training data to reduce the reliance on manual annotation. Second, most previous research has only focused on the aesthetic features of the images but has ignored the criterion bias associated with their themes. Photographers shoot different scenes with different shooting methods. The scenes shot by each shooting method can be regarded as having a specific theme, but different shooting methods have different standards for the assessment of aesthetic quality. Thus, different themes use different evaluation criteria. For example, a highly blurred image may obtain a significant high score under the theme "Motion Blur" because blurring is regarded as a good feature'; however, it will obtain a low aesthetic score under the theme "Landscape", since blurring is considered to be a drawback for landscape images. Thus, it is appropriate to take the themes into account when aesthetic decisions are made.

Therefore, we propose a theme-aware semi-supervised image aesthetic quality assessment to solve the above-mentioned problems. To deal with the first problem, we employ a deep-learning-based label propagation method which is based on the assumption of making predictions on the entire dataset and using these to generate pseudo-labels for the unlabeled data. To handle the noise label problem in the process of label propagation, we also propose a cumulative confidence algorithm which can apply different weights to different unlabeled data. For data similar to previous prediction results, we apply a higher confidence weight; for dissimilar data, we apply a lower confidence weight. For the second problem, we propose a theme-aware attention network that considers the theme of an image when an aesthetic decision is made. This network consists of three components: an image feature extractor (backbone), a self-attention-based theme encoder and a residual connection module. The proposed network not only aims to extract visual features more effectively, but also leverages the theme power of tag and challenges to make aesthetic predictions more accurate.

The contributions of this paper are as follows:


The remainder of this paper is organized as follows: Section 2 summarizes related work. Section 3 introduces the methodology of the proposed theme aware semi-supervised approach. Section 4 quantitatively analyses the effectiveness of the proposed method and compares it with state-of-the-art results. Finally, Section 5 contains a summary and plans for future work.

#### **2. Related Work**

#### *2.1. Image Aesthetics Quality Assessment*

Image aesthetic quality assessment is a branch of image quality assessment (IQA) [12–14]. A broad collection of methods has been proposed in the last few years. Earlier image aesthetic assessment methods rely on handcrafted features to extract the aesthetic attributes of images [1,2]. These hand-crafted features include global features, such as saturation, brightness and hue, local features, such as contrast, and general features, such as SIFT and the Fisher vector [3]. With the advent of deep convolutional neural networks, deep CNNs have been deployed in image aesthetic quality assessment and have proved to be effective. For instance, Lu et al. [4] proposed a double-column DNN architecture, the RAPID-Net, which extracts global features from the whole image and local features from a randomly cropped patch. To capture more high-resolution fine-grained details, Lu et al. [5] proposed a deep multi-patch aggregation network, the DMA-Net. The DMA-Net extracts aesthetic features from a bag of randomly cropped patches, and uses statistics and sorting network layers to aggregate these multiple patches. Later, researchers found that processing images in the data augmentation stage entails loss of the original information of the image, which will affect the performance of the network. Thus, Mai et al. [6] added an adaptive spatial pooling layer onto the regular convolution to handle images with original sizes. In a similar vein, Ma et al. [15] proposed the non-random selection of multiple patches to extract image features according to the significance of the image without any transformation. Jia et al. [10] combined padding with ROI pooling to handle the arbitrary sizes of batch inputs.

Since previous work has focused only on the aesthetic features of images and ignored image content, some researchers have resorted to the use of semantic information to enhance the accuracy of aesthetic prediction. For example, Kao et al. [9] proposed the use of semantic labels to guide aesthetic assessment. Kong et al. [16] regularized the complicated photo aesthetics rating problem by applying joint learning of meaningful photographic attributes and image content information. However, these methods still cannot cope with the theme criterion bias problem. Using the method of [16], photographic attributes cannot solve the problem of theme criterion bias well. Firstly, the same image can belong to multiple aesthetic attributes, so we cannot uniquely determine the theme of the image through photographic attributes. Secondly, photographic attributes focus on different perspectives to evaluate an image, such as light, color, DOF, etc., rather than the theme. In the method of Kao et al. [9], although semantic labels can guide the aesthetic assessment, the semantic information is used simply as ground truth labels, which cannot fully interact with images. In this paper, we take the tag and challenge themes into account. To fully utilize them, we encode the theme information and combine it with the extracted visual features via an attention mechanism. Experiments undertaken demonstrated the effectiveness of the proposed module.

#### *2.2. Semi-Supervised Learning*

Supervised learning methods need to use labeled data to build models. However, labeling training data in the real world may be expensive or time-consuming. A semi-supervised learning (SSL) model can allow the model to integrate part or all of the unlabeled data in its supervised learning to solve this inherent bottleneck. The goal is to maximize the learning performance of the model through information revealed by both limited labeled images and sufficient unlabeled images. The study of semi-supervised learning (SSL) has a long history with various models being proposed. For example, Zhang et al. [17] proposed a simple learning principle, MixUp, to reduce memory and sensitivity to antagonistic examples of large deep neural networks. Berthelot et al. [18] unified the mainstream methods of semi-supervised learning and proposed MixMatch that guesses low-entropy labels for unlabeled examples and uses MixUp to mix labeled and unlabeled data. Laine et al. [19] introduced self-ensembling, in which the output of the network in different periods of training is used to form a consistent prediction of unknown tags. However, since the target changes only once in each epoch, temporal ensembling becomes very clumsy when

learning huge datasets. To overcome this problem, Tarvainen et al. [20] proposed Mean Teacher, a method that defines the weight of the teacher model parameters obtained in each round as an exponential moving average. Iscen et al. [21] proposed a label propagation method based on transductive learning, which can assign pseudo-labels to unlabeled data using a k-nearest neighbor graph. Although based on this method, our proposed method represents an improvement in terms of cumulative confidence. The experimental results demonstrate that our improved method can solve the problems caused by label noise.

Although SSL has been evaluated for various tasks, few investigations have considered its application to an image aesthetic prediction task. Image aesthetic prediction is highly subjective and complex. Annotating aesthetic labels is a time-consuming and error-prone task. To reduce reliance on manual annotation, it is crucial to develop the SSL method to leverage dependencies on labeled data. Therefore, in this paper, we propose a themeaware semi-supervised method which exhibits equivalent performance to that of a fully supervised method.

#### **3. Methodology**

In this section, we first describe preliminary details and the overall architecture of our method. Then, we introduce each module in detail.

#### *3.1. Preliminaries*

In semi-supervised image aesthetic assessment prediction, a dataset can be expressed as *X* := (*x*1, ... , *xl*, *xl*+1, ... , *xn*) . The dataset contains *l* labeled examples and *u* = *n* − *l* unlabeled examples. The labeled examples *xi* for *i* ∈ *L* := (1, ... , *l*), denoted by *XL*, are labeled according to *YL* := (*y*1, ... , *yl*) with *yi* ∈ *C*, where *C* := (1, ... , *c*) is a discrete label set for *c* classes. The remaining unlabeled examples are denoted as *XU* = *xl*+1, ... , *xn*. The goal in semi-supervised learning (SSL) is to use all examples *X* and labels *YL* to train a classifier that maps previously unseen samples to class labels.

In supervised learning, the network is trained by minimizing the following supervised loss term:

$$L\_s(X\_{L}, Y\_{L}; \theta) := \sum\_{i=1}^{l} loss(f\_{\theta}(x\_i), y\_i), \tag{1}$$

where *θ* is the parameters of the network and *f<sup>θ</sup>* is the forward function of the network.

The supervised loss applies only to labeled data in *XL*. The loss function in classification is cross-entropy (CE) loss under standard conditions, which is given by

$$\text{loss}(p\_\prime y) := \sum\_{i=1}^{l} (-y\_i \log p\_i)\_\prime \tag{2}$$

where *y* is the label and *p* is the predict logits.

In semi-supervised learning, pseudo-labeling is the process of using the labeled data trained model to assign labels for unlabeled data. The additional pseudo-label loss term is defined as follows:

$$L\_p(X\_{l\varPi}, Y\_{l\varPi}; \theta) := \sum\_{i=l+1}^n \ loss(f\_\theta(\mathbf{x}\_i), y\_i),\tag{3}$$

where *YU* := (*yl*+1, ... , *yn*) denote the collection of pseudo-labels for *XU*, and the *loss* can be any supervised loss function, such as cross-entropy.

#### *3.2. Overall Architecture*

An overview of our proposed framework is illustrated in Figure 1. Our training is divided into two steps: a representation learning step and a label propagation step. These two steps are iteratively trained. In the representation learning step, we train the theme-aware attention network in a fully supervised fashion on the *l* labeled examples. The theme-aware attention network generates two outputs: an embedding output ˆ *fv* and

a category prediction output. In the label propagation stage, we construct a k-nearest neighbor graph through the embedding output ˆ *fv* and perform label propagation on the training set. The known labels *YL* are propagated from *XL* to *XU*, creating pseudo-labels *YU*. Then, we estimate confidence scores reflecting the uncertainty of each unlabeled example. The confidence scores are then used as loss weights during the representation learning stage. Finally, we inject the obtained labels into the representation learning step. By iteratively applying the label propagation and representation learning steps, our model builds a good underlying representation and trains an accurate classifier for the image aesthetic prediction task.

**Figure 1.** Overall architecture of our theme-aware semi-supervised image aesthetic quality assessment. First, in step one, we train our theme-aware attention network (TAAN) using a small amount of labeled data in a supervised fashion. In step 2, we use a label propagation with cumulative confidence algorithm (LPCC) to transduct the pseudo-labels for unlabeled data. We extract the features of the entire training set and compute a k-nearest neighbor graph. Then we propagate labels by transductive learning and train the theme-aware attention network (TAAN) on the entire training set. These two steps are iteratively trained. When testing, we send the input image directly into the trained TAAN model to obtain the predicted aesthetic quality. More detailed illustrations of label propagation with cumulative confidence algorithm (LPCC) can be found in Algorithm 1.

#### *3.3. Theme-Aware Attention Network*

In recent years, the attention mechanism has been shown to be effective in capturing important information from raw features in either linguistic or visual representations [22]. In contrast to the above approaches, we propose theme-aware attention to jointly exploit attention mechanisms to encode the theme features. Inspired by the success of self-attention, the proposed theme-aware attention module can capture the complex interactions between the theme features and different spatial locations in the input image.

The pipeline of our proposed theme-aware attention network (TAAN) is shown in Figure 2, which consists of the following three parts: an image feature extractor (backbone), a self-attention-based theme encoder and a residual connection module. Given the image, the image feature extractor firstly extracts high level features. Then these features are sent into the self-attention-based theme encoder. Finally, the visual features and the theme-based features are combined via a residual connection module.

**Algorithm 1** Label propagation with cumulative confidence.

**Figure 2.** Details of our theme-aware attention network (TAAN). The TAAN consists of an image feature extractor (backbone), a self-attention-based theme encoder and a residual connection module.

The image feature extractor is a residual network with 18 layers, as described in [23], pretrained on ImageNet [24]. Images in the AVA dataset not only have semantic tag information (such as Macro, Animals and Portraiture), but also have challenge information (such as Fairy Tales, Flowers, Black and White, Street Photography). The tag information and challenge information both encode the theme information. Thus, we turn the tag information and challenge information into one-hot codes, and then process the one-hot codes with a fully connected layer to extract the theme features. Given the extracted visual feature *fv* and theme features *ftheme*, the self-attention-based theme encoder first produces a set of query, key and value pairs by linear transformations as *q*<sup>1</sup> = *Wq ftag*, *k* = *Wk fv*, *q*<sup>2</sup> = *Wq*<sup>2</sup> *fchallenge*, *v* = *Wv fv*, where *Wq*, *Wq*2, *Wk*, *Wv* are part of the model parameters to be learned. Then the tag-theme-based attention and the challenge-theme-based attention are computed as follows:

$$\begin{aligned} \alpha\_{tag} &= \operatorname{Softmax}(q\_1^T k) \\ \alpha\_{\text{challenge}} &= \operatorname{Softmax}(q\_2^T k), \end{aligned} \tag{4}$$

where *αtag* and *αchallenge* denote the tag-theme-based attention and the challenge-themebased attention, respectively. Then the final theme-attentive features *v*ˆ are computed as follows:

$$
\psi = \mathfrak{a}\_{\text{tag}} \times \mathfrak{v} + \mathfrak{a}\_{\text{chall}\text{}}} \times \mathfrak{v} \tag{5}
$$

We then combine the theme-attentive features with visual features via a residual connection. This allows the insertion of the proposed module into any backbone network without disrupting its initial behavior. The operations can be defined as follows:

$$f\_v = \vartheta + f\_v \tag{6}$$

where *v*ˆ is the theme-attentive features, *fv* is the extracted visual feature, and ˆ *fv* denotes theme-attentive features with residual features.

#### *3.4. Label Propagation with Cumulative Confidence Algorithm*

The label propagation algorithm is an iterative process for semi-supervised learning. More specifically, we first construct a nearest neighbor graph and perform label propagation on the whole training set. Then, we calculate an entropy weight reflecting the uncertainty of label propagation for each unlabeled example. Inspired by [25], we believe that the results obtained from early propagation should also be considered as a constraint, so we propose a cumulative confidence weight to improve the traditional label propagation [21]. Finally, we inject the obtained pseudo-labels into the network training process. This method is described in detail below; the process of the proposed approach is demonstrated in Algorithm 1.

**K-nearest neighbor search for the graph.** Given an image feature matrix ˆ *fv* with dimensions (*n*, *dim*), we first calculate the similarity between every two points (the Euclidean distance or cosine similarity can be used).

**Create the adjacency matrix of the graph.** For the first *k* nearest neighbors of each point, the similarity is the weight of the edge, and the weight of the edge after more than *k* is set to 0. A sparse affinity matrix *A* ∈ R*n*×*<sup>n</sup>* is constructed as follows:

$$a\_{i\bar{j}} = \begin{cases} [f\_{v\_i}^T f\_{v\_j}]^\gamma \downarrow & \text{if} \quad i \neq j \land f\_{v\_i} \in \text{KNN}(f\_{v\_j});\\ 0, & \text{otherwise.} \end{cases} \tag{7}$$

where *KNN* denotes the set of the first *k* nearest neighbors in *X*, and *γ* is a parameter following work on a manifold-based search [26]. So far, we obtain the adjacency matrix *A*.

**Normalize the graph.** Since the full affinity matrix is not tractable, it may lead to the following problems: node *a* is the k-nearest neighbor of node *b*, but node *b* is not the k-nearest neighbor of node *a*, so we symmetrize it and turn it into a real undirected graph. The operation is defined in Equation (8). Then we use regularization of the Laplace matrix for the adjacency matrix *A* to build its symmetrically normalized counterpart *A*\*, which is defined in Equation (9);

$$A = A + A^T,\tag{8}$$

$$A^\* = D^{-1/2} A D^{-1/2} ,\tag{9}$$

where *A* is the adjacency matrix, *D* is the degree matrix of *A*, which is defined as *D* := *diag*(*A*1*n*), where 1*<sup>n</sup>* is the all-ones n-vector , and *A*<sup>∗</sup> is the normalized adjacency matrix.

**Diffusion for transductive learning [27].** The label matrix *Y*(*nc*)is defined with elements:

$$\mathcal{Y}\_{ij} = \begin{cases} 1, & \text{if } \quad i \in L \land y\_i = j; \\ 0, & \text{otherwise.} \end{cases} \tag{10}$$

where *L* represents the index of labeled data. This means that the rows of the label matrix *Y* corresponding to the labeled examples are one-hot encoded labels. The remaining elements are zero. The diffusion process is equivalent to the solution of linear equations:

$$(I - \mathfrak{a}A^\*)Z = \Upsilon \tag{11}$$

where *α* is the adjustable parameter and *I* is the identity matrix. Because matrix (*I* − *αA*∗) is positive-definite, we can use the conjugate gradient (CG) method to solve the linear system. This solution is known to be faster than the iterative solution. Finally, we infer the pseudo-labels:

$$Z^\* = \operatorname{normalize}((I - \alpha A^\*)^{-1}Y) \tag{12}$$

$$Y\_{\rm II} = \operatorname\*{argmax}(Z^\*)\tag{13}$$

where *Z*∗ is the row-wise normalized counterpart of *Z* and *YU* are the predicted pseudo-labels.

**Entropy weight.** We need to evaluate the reliability of the predicted pseudo-labels. Firstly, we consider the credibility of a single round. The prediction matrix *Z* we obtained has a probability prediction value for the category to which each sample point belongs. For points with small entropy, we think it is more credible, while for points with large entropy, we think it is less credible, so our weight is calculated by the following:

$$
\omega = 1 - \frac{H(Z^\*)}{\log c} \tag{14}
$$

where *Z*∗ is the row-wise normalized counterpart of *Z* and *c* is the number of classes, so log(*c*) is the maximum possible entropy.

**Cumulative confidence weight.** To improve the fault tolerance and reliability of label propagation, we propose a second weight, the cumulative confidence weight *Fconf* . We maintain an array *Fpre* to record the average value of the previous prediction. *Fpre* reflects the reliability of the prediction (higher *Fpre* means higher reliability). *Fconf* denotes the similarity between *Fpre* and the pseudo-labels in each epoch; it can be directly multiplied with the previous entropy weight. We have also designed three similarity functions and can manually select the appropriate one to deploy to the final architecture. *Fconf* is calculated by the following equation:

$$F\_{conf} = similarity(Y\_{lI\prime}F\_{prc})\tag{15}$$

$$F\_{prc} = \frac{epoch \times F\_{prc} + Y\_{II}}{epoch + 1} \tag{16}$$

where *YU* denote the pseudo-labels of unlabeled data. So, the final loss with weight is calculated by the following formula:

$$L\_p(X\_{l\varPi}, \varPsi\_{l\varPi}; \theta) := \sum\_{i=l+1}^n \log(f\_\theta(\mathbf{x}\_i), y\_i) \times \omega\_i \times F\_{conf}^i. \tag{17}$$

where *XU* denote the image features of unlabeled data, *YU* denote the pseudo-labels of unlabeled data, *ω<sup>i</sup>* denote the entropy weights in index *i* and *F<sup>i</sup> conf* denote the cumulative confidence weights in *i*.

#### **4. Experiments**

#### *4.1. Datasets*

**AVA.** Aesthetic Visual Analysis (AVA) [28] is a large-scale database for image aesthetics quality assessment. The images of this dataset are crawled from www.DPChanllenge.com (accessed on 5 May 2022). It contains more than 255,000 images. The aesthetic assessment is scored by 78 to 549 individuals, and the scores given by the voters are from 1 to 10. The AVA dataset provides 66 kinds of semantic tags and 1409 kinds of style tags. Each image in the AVA dataset has 0 to 2 semantic tags and belongs to one specific challenge

theme. We follow the official dataset partition as in [28], randomly selecting 235,508 images as the training set, and 20,000 images as the testing set.

**Photo.net.** The Photo.net dataset [1] contains about 20,278 images. Unlike the AVA dataset, it contains only aesthetic labels. The aesthetic assessment is scored by at least 10 individuals, and the scores given by the voters are from 1 to 7. For some images, only the mean score and standard deviation are given and voting information is lost. Since the website has been updated several times, there are only 17, 253 images that can be downloaded. The Photo.net dataset contains no theme information. Thus, similar to previous work [10], we only use Photo.net as a test set.

**CUHK.** CUHK [2] is a small-scale dataset that can clearly distinguish high-quality and low-quality images. We only use photos that have a clear consensus on their quality. The images of this dataset are also crawled from www.DPChanllenge.com (accessed on 5 May 2022). About 3000 images (half of the photos) were used for testing. For the same reason as the Photo.net dataset, we only use the test dataset of CUHK to evaluate our model.

#### *4.2. Implementation Details*

We implemented our method using the PyTorch framework. We used the Adam optimizer with *<sup>β</sup>*<sup>1</sup> = 0.9 and *<sup>β</sup>*<sup>2</sup> = 0.999, and the learning rate was 1 × <sup>10</sup><sup>−</sup>5. Our GPU uses GeForce RTX 3080Ti.

**Networks.** We used many backbone networks in our experiment. For VGG, ResNet and DenseNet, we used the implementation provided in the Torchvision project [29]. For Swin-T, we used the implementation provided in https://github.com/WZMIAOMIAO/ deep-learning-for-image-processing (accessed on 5 May 2022). In our experiment, the input image size was [3, 224, 224]. When we used ResNet18, ResNet34 or VGG16, the output feature dim was 512; when we used ResNet50, ResNet101 or ResNet152, the output feature dim was 2048; when we used Swin-T, the output feature dim was 768. Then we used the flattened feature as our image feature vector.

**Hyper-parameters.** We trained 10 epochs for step one (i.e., the representation learning step) and 20 epochs for step two (i.e., the label propagation step). Step two uses the embedding output ˆ *fv* of step one to infer the pseudo-labels. For step one, the mini-batch size is a certain number which is determined by the depth of the network backbone (usually 32 or 64). For step two, the mini-batch size needs to use two steam samplers: the labeled data sampler and the unlabeled sampler. The unlabeled data sampler guarantees that all unlabeled data will be traversed, while the labeled data sampler constantly iterates over the labeled data. The total mini-batch size *B* = *Bl* + *Bu*. *Bl* is the labeled mini-batch size and *Bu* is the unlabeled mini-batch size. The value of *Bl* is usually half that of *B*. In our TAAN network, we set the scale factor *α* = 1. In our LPCC algorithm, the diffusion parameters were set as follows: the value of *γ* was set to 3, *k* was set to 50 and the CG iteration was set to 20.

#### *4.3. Ablation Studies*

**Effectiveness of the theme-aware attention network.** The proposed method employs themes as privileged information to improve the performance. To evaluate the performance of our proposed theme-aware attention network, we compared the proposed module with the following models:


The comparison results are shown in Table 1. To prove the effectiveness of the proposed module, we tested it both in a full supervised condition and in a semi-supervised condition. From the Table, we make the following observations. First, the proposed

ResNet18 + TAAN had the best performance. For example, ResNet18 + TAAN achieved 76.6% in full supervised method, while the other two models achieved 76.28% and 76.32%, respectively. Similar results were also found for the semi-supervised learning method. Second, compared to ResNet18, ResNet18 + Theme achieved better performance, using both the fully supervised method and the semi-supervised method, which demonstrates the effectiveness of the theme information. Third, ResNet18 + TAAN performed better than ResNet18 + Theme, which demonstrates the superiority of the attention mechanism. This is because the attention mechanism makes the visual features and theme features fully interact with each other.

**Table 1.** Accuracy (%) of different modules. For the semi-supervised method, the value is the accuracy of step 2 (*δ* = 1).


**Effectiveness of cumulative confidence weight.** We propose a cumulative confidence weight to estimate the fault tolerance and reliability of the samples. We tested three different similarity estimation methods for the cumulative confidence weight, i.e., the linear function, the square function and the sigmoid function. We first define distance

$$d = \mathbb{Y}\_u - F\_{prc} \tag{18}$$

where *YU* are the pseudo-labels of all the data items, *Fpre* is the average value of the previous prediction, and *d* means the distance between the current predicted pseudo-label *YU* and the average previous prediction value *Fpre*. The linear function is defined as follows:

$$
\text{similarity}\_{linear} = 1 - d \tag{19}
$$

The square function is defined as follows:

$$
\text{similarity}\_{square} = 1 - d^2 \tag{20}
$$

The sigmoid function is defined as follows:

$$
gamma\_{signed} = 1 - \frac{1}{c^{(0.5-d)\times\lambda}}\tag{21}$$

where *λ* controls the slope of the sigmoid function. To separate the predicted values into two categories, we use *λ* = 10 as our final method. Table 2 illustrates the comparison results. The base-line model in Table 2 did not include a cumulative confidence weight. From the table, we can draw the following conclusions. First, adding a cumulative confidence weight can result in better performance. For example, the performance of the base-line model was 75.01%; by adding a cumulative confidence weight (using the linear similarity function for the cumulative confidence weight) the model was able to achieve at least 75.96%. Second, it can be seen that using the square similarity function resulted in slightly better performance than for the other two similarity functions. Thus, in this paper, we use the square function as the similarity function for the cumulative confidence weight.


**Table 2.** Accuracy (%) of different similarity strategies in the cumulative confidence algorithm (*δ* = 1).

#### *4.4. Experiments on Different Label Rates*

To evaluate how good the proposed model is at using unlabeled images, we trained our model under different labeling rates. As can be seen from Table 3, with the 90% label missing (i.e., the labeling rate was 10%), step one achieved 73.86% accuracy. However, with the help of unlabeled images, in step two, our model improved the accuracy to 76.12%. This demonstrates that the proposed method consistently benefits from additional unlabelled images. Similar results were also found for other labeling rates, such as 5% and 2%. Figure 3 shows the t-SNE visualization of the embedded output ˆ *fv* under different labeling rates. Purple dots represent unlabeled images, yellow dots represent labeled low-quality images and green dots represent labeled high-quality images. From the figure, we can easily make the following two observations: First, our method can cluster unlabeled data (purple) with labeled data under these three labeling rates. Thus we can easily deploy our LPCC algorithm. Second, our method has a robust discrimination effect for data under different labeling rates.

**Figure 3.** Visualization of the features of labeling rate 0.02 (**left**), 0.05 (**middle**) and 0.1 (**right**) on the test set by TSNE. Purple dots represent unlabeled images, yellow dots represent labeled low-quality images and green dots represent labeled high-quality images.


**Table 3.** Accuracy (%) of experiments on different labeled rates.


#### *4.5. Extension to Different Backbones*

Our model can use a variety of different feature extractors. Therefore, we used different pre-training models as our backbones. We chose VGG16, ResNet18 [23], ResNet34, ResNet50, ResNet101, DenseNet121 [30] and Swin Transformer-T [31] to experiment on the label rate of 0.05 with the AVA dataset. All networks were pretrained on ImageNet [24]. The performance of different CNN feature extractors is given in Table 4.


**Table 4.** Accuracy (%) on different backbones. For the semi-supervised method, the value is the accuracy of step 2.

It can be seen that with increase in the complexity of the model, the accuracy increases. Figure 4 illustrates the embedded features ˆ *fv* with different backbones. We can also clearly see that the discrimination of features extracted with a better backbone framework is significantly higher.

**Figure 4.** Visualization of the fc-features of ResNet18 (**top left**), ResNet34 (**top right**), ResNet50 (**bottom left**) and ResNet101 (**bottom right**) on the test set by TSNE. Purple dots represent lowquality images and yellow dots represent high-quality images.

#### *4.6. Performance Evaluation*

To demonstrate the effectiveness of our method, we performed a comparative evaluation with existing approaches on the AVA dataset. It should be noted that the existing methods are based on the assumption of full supervision, while our method is a semisupervised method. We selected some mainstream methods for comparison. During the comparative study, it was found that the source codes of [4–6,9] were unavailable and the experimental details were not mentioned. As a result, it might be infeasible to implement them precisely. Thus the experimental data were taken from their paper. For those methods that published the code, such as [7,32,33], we used the same dataset (5% labeling rates) to evaluate their models and to obtain the corresponding experimental data provided in Table 5.


**Table 5.** Comparison with state-of-the-art methods on AVA dataset.

The methods we compared were as follows:


The experimental results are illustrated in Table 5. From the table, we can make the following two observations: First, the semi-supervised accuracy can reach, or even exceed, that of some fully supervised models. For example, MTCNN [9] achieved 75.9% accuracy, while our method achieved 76.82% accuracy with only 5% labeling rates. Second, our semi-supervised accuracy can exceed the current model when using the same labeling rate. For example, MPA [32] and NIMA [7] achieved 70.52% and 74.87% accuracy, respectively, while our method achieved 76.82% accuracy with only 5% labeling rates. The reason for the difference is clear: the lack of data leads to the degradation of the other models' performance, while our proposed model can improve performance by using a large quantity of unlabeled data.

#### *4.7. Experimental Results on Photo.net and CUHK Dataset*

Tables 6 and 7 show the comparison results for the Photo.net and CUHK datasets, respectively. As stated earlier, the Photo.net and CUHK datasets are both small datasets and

have no theme information. Thus, we used the AVA dataset to train the model, and tested on the Photo.net and CUHK datasets. We used the published Pytorch code of NIMA [7] and MUSIQ [33] to implement 5% labeling rates for the Photo.net and CUHK datasets; these are compared with our method in Tables 6 and 7. From Tables 6 and 7, we can see that our proposed method outperformed previously used methods by using a large quantity of unlabeled data. This also demonstrates that our proposed model produces good generalization performance for different datasets.

**Table 6.** Comparison with state-of-the-art methods on the Photo.net dataset.


**Table 7.** Comparison with state-of-the-art methods on CUHK dataset.


#### *4.8. Discussion of Experiment on Labeled Data Sensitivity*

To explore whether the proposed method is sensitive to the labeled data, we randomly divided the labeled data under labeling rate 5% into five groups: split 1, 2, 3, 4 and 5. We used these groups of labeled data to train our model and record the best accuracy. The experimental results are shown in Table 8. Evidently, no matter which split we used, the accuracy did not fluctuate significantly. Therefore, we hold that our model is insensitive to the selection of labeled data.

**Table 8.** Experiment on sensitivity analysis. Split 1, 2, 3, 4 and 5 are random labeled data splits under labeling rate 5%. The best accuracy (%) of each split is recorded.


#### *4.9. Computational Complexity*

#### 4.9.1. Theoretical Analysis

Our training was divided into two steps. In the first step, we trained our theme-aware attention network (TAAN) using small quantities of labeled data in a supervised fashion. In the second step, we used the label propagation with cumulative confidence algorithm (LPCC) to transduct the pseudo-labels for unlabeled data. These two steps were iteratively trained. Since label propagation tends to be viewed as entailing considerable complexity, we mainly analyzed the computational complexity of label propagation theoretically.

The computational complexity of traditional label propagation is mainly composed of KNN search and creation of the graph. Suppose the data scale is n, if no optimization measures are taken, the computational complexity of the KNN search is O(*n* × *n*). This is because the KNN search needs to traverse n features to find the k most similar vectors. The floating-point operation required by a vector point multiplication is proportional to the vector dimension. Suppose the vector dimension is m, the computational complexity of the KNN search is

$$FLOPs = n \times n \times m \tag{22}$$

Considering the computational cost is quite high, we use the inverted file system (IVF) and product quantification (PQ) in the Faiss library to reduce the computational complexity of label propagation.

**Using the inverted file system(IFS) to optimize the KNN search:** we index the entire dataset and cluster it into several subspaces. When we query a vector, we first calculate the subspace of query vector, and then search in the corresponding subspace. Suppose that the average size of our subspace is <sup>1</sup> *<sup>s</sup>* of the original space size, the computational complexity of the KNN search can be reduced to:

$$FLOPs = \frac{n \times n \times m}{s} \tag{2.3}$$

**Using product quantification(PQ) to further optimize the KNN search:** the details of the product quantification are illustrated in Figure 5. As can be seen from Figure 5, we assume that the vector dimension m is 128 (our whole dataset is *n* × 128). We split each vector into four sub-vectors with 32 dimensions and group the n sub-vectors (in four columns) into 256 classes, respectively. The sub-vectors of each data item are represented by four class centers (such as [12, 45, 240, 48]); thus, each vector can be saved in four bytes (int type). We need to calculate the distance table in advance. Building the distance table requires 4 × 256 × 32 floating-point operations, which are independent of *n*. Once the distance table is built, our distance query calculation (needing *n* × 4 times) is a table lookup operation, which takes much less time than performing floating-point multiplication calculations, so we also need to divide by a constant *c* to derive the computational complexity. The final computational complexity can be reduced to:

$$FLOPs = \frac{n \times 4 + 4 \times 256 \times 32}{s \times c} \tag{24}$$

**Figure 5.** Details of Product Quantification.

When *n* is particularly large (large-scale data), 4 × 256 × 32 can be ignored. When *s* is set to 10, *c* is set to 5, and *m* is 512, the IFS + PQ algorithm can be 6400 times faster than using a violent search method. To verify the reliability of the theoretical analysis, we tested one epoch running time for our whole AVA dataset. The running times for each step of our method are reported in Table 9. It can be seen clearly from the table that the time required for label propagation is negligible compared with the time in training.

**Table 9.** The one epoch running time in each step of our method are reported in the table. When training, we used the whole training set of the AVA dataset. When inferring, we used the whole testing set of the AVA dataset. Step 1 DNN means deep neural network pass of step 1. Step 2 LP means label propagation of step 2. The items explain in more detail what is done at each step. FP means forward pass and BP means back propagation.


4.9.2. Inference Computational Cost Comparison

We analyzed the time consumption to compare the computational complexity of different methods. Thus, we only compared the computational complexity with methods that published the code, such as NIMA, MPA and MUSIQ. Table 10 shows the computational complexity results. The timings of an image forward pass are reported in the table. Our inference Pytorch implementation and TensorFlow implementation were tested on an Intel i7-11700H @ 2.5 GHz with 32 GB memory and 8 cores, and NVIDIA 3080Ti GPU. From the table, we can see that our method has similar running time to NIMA and MPA when using the same ResNet18 backbone.

**Table 10.** Comparison of image forward pass running time for different methods (ResNet18 backbone).


#### **5. Conclusions and Discussion**

In this paper, we propose a theme-aware semi-supervised architecture for image aesthetic quality assessment with the aim of reducing the dependence on image label annotation and making full use of a large number of unlabeled images on the network. For the noise label problem encountered in the process of label propagation, we propose a cumulative confidence algorithm by improving the traditional label propagation algorithm. We applied it to our image aesthetic quality assessment task, and achieved satisfactory results. We also found that our theme-aware architecture can solve the problem of theme sensitivity in image aesthetic quality assessment. The experimental results show that our method is robust to different label rates, different labeled data selection and different datasets.

Although our method achieves promising results, several issues need to be considered in our future research. First, we will continue to focus on how to use EMD loss for the label propagation algorithm to improve the accuracy of semi-supervised learning. Second, to make good use of the collaborative attention between images and other information, such as user comments, we will start from a multi-modality position to seek better solutions. We will also explore new semi-supervised algorithms, such as curriculum learning, to improve the existing label propagation algorithms.

**Author Contributions:** Investigation, G.L.; Methodology, X.Z. (Xiaodan Zhang) and X.Z. (Xun Zhang); Resources, G.L.; Software, X.Z. (Xun Zhang) and Y.X.; Validation, G.L.; Writing—original draft, X.Z. (Xiaodan Zhang); Writing—review & editing, X.Z. (Xiaodan Zhang). All authors have read and agreed to the published version of the manuscript.

**Funding:** This work was supported in part by the National Natural Science Foundation of China under Grant No. 62001385,62002290, in part by the Key RD Program of Shaanxi under Grant 2021ZDLGY15-03, and in part by the Project funded by China Postdoctoral Science Foundation (Grant No. 2021MD703883).

**Institutional Review Board Statement:** Not applicable.

**Informed Consent Statement:** Not applicable.

**Data Availability Statement:** The data presented in this study are available in [1,2,28].

**Conflicts of Interest:** The authors declare that they have no known competing financial interests or personal relationships that could have appeared to influence the work reported in this paper.

#### **References**


### *Article* **An Improved Soft-YOLOX for Garbage Quantity Identification**

**Junran Lin, Cuimei Yang, Yi Lu, Yuxing Cai, Hanjie Zhan and Zhen Zhang \***

School of Computer Science and Engineering, Huizhou University, Huizhou 516007, China; hzurang@gmail.com (J.L.); meikoyoung@gmail.com (C.Y.); ly97264833@gmail.com (Y.L.); hzucyx@gmail.com (Y.C.); 1914080902532@stu.hzu.edu.cn or zhanhanjie123@gmail.com (H.Z.) **\*** Correspondence: zzsjbme@sjtu.edu.cn; Tel.: +86-182-1726-7715

**Abstract:** Urban waterlogging is mainly caused by garbage clogging the sewer manhole covers. If the amount of garbage at a sewer manhole cover can be detected, together with an early warning signal when the amount is large enough, it will be of great significance in preventing urban waterlogging from occurring. Based on the YOLOX algorithm, this paper accomplishes identifying manhole covers and garbage and building a flood control system that can automatically recognize and monitor the accumulation of garbage. This system can also display the statistical results and send early warning information. During garbage identification, it can lead to inaccurate counting and a missed detection if the garbage is occluded. To reduce the occurrence of missed detections as much as possible and improve the performance of detection models, Soft-YOLOX, a method using a new detection model for counting, was used as it can prevent the occurrence of missed detections by reducing the scores of adjacent detection frames reasonably. The Soft-YOLOX improves the accuracy of garbage counting. Compared with the traditional YOLOX, the mAP value of Soft-YOLOX for garbage identification increased from 89.72% to 91.89%.

**Keywords:** garbage quantity identification; YOLOX; NMS; Soft-NMS

**MSC:** 68T45

#### **1. Introduction**

With the development of science and technology competing with the gradual improvement of urban construction levels, image recognition is increasingly being applied to urban informatization and digital construction. Among the recognition methods, the YOLOX detection framework is particularly famous. It is widely used in urban object detection [1], pedestrian detection [2], and other environments due to its advantage of fast response and high precision.

Modern cities are divided into regions and assigned sanitation workers for garbage cleaning using a grid management method [3] to improve the efficiency of urban management and sanitation. Workers only need to conduct regular inspections and cleanings of the place they are responsible for to ensure basic hygiene everywhere in the city. However, the regionalization management level is still insufficient. The fixed personnel allocation method cannot adjust the number of people according to the dynamic change in the amount of garbage in each area. There is a situation where the garbage accumulates in some areas, but there are insufficient sanitation workers there. However, sanitation workers in other places have few things to do. Relying on regular inspections by sanitation workers cannot keep up with the real-time changes in the amount of garbage. If the sanitation workers do not clean in time, accumulation of garbage will happen. Over time, garbage accumulation has become a hidden danger of flood disasters. The work of dealing with urban floods has remained a tricky problem for a long time.

With the help of target detection technology and the support of the urban public surveillance system, real-time monitoring is prevailing in the maintenance of various areas,

**Citation:** Lin, J.; Yang, C.; Lu, Y.; Cai, Y.; Zhan, H.; Zhang, Z. An Improved Soft-YOLOX for Garbage Quantity Identification. *Mathematics* **2022**, *10*, 2650. https://doi.org/10.3390/ math10152650

Academic Editors: Jianping Gou, Weihua Ou, Shaoning Zeng and Lan Du

Received: 20 June 2022 Accepted: 25 July 2022 Published: 28 July 2022

**Publisher's Note:** MDPI stays neutral with regard to jurisdictional claims in published maps and institutional affiliations.

**Copyright:** © 2022 by the authors. Licensee MDPI, Basel, Switzerland. This article is an open access article distributed under the terms and conditions of the Creative Commons Attribution (CC BY) license (https:// creativecommons.org/licenses/by/ 4.0/).

roads, and sewer manholes. The monitoring capability of an urban flood control system highly depends on the in-time identification of a garbage's type and quantity and the collection of data. An efficient urban flood control system can help sanitation workers check and clean up the underlying danger areas such as sewer manhole covers. Committed to the goals of improving the city's ability to prevent floods and waterlogging while reducing the work intensity of sanitation workers, a method based on the YOLOX detection framework was designed to reflect the garbage accumulation in the flood control area. The working mechanisms of analyzing the monitoring images and identifying the type and quantity of garbage on the manhole cover have improved the efficiency and ability of urban sanitation cleaning, flood controlling, and waterlogging prevention. If the garbage is occluded by other garbage during detection, it causes inaccurate counting and missed detection. A well-designed detection scheme is sufficient to solve such problems. To further improve the detection accuracy, reduce the missed detection rate, and give early warning signals more precisely, our team proposed a new detection method, Soft-YOLOX, to solve the problems above.

#### **2. Related Work**

Garbage counting is a basic application scenario in target detection [4], and many machine learning methods have been proposed in this field to solve target detection and counting problems. Traditional machine learning cases include multi-vehicle counting algorithms based on the Haar feature principle [5], SVM based on HOG [6] and LBP [7] features, and others. These traditional machine learning target detection algorithms mainly rely on manual feature extraction. First, the features are extracted from the image, then a classifier is built to classify, and finally, the wanted target is obtained. However, most of these traditional target detection algorithms do not have high accuracy and good generalization ability.

With the continuous development of artificial intelligence, deep learning technology in image recognition [8] has been relatively mature [9]. For example, great achievements have been made in the fields of face recognition [10], medical image recognition [11], remote sensing image classification [12], ImageNet classification and recognition, traffic recognition, and character recognition. Deep learning can extract image features and achieve the function of image classification [13] and recognition after a large-scale training. Therefore, deep learning is a very effective and universal technology in the field of target detection. Currently, target detection algorithms using deep learning methods are mainly divided into three categories, and the difference is whether there is a region proposal in the algorithm. The first category is multi-stage algorithms, such as R-CNN [14] and SPPNet [15]. The second is two-stage algorithms such as Fast R-CNN [16], Faster R-CNN [17], Mask R-CNN [18], and Light-Head R-CNN [19]. The third is single-stage algorithms, including YOLOV1 [20], YOLOV2 [21], YOLOV3 [22], SSD [23], Retina U-Net [24], CenterNet [25], FSAF [26], FCOS [27], YOLOV4 [28], and YOLOX [29]. The detection performance of the multi-stage algorithm and the two-stage algorithm is outstanding, but the detection rate in practical applications is not as good as that of the single-stage algorithm. Although the single-stage algorithm has a fast recognition speed, the accuracy rate is not high, and there are still cases of missed detection when the target to be detected is occluded. Therefore, our goal is to improve the YOLOX model and devise a solution that can address the above problems.

In traditional YOLOX, non-maximum suppression (NMS) sets the score of adjacent detection frames (the number of adjacent detection frames containing similar targets is greater than or equal to 2) to 0, resulting in the final output missing some of the target objects. This mechanism leads to missed detection that reduces detection accuracy [30]. The Soft-NMS algorithm attenuates the scores of the above types of adjacent detection frames, rather than directly reducing their scores to 0. As long as the final score of the adjacent detection frame is greater than a certain threshold, the final output detection frame meets the expected result. The improved YOLOX is called Soft-YOLOX (using Soft-NMS instead of NMS in YOLOX). After Soft-NMS processing, the mAP value of YOLOX was 91.89%, which is 2.17% higher than that of the Original-NMS method. In the real-time detection case, the FPS reached 15.46. To further ensure the effectiveness of the improvement, we also used Soft-YOLOX to compare with other target detection algorithms, such as YOLOV4, Fast R-CNN, SSD, YOLOV5, etc. It can be seen from the mAP value in the comparison that Soft-YOLOX has greater detection performance and a lower missed detection rate. We make the following contributions:


### **3. Methods**

#### *3.1. YOLOX and NMS Algorithms*

The most significant thing in the YOLOX target detection algorithm is the YOLOX-CSPDarknet53 network structure. Figure 1 shows the network structure of YOLOX-CSPDarknet53. We split the YOLOX-CSPDarknet53 network structure into four parts: input, backbone network, neck, and prediction.

**Figure 1.** YOLOX-CSPDarknet53 network structure diagram.

	- a. The decoupled head is used in YOLOX. Compared with the previous target detection algorithm of the YOLO series, the decoupling head of YOLOX consists of two parts which are implemented separately and integrated at the final prediction;

The original YOLOX model uses NMS to filter out the detection frames with the highest scores in a certain area belonging to the same category. However, only considering the detection frame and its IOU (Intersection over Union) in the calculation process, the elimination mechanism of NMS is very rigid, which easily leads to missed detection. Figure 2 shows the missed detection of a target object.

**Figure 2.** The situation of missed detection using NMS.

As can be seen from Figure 2, the sample is wrong. There are three leaves and a box on the drain cover. After NMS processing, there are cases of missed detection, such as the one leaf that is not detected in the figure. Obviously, the predicted results are not in line with the reality and cannot meet our expectation.

The critical step of accurate counting is meant to detect the targets successfully. When the target objects are blocked by each other, it is easy to cause missed detection. Therefore, we used Soft-NMS instead of the NMS method in the original YOLOX model as an improvement to solve this problem.

#### *3.2. Principle of Soft-NMS Algorithm*

First of all, from a mathematical point of view, the following principles explain the mechanism of NMS removing redundant frames:

$$\text{score}\_{\bar{\mathbf{i}}} = \begin{cases} \text{0,IOU}(\mathbf{M}\_{\bar{\mathbf{i}}}\mathbf{b}\_{\bar{\mathbf{i}}}) \ge \text{threshold of IOU} \\ \text{score}\_{\bar{\mathbf{i}}} \text{IOU}(\mathbf{M}\_{\bar{\mathbf{i}}}\mathbf{b}\_{\bar{\mathbf{i}}}) < \text{threshold of IOU} \end{cases} \tag{1}$$

The scorei is the score of the current detection frame. After multiple tests on the dataset of this experiment, we found that the best threshold for IOU is 0.5.

During the experiment, we further found that when the detection frame with a higher IOU is adjacent to the detection frame with the highest score in all current detection frames, NMS reduces the score of this frame to 0, and then deletes it from the candidate frame set. Like the case in Figure 2, it is likely to cause missed detection. Soft-NMS can solve this problem very well, and its mechanism for removing redundant frames is as follows:

$$\text{score}\_{\text{i}} = \text{score}\_{\text{i}} \,\text{e}^{-\frac{\text{lOL}\left(\text{M}\_{\text{i}}\right)^{2}}{\text{\(\text{l}\)}}}\tag{2}$$

It means that when Soft-NMS encounters a detection frame with a high IOU adjacent to the detection frame with the highest score, it does not directly set the score of the frame to 0. Compared with NMS, Soft-NMS adopts a penalty mechanism, which assigns the multiplication of the score of the current detection frame and the weight function as a penalty score and assigns it to the current detection frame. We used the Gaussian function as the weight function (θ is the parameter of the weight function; after many times of debugging, we defined the value of theta and set it to 0.1 according to the reference [30]):

$$\mathbf{e} - \frac{\mathbf{lCM}(\mathbf{M.b\_i})^2}{\mathbf{d}} \tag{3}$$

The larger the overlapped area of the detection frame with the highest score, the smaller the score this detection frame obtains. Lastly, only those detection frames with scores greater than or equal to 0.5 were left in the frame set, which is the candidate. Thus, Soft-NMS can remove redundant detection frames to reduce the rate of missed detection with effect. The flow chart that describes the Soft-NMS method is shown in Figure 3.

**Figure 3.** Flowchart of Soft-NMS.

The main idea of the Soft-NMS is as follows. At first, find all detection frames with a confidence higher than a threshold set manually from an image (no target object in the detection frame if below). Then, process the detection frames belonging to the same class. Finally, put all these detection frames into an established set S.


After using the Soft-NMS method to process Figure 2, the detection result can be seen in Figure 4.

**Figure 4.** No missed detection after using Soft-NMS.

As can be seen from Figure 4, the correct sample was obtained. There are three leaves and a box on the drain cover. After Soft-NMS processing, the missed detection in Figure 2 disappeared. All objects in the image can be detected correctly. Obviously, the predicted results are in accordance with the reality and can meet our expectations.

#### *3.3. System Framework*

Figure 5 shows the application hierarchy of the system built on the problems studied in this paper.

**Figure 5.** Flowchart of the application deployed by YOLOX.


#### **4. Experimental Datasets and Evaluation Metrics**

The datasets in this paper came from a research group that used a camera to simulate a road surveillance camera in a specific road scene, acquiring the situation of garbage near the sewer manhole covers at different periods and under various weather conditions at a roughly fixed angle. The advantage of doing this is that the trained model can make predictions for different scenarios and has better adaptability. Each video obtained ranged from more than ten seconds to several minutes, and the resolution of all videos was 1365 × 1024. Lastly, videos were split into frames, which were divided into two parts. Both parts contain the above datasets from different periods and under various circumstances. One of the parts was used as training sets and validating sets for training and validation of convolutional networks. The other part was used as a test for the trained model.

The annotation of the dataset was in the Pascal VOC format, and the size of each image was 1365 × 1024. For the YOLOX model, the input image was 640 × 640, thus all images could be preprocessed. There were 1800 images in the processed dataset with different types of garbage, drainage covers, and road information under different periods, weather, and road sections. To strengthen the effectiveness of the data in training, the research team also framed the data at different frame rates to reduce the workload of labeling and improve the learning efficiency of the model. Finally, we divided the dataset into the training set, verification set, and test set in the ratio of 7:2:1.

In the model evaluating the works of this paper, we had Precision, Recall, and mAP values as the evaluation indexes to evaluate the model [33]. The calculation of Precision and Recall values are expressed by Formula (4) and Formula (5), respectively.

$$\text{Precision} = \frac{\text{TP}}{\text{TP} + \text{FP}} \tag{4}$$

$$\text{Recall} = \frac{\text{TP}}{\text{TP} + \text{FN}} \tag{5}$$

In the above two equations, TP means the prediction result is correctly classified as a positive sample, FP means the result is incorrectly classified as a positive sample, and FN means the result is incorrectly classified as a negative sample.

The dataset used in the evaluation includes the environment of daytime and rainy days, but excludes nighttime. In Figure 6, the mAP values processed by NMS and Soft-NMS can be seen.

**Figure 6.** The mAP values processed by NMS and Soft-NMS.

AP refers to the combination of Precision and Recall; Precision shows the prediction ability of the hit target passing the threshold in all prediction results, whereas Recall shows the ability to cover the real target in the test set. By combining the two, we can better evaluate our model. The mAP is the average of the average accuracy of each category, that is, the average AP of each category. The higher the mAP, the better the prediction ability of the model.

#### **5. Results**

#### *5.1. Principle of Soft-NMS Algorithm*

For YOLOX and Soft-YOLOX, the same prediction parameters and datasets were used to verify the effectiveness of the improvement we made. The difference is that YOLOX uses NMS, whereas Soft-YOLOX uses Soft-NMS. In the verification process, the detection effect and performance of the model are reflected by the evaluation metrics.

The mAP value of the YOLOX model processed by NMS is 89.72%, Precision is 91.54%, and Recall is 89.53%. The prediction results of the Soft-YOLOX model processed by Soft-NMS are improved, in which the mAP value is 91.89%, Precision is 92.93%, and Recall is 88.42%. The comparison results between the YOLOX model before and after improvement are shown in Table 1.



Since Recall and Precision cannot comprehensively evaluate the effect of the algorithm, the mAP index was selected for analysis. As can be seen from the results in Table 1, the mAP value and Precision value are higher than those of the original model, whereas Recall is lower. Soft-NMS removes redundant detection frames through the penalty mechanism of the weight function, thus reducing the missed detection rate. We found that the improvement of Soft-NMS was effective from the results.

#### *5.2. Comparison with State-of-the-Art Methods*

The experiments included the following comparison methods: Fast R-CNN [16], target detection algorithm based on YOLOV4 (abbreviated as YOLOV4 [34]), SSD [23], and target detection algorithm based on YOLOV5 (abbreviated as YOLOV5 [35]). All methods used the same evaluation index. It is not difficult to see that the Soft-YOLOX model improved performance compared with other algorithms. The detection results of each method on our dataset are shown in Table 2.



#### *5.3. Actual Application of the System*

The left side of Figure 7 shows that the system can detect the specific types and quantities of garbage in a complex garbage environment, which is convenient and allows for the system to further send early warning signals. The right side of Figure 7 shows the real-time prediction results of the system on rainy days, in which the graphics card model used for reasoning was the RTX 1060 Ti, and the FPS (frames per second) was 15.46. The above results effectively demonstrate the feasibility of the project, and support our team in carrying out further research and development.

**Figure 7.** The application of the system.

Currently, there are many cases about embedded deployment in YOLO series of algorithms, such as Fast YOLO [36], Efficient YOLO [37], YOLO nano [38], and so on. The YOLOX algorithm in this paper can be implemented by exporting the ONNX model for embedded deployment, or by pruning and quantization to build the lightweight model of YOLO to, finally, achieve embedded deployment.

#### **6. Conclusions**

Compared with other target detection models, the new detection model and counting method of Soft-YOLOX proposed in this paper has better detection performance and robustness, and a lower missed detection rate. Garbage can be identified and counted accurately in the case of occlusion, which effectively avoids the phenomenon of missed detection.

With the help of public surveillance cameras on urban roads, the system collects real-time images of sanitary conditions in the areas with urban sewer manhole covers. After identifying, analyzing, and processing data by the Soft-YOLOX model, the client is shown the returned results. With future development of urban public facilities, the number of urban surveillance cameras and the area covered by cameras will continue to increase. A large amount of available image data can improve the accuracy of the model and the availability of the system. The enhancement of identification accuracy and processing capacity will also effectively help urban sanitation construction and improve urban sanitation levels [39].

This paper proposed a new detection model called Soft-YOLOX based on YOLOX. By using Soft-NMS, the number of garbage can be accurately counted and the performance close to the actual application requirements obtained. The original YOLOX model is based on the NMS algorithm to remove redundant detection frames, whereas the YOLOX model proposed in this paper penalizes the score of detection frames based on the Soft-NMS algorithm. After comparative analysis, Soft-YOLOX had higher accuracy and lower missed detection in garbage detection applications. The mAP value of Soft-YOLOX was 91.89%, which is 2.17% higher than the YOLOX model. Therefore, Soft-YOLOX is more suitable for accumulated garbage quantity detection.

**Author Contributions:** Conceptualization, Z.Z. and J.L.; data curation, Y.L. and J.L.; validation, J.L., Y.L. and C.Y.; formal analysis, J.L., C.Y., Y.L., H.Z., Y.C. and Z.Z; funding acquisition, Z.Z. and J.L.; investigation, J.L.; supervision, Z.Z. All authors have read and agreed to the published version of the manuscript.

**Funding:** This research was funded by [the Young Innovative Talents Project of colleges and universities in Guangdong Province] grant number [2021KQNCX092]; [Doctoral program of Huizhou University] grant number [2020JB028]; [Outstanding Youth Cultivation Project of Huizhou University] grant number [HZU202009]; [Innovation Training Program for Chinese College Students] grant number [S202110577044]; [Special Funds for the Cultivation of Guangdong College Students' Scientific and Technological Innovation. ("Climbing Program" Special Funds)] grant number [No. pdjh2022b0494].

**Institutional Review Board Statement:** Not applicable.

**Informed Consent Statement:** Not applicable.

**Data Availability Statement:** The dataset used in this paper is made by our team members, and some parts of the dataset can be found at the following link: [https://github.com/Hzurang/Dataset]. (accessed on 1 July 2022).

**Conflicts of Interest:** The authors declare no conflict of interest.

#### **References**


### *Article* **Stability of Switched Systems with Time-Varying Delays under State-Dependent Switching**

**Chao Liu \*,† and Xiaoyang Liu †**

School of Computer Science and Engineering, Chongqing University of Technology, Chongqing 400054, China; lxy3103@cqut.edu.cn

**\*** Correspondence: 20140058@cqut.edu.cn or xiuwenzheng2000@163.com

† These authors contributed equally to this work.

**Abstract:** This paper studies the stability of linear switched systems with time-varying delays and all unstable subsystems. According to the largest region function strategy, the state-dependent switching rule is designed. By bringing in integral inequality and multiple Lyapunov-Krasovskii functionals, the stability results of delayed switched systems with or without sliding motions under the designed state-dependent switching rule are derived for different assumptions on time delay. Several numerical examples are employed to show the effectiveness and superiority of the proposed results.

**Keywords:** stability; switched system; state-dependent switching; time delay

**MSC:** 93D20; 93C10

#### **1. Introduction**

The dynamics of switched systems are affected by both subsystems and switching rules. For example, Decarlo R A has indicated that some appropriate switching rules can make switched systems unstable (or asymptotically stable) even if all subsystems are asymptotically stable (or unstable) [1]. Therefore, we must concentrate on both subsystems and switching rules to derive the stability results. In recent years, the stability issue of switched systems with unstable subsystems has been extensively investigated. For instance, in [2–7] the researchers have derived some stability results for switched systems with both stable and unstable subsystems. The main strategy of some literature is to ensure that the dwell time of stable subsystems is sufficiently large to compensate for the state divergence caused by unstable subsystems and switching behaviors. Obviously, if there is no stable subsystem to absorb the state divergence, these results proposed in [2–7] are disabled.

Because of the absence of stable subsystems, the stability analysis of switched systems with all unstable subsystems is more complicated. How to design appropriate switching rules to stabilize switched systems with all unstable subsystems has become an interesting and challenging problem. Ordinarily, switching rules can be designed by two strategies: time-dependent switching and state-dependent switching. The main idea of the first one is to use the stabilization of switching behaviors to stabilize switched systems and the designed switching rules usually have both upper and lower bounds. In [8–12], the timedependent switching rules are designed to stabilize switched systems with or without time delay by using discretized Lyapunov function approach or bound maximum average dwell time. The time-dependent switching strategy requires that switching behaviors have a good characteristic of stabilization. Therefore, when all switching behaviors do not contain stabilization characteristics, the time-dependent switching strategy is invalid.

In many instances, time-dependent switching rules that can stabilize switched systems are hard to design or even non-existent, which signifies that the state-dependent switching strategy becomes the unique way to stabilize switched systems. Up to now, the state-dependent switching rules can be designed by two methods. The first one is

**Citation:** Liu, C.; Liu, X. Stability of Switched Systems with Time-Varying Delays under State-Dependent Switching. *Mathematics* **2022**, *10*, 2722. https://doi.org/10.3390/ math10152722

Academic Editors: Jianping Gou, Weihua Ou, Shaoning Zeng, Lan Du and Asier Ibeas

Received: 22 June 2022 Accepted: 27 July 2022 Published: 1 August 2022

**Publisher's Note:** MDPI stays neutral with regard to jurisdictional claims in published maps and institutional affiliations.

**Copyright:** © 2022 by the authors. Licensee MDPI, Basel, Switzerland. This article is an open access article distributed under the terms and conditions of the Creative Commons Attribution (CC BY) license (https:// creativecommons.org/licenses/by/ 4.0/).

based on the regional partition of state space. Its basic idea can be summarized as follows: (a) divide the state space into different switching regions; (b) determine the index of activated subsystems for each switching region; (c) derive the stability conditions for switched systems under the designed switching rule. Under the assumption that there exists a Hurwitz convex combination of system matrices, the state-dependent switching rules have been designed via the regional partition of state space and some significant stability results have been deduced by common Lyapunov function (functional) in [13–19]. Remarkably, this assumption is a severe prerequisite. In order to relax this assumption, by employing some free matrices, a more flexible Hurwitz convex combination is presented in [20]. In [21] the regional partition of state space is implemented directly by the negative definite of the time-derivative of common Lyapunov functional. To ensure the strict completeness of regional partition, one additional condition is introduced. Based on newly introduced symmetric matrices, Pettersson S has defined switching rules via the largest region function strategy and established the stability results by multiple Lyapunov functionals [22,23]. Some restrictions are also employed to guarantee the decrease of Lyapunov functional when switching events occur. However, the largest region function strategy has not been generalized to switched systems with time delay. The second one is that the switching rules are defined in terms of the set-valued function. One typical state-dependent switching rule is given by *<sup>σ</sup>*(*t*) = arg min{*xT*(*t*)*P*1*x*(*t*), ··· , *<sup>x</sup>T*(*t*)*Pmx*(*t*)}, where *Pi* is a symmetric positive determined matrix, *m* is number of subsystems. In [24–27], the authors have designed the switching rules by the set-valued function and given the stability conditions with the Lyapunov-Metzler inequalities. Although there are numerous results for state-dependent switching, it is noteworthy that this issue still needs to be further studied. Designing new state-dependent switching rules and getting lower conservative stability results is still our research motivation.

Up to now, the literatures on the stability of delayed switched systems with statedependent switching rules include [15–21,27]. However, the assumption that there exists a Hurwitz convex combination of system matrices is serious, which affects the effectiveness of stability results presented in [15–20]. The additional condition on strict completeness of regional partition makes it difficult to get appropriate switching regions [21]. Additionally, the results presented in [27] are only available for switched systems with constant delay. Therefore, the stability of switched systems with time-varying delays under state-dependent switching rules still deserves further attention. The main objective of this paper is to derive some new stability results for this problem. Based on the largest region function strategy, we design a state-dependent switching rule. By using integral inequality and the Leibniz-Newton formula, novel asymptotic stability results under different assumptions on time delay are presented in the form of bilinear matrix inequalities (BMIs). The effectiveness of the proposed results is shown via several numerical examples.

*Notations***:** matrix *A* > 0(<0) yields that *A* is symmetric positive(negative) matrix, *R<sup>n</sup>* denotes the *n*−dimension Euclidean space, *arg* max *S* is defined as the index of maximum element of order set *S*.

#### **2. Preliminaries**

This paper considers the following switched systems with time-varying delay

$$\begin{cases}
\dot{\mathbf{x}}(t) = A\_{\sigma(x(t))} \mathbf{x}(t) + B\_{\sigma(x(t))} \mathbf{x}(t - d(t)), t > 0, \\
\mathbf{x}(s) = \phi(s), s \in [-d, 0],
\end{cases} \tag{1}$$

where *<sup>x</sup>*(*t*) ∈ *<sup>R</sup><sup>n</sup>* is the state vector, *<sup>σ</sup>*(*x*(*t*)) ∈ *<sup>M</sup>* = {1, 2, ··· , *<sup>m</sup>*} is the switching rule, *Ap*, *Bp* ∈ *<sup>R</sup>n*×*n*, *<sup>p</sup>* ∈ *<sup>M</sup>*, are known matrices, *<sup>d</sup>*(*t*) is the time-varying delay, *<sup>φ</sup>*(*s*) is a piecewise continuous function. If *σ*(*t*) = *p*, we say that the *p*-th subsystem *x*˙(*t*) = *Apx*(*t*) + *Bpx*(*t* − *d*(*t*)) is activated.

**Remark 1.** *σ*(*x*(*t*)) *is a state-dependent switching rule which is generated by switching device [13]. Similar to [13–23], in this paper we also assume that there is no delay produced in switching device.* *That is to say, the switching rule σ*(*x*(*t*)) *is one dependent on the current state but irrelevant to the delayed state.*

We would like to design a state-dependent switching rule *σ*(*t*) such that switched system (1) is globally asymptotically stable. We employ the state-dependent switching strategy introduced in [22,23], which is based on the appropriate choice of symmetric matrices *Qp*, *p* ∈ *M*. Define the following regions

$$\begin{aligned} \Omega\_{\mathcal{P}} &= \left\{ \mathbf{x} \in \mathbb{R}^n | \mathbf{x}^T \mathbf{Q}\_{\mathcal{P}} \mathbf{x} \ge 0 \right\}, \mathcal{p} \in M, \\ \Omega\_{\mathcal{P}q} &= \left\{ \mathbf{x} \in \mathbb{R}^n | \mathbf{x}^T \mathbf{Q}\_{\mathcal{P}} \mathbf{x} = \mathbf{x}^T \mathbf{Q}\_{\mathcal{P}} \mathbf{x} \ge 0 \right\}, \mathcal{p}, q \in M, \mathcal{p} \ne q. \end{aligned}$$

We hope that the *p*-th subsystem is activated if *x*(*t*) ∈ Ω*<sup>p</sup>* and switching events occur at the region Ω*pq*. The following properties should be satisfied to ensure that the switched system (1) is well-defined [22],


The covering property points out that there is at least one activated subsystem on an arbitrary region of the state space. The switching property implies that the switch from subsystem *p* to *q* occurs only if regions Ω*<sup>p</sup>* and Ω*<sup>q</sup>* are adjacent. According to [22,23], the covering property is satisfied, if there exists *<sup>θ</sup><sup>p</sup>* > 0, *<sup>p</sup>* ∈ *<sup>M</sup>*, such that for any *<sup>x</sup>* ∈ *<sup>R</sup>n*,

$$\sum\_{p\in\mathcal{M}} \theta\_p \mathbf{x}^T Q\_p \mathbf{x} \ge 0. \tag{2}$$

The switching rule can be defined as the following largest region function strategy [22,23]

$$\sigma(\mathbf{x}(t)) = \arg\max \left\{ \mathbf{x}^T(t) Q\_1 \mathbf{x}(t), \dots, \mathbf{x}^T(t) Q\_{\mathbf{m}} \mathbf{x}(t) \right\}. \tag{3}$$

As can be seen from [22] we know that if (2) is true and the switching rule (3) is used, the switching property is also satisfied.

The main purpose of this work is to get the stability results under one of the following assumptions.

**Assumption 1.** *The time delay and its time-derivative are bounded. Namely, there exist nonnegative constants d*, ¯*d and constant* ˜*d such that*

$$0 \le d(t) \le d,\\
\vec{d} \le \dot{d}(t) \le \vec{d}. \tag{4}$$

**Assumption 2.** *The time delay is bounded. Namely, there exists a nonnegative constant d such that*

$$0 \le d(t) \le d.\tag{5}$$

The following lemma is the core of this research.

**Lemma 1** ([28])**.** *If matrix <sup>M</sup>* > <sup>0</sup> *and function <sup>x</sup>* : [*a*, *<sup>b</sup>*] → *<sup>R</sup><sup>n</sup> is differentiable, then the following inequality is satisfied*

$$(b-a)\int\_{a}^{b} \dot{\mathfrak{x}}^T(s)M\dot{\mathfrak{x}}(s)ds \ge \beta^T \text{diag}(M, \mathfrak{x}M, \mathfrak{L}M)\beta\_{\gamma}$$

*where β* = *βT* <sup>1</sup> , *<sup>β</sup><sup>T</sup>* <sup>2</sup> , *<sup>β</sup><sup>T</sup>* 3 *T , <sup>β</sup>*<sup>1</sup> <sup>=</sup> *<sup>x</sup>*(*b*) <sup>−</sup> *<sup>x</sup>*(*a*)*, <sup>β</sup>*<sup>2</sup> <sup>=</sup> *<sup>x</sup>*(*b*) + *<sup>x</sup>*(*a*) <sup>−</sup> <sup>2</sup> *b* − *a b <sup>a</sup> x*(*s*)*ds, <sup>β</sup>*<sup>3</sup> <sup>=</sup> *<sup>x</sup>*(*b*) <sup>−</sup> *<sup>x</sup>*(*a*) + <sup>6</sup> *b* − *a b <sup>a</sup> <sup>x</sup>*(*s*)*ds* <sup>−</sup> <sup>12</sup> (*b* − *a*)<sup>2</sup> *b a b <sup>θ</sup> x*(*s*)*dsdθ.*

#### **3. Main Results**

This section presents the stability criteria for the switched system (1) under the state-dependent switching rule (3). Owing to the Leibniz-Newton formula, we have the following equation

$$
\dot{\mathbf{x}}(t) - \mathbf{x}(t - d(t)) = \int\_{t - d(t)}^{t} \dot{\mathbf{x}}(\mathbf{s}) d\mathbf{s}.\tag{6}
$$

Some notations are given as follows

$$\upsilon\_1 = \frac{2}{d - d(t)} \int\_{t - d(t)}^t \mathbf{x}(s) ds,\\ \upsilon\_2 = \frac{12}{\left(d - d(t)\right)^2} \int\_{t - d(t)}^t \int\_{\theta}^t \mathbf{x}(s) ds d\theta,$$

$$\eta(t) = \left(\mathbf{x}^T(t), \mathbf{x}^T(t - d(t)), \mathbf{x}^T(t - d), \dot{\mathbf{x}}^T(t - d(t)), \dot{\mathbf{x}}^T(t - d), \upsilon\_1^T, \upsilon\_2^T\right)^T.$$

**Theorem 1.** *Under Assumption 1, assume that for any p* ∈ *M, there exist n* × *n matrices Pp* > 0*, Ri* > 0, *Si* > 0, *U* > 0*, (i* = 1, 2*), Qp* = *Q<sup>T</sup> <sup>p</sup> , positive constants μp*, *θp, constants ηp*,*q, q* ∈ *M, q* = *p, such that*

$$
\begin{pmatrix}
\Lambda\_l^p + \mu\_P \boldsymbol{\varepsilon}\_1^T \boldsymbol{Q}\_p \boldsymbol{\varepsilon}\_1 & \sqrt{d} \boldsymbol{\varepsilon}\_1^T \boldsymbol{P}\_p \boldsymbol{B}\_p \\
\sqrt{d} \boldsymbol{B}\_p^T \boldsymbol{P}\_p \boldsymbol{\varepsilon}\_1 & -\boldsymbol{\mathcal{U}}
\end{pmatrix} < 0, l = 1, 2,
\tag{7}
$$

$$P\_P = P\_q + \eta\_{p\mathcal{A}} (Q\_q - Q\_p), q \in M, q \neq p,\tag{8}$$

$$\sum\_{j \in M} \theta\_j Q\_j \ge 0,\tag{9}$$

*where*

Λ*p* <sup>1</sup> <sup>=</sup>Φ*<sup>p</sup>* <sup>1</sup> <sup>+</sup> <sup>Φ</sup><sup>2</sup> <sup>+</sup> <sup>Φ</sup>*<sup>p</sup>* <sup>3</sup> <sup>+</sup> <sup>Φ</sup>*<sup>p</sup>* <sup>4</sup> <sup>+</sup> <sup>1</sup> <sup>−</sup> ¯*<sup>d</sup>* (Ψ<sup>2</sup> <sup>+</sup>Ψ3) <sup>−</sup> <sup>1</sup> *d* <sup>Ξ</sup>4, <sup>Λ</sup>*<sup>p</sup>* <sup>2</sup> <sup>=</sup> <sup>Φ</sup>*<sup>p</sup>* <sup>1</sup> <sup>+</sup> <sup>Φ</sup><sup>2</sup> <sup>+</sup> <sup>Φ</sup>*<sup>p</sup>* <sup>3</sup> <sup>+</sup> <sup>Φ</sup>*<sup>p</sup>* 4+ <sup>1</sup> <sup>−</sup> ˜*<sup>d</sup>* (Ψ<sup>2</sup> <sup>+</sup>Ψ3) <sup>−</sup> <sup>1</sup> *d* <sup>Ξ</sup>4, <sup>Φ</sup>*<sup>p</sup>* <sup>1</sup> <sup>=</sup> *<sup>e</sup><sup>T</sup>* 1 *Ap* + *Bp TPp* + *Pp Ap* + *Bp e*1, Φ<sup>2</sup> =*e<sup>T</sup>* <sup>1</sup> *<sup>R</sup>*1*e*<sup>1</sup> <sup>−</sup> *<sup>e</sup><sup>T</sup>* <sup>3</sup> *<sup>R</sup>*2*e*3, <sup>Φ</sup>*<sup>p</sup>* <sup>3</sup> <sup>=</sup> *Ape*<sup>1</sup> + *Bpe*<sup>2</sup> *T S*1 *Ape*<sup>1</sup> + *Bpe*<sup>2</sup> <sup>−</sup> *<sup>e</sup><sup>T</sup>* <sup>5</sup> *S*2*e*5, Φ*p* <sup>4</sup> =*d Ape*<sup>1</sup> + *Bpe*<sup>2</sup> *T U Ape*<sup>1</sup> + *Bpe*<sup>2</sup> , Ψ<sup>2</sup> = *e<sup>T</sup>* <sup>2</sup> (*R*<sup>2</sup> <sup>−</sup> *<sup>R</sup>*1)*e*2, <sup>Ψ</sup><sup>3</sup> <sup>=</sup> *<sup>e</sup><sup>T</sup>* <sup>4</sup> (*S*<sup>2</sup> − *S*1)*e*4, Ξ<sup>4</sup> =(*e*<sup>2</sup> − *e*3) *TU*(*e*<sup>2</sup> <sup>−</sup> *<sup>e</sup>*3) <sup>+</sup> <sup>3</sup>(*e*<sup>2</sup> <sup>+</sup> *<sup>e</sup>*<sup>3</sup> <sup>−</sup>*e*6) *TU*(*e*<sup>2</sup> <sup>+</sup> *<sup>e</sup>*<sup>3</sup> <sup>−</sup> *<sup>e</sup>*6) <sup>+</sup> <sup>5</sup>(*e*<sup>2</sup> <sup>−</sup> *<sup>e</sup>*<sup>3</sup> <sup>+</sup>3*e*<sup>6</sup> <sup>−</sup> *<sup>e</sup>*7) *TU*<sup>×</sup> (*e*<sup>2</sup> − *e*<sup>3</sup> + 3*e*<sup>6</sup> − *e*7),*ei* = 0*n*×(*i*−1)*n*, *I*, 0*n*×(7−*i*)*<sup>n</sup>* , *i* = 1, 2, ··· , 7.

*Then, the switched system (1) is globally asymptotically stable under the state-dependent switching rule (3), if there is no sliding motion or there exist sliding motions on the switching surface* Ω*pq with ηp*,*<sup>q</sup>* > 0*.*

**Proof.** Condition (9) implies that (2) is true, which indicates that the covering property holds. Therefore, under the switching rule (3), the switched system (1) is well-defined.

Now we prove that the switched system (1) is globally asymptotically stable. Similar to [29,30], for each subsystem *p*, we choose the Lyapunov-Krasovskii functional as follows

$$V\_p(t) = V\_{p1}(t) + \sum\_{i=2}^{4} V\_i(t),\tag{10}$$

where

$$V\_{p1}(t) = \mathbf{x}^T(t)P\_p\mathbf{x}(t),\\V\_2(t) = \int\_{t-d(t)}^t \mathbf{x}^T(\mathbf{s})R\_1\mathbf{x}(\mathbf{s})ds + \int\_{t-d}^{t-d(t)} \mathbf{x}^T(\mathbf{s})R\_2\mathbf{x}(\mathbf{s})ds,$$

$$V\_3(t) = \int\_{t-d(t)}^t \dot{\mathbf{x}}(s) S\_1 \dot{\mathbf{x}}(s) ds + \int\_{t-d}^{t-d(t)} \dot{\mathbf{x}}^T(s) S\_2 \dot{\mathbf{x}}(s) ds,\\ V\_4(t) = \int\_{-d}^0 \int\_{t+\theta}^t \dot{\mathbf{x}}(s) \mathcal{U} \dot{\mathbf{x}}(s) ds d\theta.$$

In each region Ω*p*, the time derivate of *Vp*1(*t*), *Vi*(*t*), *i* = 2, 3, 4, along the trajectory of the subsystem *p* are given as follows

$$\begin{split} & \quad \dot{V}\_{p1}(t) \\ &= \mathbf{x}^T(t) \Big( \left( A\_P + B\_p \right)^T P\_p + P\_p \left( A\_P + B\_p \right) \Big) \mathbf{x}(t) - \int\_{t-d(t)}^t \left( \dot{\mathbf{x}}^T(s) B\_p^T P\_p \mathbf{x}(t) + \mathbf{x}^T(t) P\_p B\_p \dot{\mathbf{x}}(s) \right) ds \\ & \le \mathbf{x}^T(t) \Big( \left( A\_P + B\_p \right)^T P\_p + P\_p \left( A\_P + B\_p \right) + d(t) P\_p B\_p \mathbf{L}^{-1} B\_p^T P\_p \Big) \mathbf{x}(t) + \int\_{t-d(t)}^t \dot{\mathbf{x}}^T(s) \mathbf{L} \dot{\mathbf{x}}(s) ds \\ &= \boldsymbol{\eta}^T(t) \Big( \Phi\_1^p + d(t) \Theta\_1^p \Big) \boldsymbol{\eta}^T(t) + \int\_{t-d(t)}^t \dot{\mathbf{x}}^T(s) \mathbf{L} \dot{\mathbf{x}}(s) ds. \end{split} \tag{11}$$

$$\begin{split} \dot{V}\_{2}(t) &= \mathbf{x}^{T}(t)\mathbb{R}\_{1}\mathbf{x}(t) + \left(1 - \dot{d}(t)\right)\mathbf{x}^{T}(t - d(t))(\mathbf{R}\_{2} - \mathbf{R}\_{1})\mathbf{x}(t - d(t)) - \mathbf{x}^{T}(t - d)\mathbf{R}\_{2}\mathbf{x}(t - d) \\ &= \boldsymbol{\eta}^{T}(t)\left(\boldsymbol{\Phi}\_{2} + \left(1 - d(t)\right)\mathbf{\varvarlim}\_{t \in \mathcal{T}}\right)\boldsymbol{\eta}(t). \end{split} \tag{12}$$

$$\begin{split} & \quad \mathcal{V}\_{3}(t) \\ &= \dot{\mathbf{x}}^{T}(t) \mathbf{S}\_{1} \dot{\mathbf{x}}(t) - \dot{\mathbf{x}}^{T}(t-d) \mathbf{S}\_{2} \dot{\mathbf{x}}(t-d) + \left(1 - \dot{d}(t)\right) \dot{\mathbf{x}}^{T}(t-d(t)) (\mathbf{S}\_{2} - \mathbf{S}\_{1}) \dot{\mathbf{x}}(t-d(t)) \\ &= \left(A\_{P} \mathbf{x}(t) + B\_{P} \mathbf{x}(t-d(t))\right)^{T} \mathbf{S}\_{1} \left(A\_{P} \mathbf{x}(t) + B\_{P} \mathbf{x}(t-d(t))\right) - \dot{\mathbf{x}}^{T}(t-d) S\_{2} \dot{\mathbf{x}}(t-d) \\ & \quad + \left(1 - \dot{d}(t)\right) \dot{\mathbf{x}}^{T}(t-d(t)) (S\_{2} - S\_{1}) \dot{\mathbf{x}}(t-d(t)) \\ &= \eta^{T}(t) \left(\Phi\_{3}^{p} + \left(1 - \dot{d}(t)\right) \Psi\_{3}\right) \eta(t). \end{split} \tag{13}$$

$$\begin{split} & \quad \dot{V}\_{4}(t) \\ &= d\dot{\mathbf{x}}^{T} \boldsymbol{\text{L}} \dot{\mathbf{x}}(t) - \int\_{t-d(t)}^{t} \dot{\mathbf{x}}^{T}(\mathbf{s}) \boldsymbol{\text{L}} \dot{\mathbf{x}}(\mathbf{s}) \, \mathrm{ds} - \int\_{t-d}^{t-d(t)} \dot{\mathbf{x}}^{T}(\mathbf{s}) \boldsymbol{\text{L}} \dot{\mathbf{x}}(\mathbf{s}) \, \mathrm{ds} \\ &= d\left( A\_{p} \mathbf{x}(t) + B\_{p} \mathbf{x}(t - d(t)) \right)^{T} \boldsymbol{\text{L}} \left( A\_{p} \mathbf{x}(t) + B\_{p} \mathbf{x}(t - d(t)) \right) - \int\_{t-d(t)}^{t} \dot{\mathbf{x}}^{T}(\mathbf{s}) \boldsymbol{\text{L}} \dot{\mathbf{x}}(\mathbf{s}) \, \mathrm{ds} \\ & - \int\_{t-d}^{t-d(t)} \dot{\mathbf{x}}^{T}(\mathbf{s}) \boldsymbol{\text{L}} \dot{\mathbf{x}}(\mathbf{s}) \, \mathrm{ds}. \end{split} \tag{14}$$

where Θ*<sup>p</sup>* <sup>1</sup> = *<sup>e</sup><sup>T</sup>* <sup>1</sup> *PpBpU*−1*B<sup>T</sup> <sup>p</sup> Ppe*1. Under Lemma 1, one can obtain

$$(d - d(t)) \int\_{t - d}^{t - d(t)} \dot{\mathbf{x}}^T(s) L \dot{\mathbf{x}}(s) ds \ge \xi\_1^T \mathsf{U} \xi\_1 + 3 \xi\_2^T \mathsf{U} \xi\_2 + 5 \xi\_3^T \mathsf{U} \xi\_3 = \eta^T(t) \Xi\_4 \eta(t),$$

where *ξ*<sup>1</sup> = *x*(*t* − *d*(*t*)) − *x*(*t* − *d*), *ξ*<sup>2</sup> = *x*(*t* − *d*(*t*)) + *x*(*t* − *d*) − *υ*1, *ξ*<sup>3</sup> = *x*(*t* − *d*(*t*)) − *x*(*t* − *d*) + 3*υ*<sup>1</sup> − *υ*2. Above inequality implies that (14) can be continued as

$$\dot{V}\_4(t) \le \eta^T(t) \left(\Phi\_4^{\sigma} - \frac{1}{d - d(t)} \Xi\_4\right) \eta(t) - \int\_{t-d}^{t-d(t)} \dot{\mathbf{x}}^T(s) \mathcal{U} \dot{\mathbf{x}}(s) ds \tag{15}$$

Then, it follows from (10)–(13), (15) that

$$\begin{split} & \quad \mathcal{V}\_{p}(t) \\ & \leq \eta^{T}(t) \Big( \boldsymbol{\Phi}\_{1}^{p} + \boldsymbol{\Phi}\_{2} + \boldsymbol{\Phi}\_{3}^{p} + \boldsymbol{\Phi}\_{4}^{p} + d(t) \boldsymbol{\Theta}\_{1}^{p} + \left( 1 - d(t) \right) \left( \boldsymbol{\Psi}\_{2} + \boldsymbol{\Psi}\_{3} \right) - \frac{1}{d - d(t)} \boldsymbol{\Xi}\_{4} \Big) \boldsymbol{\eta}(t) \\ & = \frac{1}{d - d(t)} \boldsymbol{\eta}^{T}(t) \Big( \boldsymbol{d} - d(t) \big( \boldsymbol{\Phi}\_{1}^{p} + \boldsymbol{\Phi}\_{2} + \boldsymbol{\Phi}\_{3}^{p} + \boldsymbol{\Phi}\_{4}^{p} + d(t) \boldsymbol{\Theta}\_{1}^{p} + \left( 1 - d(t) \right) \left( \boldsymbol{\Psi}\_{2} + \boldsymbol{\Psi}\_{3} \right) \Big) - \boldsymbol{\Xi}\_{4} \Big). \end{split} \tag{16}$$

Due to Schur complements [31], Condition (7) indicates that

$$
\Lambda\_l^p + d\Theta\_1^p + \mu\_p \varepsilon\_1^T Q\_p \varepsilon\_1 < 0, l = 1, 2. \tag{17}
$$

Namely,

$$\begin{cases} \begin{aligned} \boldsymbol{\Phi}\_{1}^{p} + \boldsymbol{\Phi}\_{2} + \boldsymbol{\Phi}\_{3}^{p} + \boldsymbol{\Phi}\_{4}^{p} + \left(1 - \bar{d}\right) \left(\boldsymbol{\Psi}\_{2} + \boldsymbol{\Psi}\_{3}\right) + d\boldsymbol{\Theta}\_{1}^{p} + \mu\_{p}e\_{1}^{T}\boldsymbol{Q}\_{p}e\_{1} - \frac{1}{d}\boldsymbol{\Xi}\_{4} < 0, \\\ \boldsymbol{\Phi}\_{1}^{p} + \boldsymbol{\Phi}\_{2} + \boldsymbol{\Phi}\_{3}^{p} + \boldsymbol{\Phi}\_{4}^{p} + \left(1 - \bar{d}\right) \left(\boldsymbol{\Psi}\_{2} + \boldsymbol{\Psi}\_{3}\right) + d\boldsymbol{\Theta}\_{1}^{p} + \mu\_{p}e\_{1}^{T}\boldsymbol{Q}\_{p}e\_{1} - \frac{1}{d}\boldsymbol{\Xi}\_{4} < 0. \end{aligned} \end{cases} \tag{18}$$

The above inequalities declare that

$$\left(\boldsymbol{\phi}\_{1}^{p} + \boldsymbol{\phi}\_{2} + \boldsymbol{\phi}\_{3}^{p} + \boldsymbol{\phi}\_{4}^{p} + \left(1 - \dot{d}(t)\right)\left(\boldsymbol{\Psi}\_{2} + \boldsymbol{\Psi}\_{3}\right) + d\boldsymbol{\Theta}\_{1}^{p} + \mu\_{p}e\_{1}^{T}\boldsymbol{Q}\_{p}e\_{1} - \frac{1}{d}\boldsymbol{\Xi}\_{4} < 0. \tag{19}$$

Due to 0 <sup>≤</sup> *<sup>d</sup>*(*t*) <sup>≤</sup> *<sup>d</sup>* and <sup>Θ</sup>*<sup>p</sup>* <sup>1</sup> > 0, it is clear from (19) that

$$d\left(\phi\_1^p + \phi\_2 + \phi\_3^p + \phi\_4^p + (1 - d(t))\left(\Psi\_2 + \Psi\_3\right) + d(t)\Theta\_1^p + \mu\_p \varepsilon\_1^T Q\_p \varepsilon\_1\right) - \Xi\_4 < 0. \tag{20}$$

Noting that 0 ≤ *d* − *d*(*t*) ≤ *d* and Ξ > 0, (20) shows that

$$(d - d(t))\left(\phi\_1^p + \phi\_2 + \phi\_3^p + \phi\_4^p + (1 - \dot{d}(t))(\Psi\_2 + \Psi\_3) + d(t)\Theta\_1^p + \mu\_p \epsilon\_1^T Q\_p \epsilon\_1\right) - \Xi\_4 < 0. \tag{21}$$

Based on (16) and (21), one can derive that

$$
\dot{\mathcal{V}}\_p(t) < -\mu\_p \eta^T(t) Q\_p \eta(t) \le 0,\tag{22}
$$

where the fact *<sup>x</sup>T*(*t*)*Qσx*(*t*) ≥ 0 is used.

Note that for arbitrary *<sup>x</sup>* ∈ <sup>Ω</sup>*pq*, *<sup>x</sup>TQpx* = *<sup>x</sup>TQqx*. Then, due to Condition (8) we can derive that *Vp*(*t*) = *Vq*(*t*) if *x*(*t*) ∈ Ω*pq*. Therefore, when the trajectory *x*(*t*) traverses from Ω*<sup>p</sup>* to Ω*q*, the Lyapunov functional *Vσ*(*t*) is not increasing. In particular, if the sliding motion does not occur, the Lyapunov functional *Vσ*(*t*) will be approximate to zero and shows that the switched system (1) is globally asymptotically stable.

Now we consider the case of sliding motions. Assume that the sliding motions occur along the switching surface Ω*pq* at the boundary of regions Ω*<sup>p</sup>* and Ω*q*. According to Filippov's definition [32], we get

$$\begin{split} \dot{\mathbf{x}}(t) &= \mathbf{a}\left(A\_{p}\mathbf{x}(t) + B\_{p}\mathbf{x}(t - d(t))\right) + \bar{\mathbf{a}}\left(A\_{q}\mathbf{x}(t) + B\_{q}\mathbf{x}(t - d(t))\right) \\ &= \mathbf{a}\left(\left(A\_{p} + B\_{p}\right)\mathbf{x}(t) - B\_{p}\int\_{t-d(t)}^{t} \dot{\mathbf{x}}(s)ds\right) + \bar{\mathbf{a}}\left(\left(A\_{q} + B\_{q}\right)\mathbf{x}(t) - B\_{q}\int\_{t-d(t)}^{t} \dot{\mathbf{x}}(s)ds\right), \end{split} \tag{23}$$

where *α* ∈ (0, 1),*α*˜ = 1 − *α*. Under the analysis of sliding motions [33], the sliding motions on the surface Ω*pq* state that

$$\begin{split} \mathbf{x}^{T} \left( \left( A\_{p} + B\_{p} \right)^{T} Q\_{p\eta} + Q\_{p\eta} \left( A\_{p} + B\_{p} \right) \right) \mathbf{x}(t) - \mathbf{x}^{T}(t) Q\_{p\eta} B\_{p} \int\_{t-d(t)}^{t} \dot{\mathbf{x}}(s) ds \\ - \int\_{t-d(t)}^{t} \dot{\mathbf{x}}^{T}(s) ds B\_{p}^{T} Q\_{p\eta} \mathbf{x}(t) < 0, \end{split} \tag{24}$$

and

$$\begin{split} \mathbf{x}^{T} \left( \left( A\_{\eta} + B\_{\eta} \right)^{T} Q\_{p\eta} + Q\_{p\eta} \left( A\_{\eta} + B\_{\eta} \right) \right) \mathbf{x}(t) - \mathbf{x}^{T}(t) Q\_{p\eta} B\_{\eta} \int\_{t-d(t)}^{t} \dot{\mathbf{x}}(s) ds \\ - \int\_{t-d(t)}^{t} \dot{\mathbf{x}}^{T}(s) ds B\_{\eta}^{T} Q\_{p\eta} \mathbf{x}(t) > 0 \end{split} \tag{25}$$

hold, where *Qpq* = *Qp* − *Qq*. Let *Pqp* = *Pq* − *Pp*. Owing to Condition (8) and *ηp*,*<sup>q</sup>* > 0, we obtain

$$\begin{aligned} \int \mathbf{x}^T \left( \left( A\_p + B\_p \right)^T P\_{qp} + \left( P\_q - P\_p \right) \left( A\_p + B\_p \right) \right) \mathbf{x}(t) - \mathbf{x}^T(t) P\_{qp} B\_p \int\_{t-d(t)}^t \dot{\mathbf{x}}(s) ds \end{aligned}$$

$$- \int\_{t-d(t)}^t \dot{\mathbf{x}}^T(s) ds B\_p^T P\_{qp} \mathbf{x}(t) < 0,\qquad \text{(26)}$$

$$\begin{aligned} \mathbf{x}^T \left( \left( A\_{\boldsymbol{\eta}} + B\_{\boldsymbol{\eta}} \right)^T P\_{\boldsymbol{q}\boldsymbol{p}} + \left( P\_{\boldsymbol{\eta}} - P\_{\boldsymbol{p}} \right) \left( A\_{\boldsymbol{\eta}} + B\_{\boldsymbol{\eta}} \right) \right) \mathbf{x}(t) - \mathbf{x}^T(t) P\_{\boldsymbol{q}\boldsymbol{p}} B\_{\boldsymbol{\eta}} \int\_{t-d(t)}^t \dot{\mathbf{x}}(s) ds \\ - \int\_{t-d(t)}^t \dot{\mathbf{x}}^T(s) ds B\_{\boldsymbol{q}}^T P\_{\boldsymbol{q}\boldsymbol{p}} \mathbf{x}(t) &> 0, \end{aligned} \tag{27}$$

which are equivalent to

$$\begin{split} & \quad \mathbf{x}^T(t) \Big( \left( A\_p + B\_p \right)^T P\_q + P\_q \left( A\_p + B\_p \right) \Big) \mathbf{x}(t) - \mathbf{x}^T(t) P\_q B\_p \int\_{t-d(t)}^t \dot{\mathbf{x}}(s) ds \\ & \quad - \int\_{t-d(t)}^t \dot{\mathbf{x}}^T(s) ds B\_p^T P\_q \mathbf{x}(t) \\ & \quad < \mathbf{x}^T(t) \Big( \left( A\_p + B\_p \right)^T P\_p + P\_p \left( A\_p + B\_p \right) \Big) \mathbf{x}(t) - \mathbf{x}^T(t) P\_p B\_p \int\_{t-d(t)}^t \dot{\mathbf{x}}(s) ds \\ & \quad - \int\_{t-d(t)}^t \dot{\mathbf{x}}^T(s) ds B\_p^T P\_p \mathbf{x}(t), \end{split} \tag{28}$$

$$\begin{split} & \quad \mathbf{x}^T(t) \Big( \left( A\_q + B\_q \right)^T P\_p + P\_p \left( A\_q + B\_q \right) \Big) \mathbf{x}(t) - \mathbf{x}^T(t) P\_p B\_q \int\_{t-d(t)}^t \dot{\mathbf{x}}(s) ds \\ & \quad - \int\_{t-d(t)}^t \dot{\mathbf{x}}^T(s) ds B\_q^T P\_p \mathbf{x}(t) \\ & \quad < \mathbf{x}^T(t) \Big( \left( A\_q + B\_q \right)^T P\_q + P\_q \left( A\_q + B\_q \right) \Big) \mathbf{x}(t) - \mathbf{x}^T(t) P\_q B\_q \int\_{t-d(t)}^t \dot{\mathbf{x}}(s) ds \\ & \quad - \int\_{t-d(t)}^t \dot{\mathbf{x}}^T(s) ds B\_q^T P\_q \mathbf{x}(t) . \end{split} \tag{29}$$

Note that the switching signal is not unique on sliding surface Ω*pq*. If *σ*(*t*) = *p*, one can derive

$$\begin{aligned} &\dot{V}\_{p1}(t) \\ &= -\mathbf{a}\mathbf{x}^T(t)\left(\left(A\_p + B\_p\right)^T P\_p + P\_p \left(A\_p + B\_p\right)\right) \mathbf{x}(t) - \mathbf{a}\mathbf{x}^T(t) P\_p B\_p \int\_{t-d(t)}^t \dot{\mathbf{x}}(s) ds \\ &- \mathbf{a} \int\_{t-d(t)}^t \dot{\mathbf{x}}^T(s) ds B\_p^T P\_p \mathbf{x}(t) + \mathbf{\bar{a}} \mathbf{x}^T(t) \left(\left(A\_q + B\_q\right)^T P\_p + P\_p \left(A\_q + B\_q\right)\right) \mathbf{x}(t) \\ &- \mathbf{\bar{a}}\left(\mathbf{x}^T(t) P\_p B\_q \int\_{t-d(t)}^t \dot{\mathbf{x}}(s) ds + \int\_{t-d(t)}^t \dot{\mathbf{x}}^T(s) ds B\_q^T P\_p \mathbf{x}(t)\right) \\ &\leq \mathbf{a} \mathbf{x}^T(t) \left(\left(A\_p + B\_p\right)^T P\_p + P\_p \left(A\_p + B\_q\right)\right) \mathbf{x}(t) - \mathbf{a} \mathbf{x}^T(t) P\_p B\_p \int\_{t-d(t)}^t \dot{\mathbf{x}}(s) ds \end{aligned} \tag{30}$$

$$\begin{split} & -\mathfrak{a} \int\_{t-d(t)}^{t} \dot{\mathbf{x}}^{T}(s) ds B\_{p}^{T} P\_{p} \mathbf{x}(t) \\ & + \bar{\mathfrak{a}} \mathbf{x}^{T}(t) \Big( \left( A\_{q} + B\_{q} \right)^{T} P\_{q} + P\_{q} (A\_{q} + B\_{q}) \big) \mathbf{x}(t) - \bar{\mathfrak{a}} \left( \mathbf{x}^{T}(t) P\_{q} B\_{q} \int\_{t-d(t)}^{t} \dot{\mathbf{x}}(s) ds \right) \\ & + \int\_{t-d(t)}^{t} \dot{\mathbf{x}}^{T}(s) ds B\_{q}^{T} P\_{q} \mathbf{x}(t) \Big) \\ & \leq \underline{a} \eta^{T}(t) e\_{1}^{T} \Big( \Phi\_{1}^{p} + d(t) \Theta\_{1}^{p} \Big) e\_{1} \eta(t) + \tilde{a} \eta^{T}(t) e\_{1}^{T} \Big( \Phi\_{1}^{q} + d(t) \Theta\_{1}^{q} \Big) e\_{1} \eta(t) + \int\_{t-d(t)}^{t} \mathbf{x}^{T}(s) L \dot{\mathbf{x}}(s) ds. \end{split}$$

Under (7), (10)–(13), (21) and (30), it is easy to deduce that

$$\dot{W}\_p(t) < -\eta^T(t) \left(\alpha \varepsilon\_1^T Q\_p \varepsilon\_1 + \overline{\alpha} \varepsilon\_1^T Q\_q \varepsilon\_1\right) \eta(t) \le 0.1$$

Similarly, when *σ*(*t*) = *q*, we can also obtain

$$\begin{aligned} \dot{V}\_{q1}(t) &\leq a\eta^T(t)e\_1^T \left(\Phi\_1^p + d(t)\Theta\_1^p\right)e\_1\eta(t) + \bar{a}\eta^T(t)e\_1^T \left(\Phi\_1^q + d(t)\Theta\_1^q\right)e\_1\eta(t) \\ &+ \int\_{t-d(t)}^t \dot{\mathfrak{x}}^T(s)L\dot{\mathfrak{x}}(s)ds, \end{aligned}$$

which further yields *V*˙ *<sup>q</sup>*(*t*) < 0. Therefore, the Lyapunov-Krasovskii functional *Vσ*(*t*) is decreasing when the sliding motions occur on switching surface Ω*pq*. According to (22) one can deduce that the switched system (1) under the switching rule (3) is also globally asymptotically stable if the sliding motions occur on switching surfaces Ω*pq* with *ηp*,*<sup>q</sup>* > 0.

**Remark 2.** *According to the Proof of Theorem 1, one can see that the chosen Lyapunov functional is function of x*(*t*) *and x*˙(*t*)*. Similar Lyapunov functionals have been employed to establish the stability results for delayed systems [29,30]. This is because such Lyapunov functionals can fully utilize the features of systems. Most noteworthy, the proposed Lyapunov functional can be viewed as a special case of that presented in [29,30].*

**Remark 3.** *Condition (7) ensures that the time derivate of Lyapunov functional along the trajectory of switched systems is less than zero for each region* Ω*p. Condition (8) guarantees that the Lyapunov functional is not increasing when the switching event occurs in the absence of sliding motion. When sliding motions occur, Conditions (7) and (8) can warrant that the time derivate of Lyapunov functional along the trajectory is less than zero when the trajectory slides on the surfaces* Ω*pq. Condition (9) ensures that the switched system is well-defined.*

**Remark 4.** *In [15–19], the researchers have also studied the stability of delayed switched systems under state-dependent switchings. However, these results assume that there exists a Hurwitz linear convex combination of Ap* + *Bp(or Ap). Generally speaking, this assumption is rigorous and may not be satisfied in some cases. Obviously, in Theorem 1 we have removed this restriction, which yields that our results are more flexible. Moreover, in the proof of Theorem 1 new inequality (Lemma 1) is employed, which states that Theorem 1 is less conservative.*

**Remark 5.** *When there exist infinite switching events in an arbitrary time interval, we call it Zeno-behaviors. The switching rule (3) cannot avoid Zeno-behaviors. However, Theorem 1 can also ensure stability when Zeno-behaviors occur. The reasons can be listed as follows: (a) If the switching event does not occur, it is obvious that the time derivate of Lyapunov functional along the trajectory is less than zero. (b) If the switching event occurs, there are two cases. The first one is that the sliding motion does not occur. Obviously, for this case, the Lyapunov functional is not increasing. The second one is that the sliding motions occur. For this case, we have that the time derivate of Lyapunov functional along the trajectory is still less than zero. Although Zeno-behaviors* *may lead to the accumulation of switches in finite time, the Lyapunov functional along the trajectory is always gradually decreasing.*

By restricting *R*<sup>1</sup> = *R*<sup>2</sup> = *R* and *S*<sup>1</sup> = *S*<sup>2</sup> = *S*, one can obtain the stability results under Assumption 2.

**Theorem 2.** *Under Assumption 2, assume that for any p* ∈ *M, there exist n* × *n matrices Pp* > 0*, R* > 0, *S* > 0, *U* > 0*, Qp* = *Q<sup>T</sup> <sup>p</sup> , positive constants μp*, *θp, constants ηp*,*q, q* ∈ *M, q* = *p, such that Conditions (8) and (9) and*

$$
\begin{pmatrix}
\Lambda^p + \mu\_P \boldsymbol{e}\_1^T \boldsymbol{Q}\_P \boldsymbol{e}\_1 & \sqrt{d} \boldsymbol{e}\_1^T \boldsymbol{P}\_P \boldsymbol{B}\_p \\
\sqrt{d} \boldsymbol{B}\_p^T \boldsymbol{P}\_p \boldsymbol{e}\_1 & -\boldsymbol{U}
\end{pmatrix} < 0,
\tag{31}
$$

*where* Λ*<sup>p</sup>* = Λ*<sup>p</sup>* <sup>1</sup> *with R*<sup>1</sup> = *R*<sup>2</sup> = *R and S*<sup>1</sup> = *S*<sup>2</sup> = *S. Then, the switched system (1) is globally asymptotically stable under the state-dependent switching rule (3), if there is no sliding motion or there exist sliding motions on the switching surface* Ω*pq with ηp*,*<sup>q</sup>* > 0*.*

Due to the existence of the product of unknown scalars and matrices, the conditions in Theorems 1 and 2 are BMIs. Therefore, the standard semi-positive definite programming methods cannot work. One can adopt two strategies to get a feasible solution. The first one is to utilize directly BMI solvers (such as PENBMI) to obtain these undetermined scalars and matrices. The second one, which is similar to [22], is to grid up the unknown scalars *μp*, *θ<sup>p</sup>* and *ηp*,*q*. While these parameters are fixed, the BMIs in Theorems 1 and 2 degenerate into ordinary linear matrix inequalities, which can be solved by standard solvers such as lmilab and mosek.

When the switched system (1) is composed of two subsystems, one can set *<sup>Q</sup>*<sup>1</sup> = −*Q*<sup>2</sup> = *<sup>Q</sup>* = *<sup>Q</sup>T*, *<sup>η</sup>*1,2 = *<sup>η</sup>*2,1 = *<sup>η</sup>*, *<sup>P</sup>*<sup>2</sup> = *<sup>P</sup>*, *<sup>P</sup>*<sup>1</sup> = *<sup>P</sup>* − <sup>2</sup>*ηQ*, constants *<sup>θ</sup>*<sup>1</sup> = *<sup>θ</sup>*<sup>2</sup> = 1. Then, Conditions (8) and (9) are always satisfied. The following corollaries can be derived readily from Theorems 1 and 2.

**Corollary 1.** *When M* = {1, 2}*, under Assumption 1, assume that there exist n* × *n matrices P* > 0*, Ri* > 0, *Si* > 0, *U* > 0*, (i* = 1, 2*), Q* = *QT, positive constants μ*1*, μ*2*, constant η, such that*

$$
\begin{pmatrix}
\Lambda\_1^{1\*} + \mu\_1 e\_1^T Q \varepsilon\_1 & \sqrt{d} e\_1^T (P - 2\eta Q) B\_1 \\
\sqrt{d} B\_1^T (P - 2\eta Q) \varepsilon\_1 & -lI
\end{pmatrix} < 0,\tag{32}
$$

$$
\begin{pmatrix}
\Lambda\_1^{2\*} - \mu\_2 \boldsymbol{e}\_1^T \boldsymbol{Q} \boldsymbol{e}\_1 & \sqrt{d} \boldsymbol{e}\_1^T \boldsymbol{P} \boldsymbol{B}\_2 \\
\sqrt{d} \boldsymbol{B}\_2^T \boldsymbol{P} \boldsymbol{e}\_1 & -\boldsymbol{U}
\end{pmatrix} < 0, l = 1, 2,
\tag{33}
$$

*where*

$$\begin{split} \Lambda\_{1}^{1\*} &= \Phi\_{1}^{1\*} + \Phi\_{2} + \Phi\_{3}^{1} + \Phi\_{4}^{1} + \left(1 - \bar{d}\right) (\Psi\_{2} + \Psi\_{3}) - \frac{1}{d} \Xi\_{4\prime} \\ \Lambda\_{1}^{2\*} &= \Phi\_{1}^{2\*} + \Phi\_{2} + \Phi\_{3}^{2} + \Phi\_{4}^{2} + \left(1 - \bar{d}\right) (\Psi\_{2} + \Psi\_{3}) - \frac{1}{d} \Xi\_{4\prime} \\ \Lambda\_{2}^{1\*} &= \Phi\_{1}^{1\*} + \Phi\_{2} + \Phi\_{3}^{1} + \Phi\_{4}^{1} + \left(1 - \bar{d}\right) (\Psi\_{2} + \Psi\_{3}) - \frac{1}{d} \Xi\_{4\prime} \\ \Lambda\_{2}^{2\*} &= \Phi\_{1}^{2\*} + \Phi\_{2} + \Phi\_{3}^{2} + \Phi\_{4}^{2} + \left(1 - \bar{d}\right) (\Psi\_{2} + \Psi\_{3}) - \frac{1}{d} \Xi\_{4\prime} \\ \Phi\_{1}^{1\*} &= e\_{1}^{T} \Big( \left(A\_{1} + B\_{1}\right)^{T} (P - 2\eta Q) + \left(P - 2\eta Q\right) \left(A\_{1} + B\_{1}\right) \big) e\_{1}, \\ \Phi\_{1}^{2\*} &= e\_{1}^{T} \Big( \left(A\_{2} + B\_{2}\right)^{T} P + P \left(A\_{2} + B\_{2}\right) \Big) e\_{1}. \end{split}$$

*and the other notations are in agreement with the ones presented in Theorem 1. Then, the switched system (1) is globally asymptotically stable under the state-dependent switching rule (3) if there is no sliding motion or there exist sliding motions on switching surfaces with η* > 0*.*

**Corollary 2.** *When M* = {1, 2}*, under Assumption 2, assume that there exist n* × *n matrices P* > 0*, R* > 0, *S* > 0, *U* > 0*, Q* = *QT, positive constants μ*1*, μ*2*, constant η, such that*

$$
\begin{pmatrix}
\Lambda^{1\*} + \mu\_1 e\_1^T Q \varepsilon\_1 & \sqrt{d} e\_1^T (P - 2\eta Q) B\_1 \\
\sqrt{d} B\_1^T (P - 2\eta Q) \varepsilon\_1 & -lI
\end{pmatrix} < 0,\tag{34}
$$

$$
\begin{pmatrix}
\bar{\Lambda}^{2\*} - \mu\_2 \varepsilon\_1^T Q \varepsilon\_1 & \sqrt{d} \varepsilon\_1^T P B\_2 \\
\sqrt{d} B\_2^T P \varepsilon\_1 & -U
\end{pmatrix} < 0,
\tag{35}
$$

*where* Λ¯ <sup>1</sup><sup>∗</sup> = Λ1<sup>∗</sup> <sup>1</sup> *,* <sup>Λ</sup>¯ <sup>2</sup><sup>∗</sup> <sup>=</sup> <sup>Λ</sup>2<sup>∗</sup> <sup>1</sup> *with R*<sup>1</sup> = *R*<sup>2</sup> = *R and S*<sup>1</sup> = *S*<sup>2</sup> = *S. Then, the switched system (1) is globally asymptotically stable under the state-dependent switching rule (3), if there is no sliding motion or there exist sliding motions on switching surfaces with η* > 0*.*

#### **4. Numerical Simulations**

In this section, several numerical examples are employed to illustrate the validity of the proposed results.

**Example 1.** *Consider the switched system (1) with M* = {1, 2} *and*

$$A\_1 = \begin{pmatrix} 0.8 & -4 \\ 0 & 0.8 \end{pmatrix}, B\_1 = \begin{pmatrix} 0.2 & -1 \\ 0 & 0.2 \end{pmatrix}, A\_2 = \begin{pmatrix} 0.8 & 0 \\ 4 & 0.9 \end{pmatrix}, B\_2 = \begin{pmatrix} 0.2 & 0 \\ 1 & 0.1 \end{pmatrix}.$$

*By choosing <sup>μ</sup>*<sup>1</sup> <sup>=</sup> *<sup>μ</sup>*<sup>2</sup> <sup>=</sup> <sup>1</sup>*, <sup>η</sup>* <sup>=</sup> <sup>−</sup>0.7 *and letting* ¯*<sup>d</sup>* <sup>=</sup> <sup>−</sup> ˜*<sup>d</sup>* <sup>=</sup> *<sup>δ</sup>, according to Corollaries <sup>1</sup> and 2, we can obtain the upper bound d for different δ, which is given in Table 1 (in order to avoid zero solution, the matrix inequalities P*, *Ri*, *Si*, *U* > *aI with a* = 10−<sup>7</sup> *are employed to replace P*, *Ri*, *Si*, *U* > 0*, i* = 1, 2*). For numerical simulation, we choose d*(*t*) = 0.1 + 0.1 sin(10*t*)*, which shows d* <sup>=</sup> 0.2 *and* ¯*<sup>d</sup>* <sup>=</sup> <sup>−</sup> ˜*<sup>d</sup>* <sup>=</sup> <sup>1</sup>*. By solving the matrix inequalities in Corollary 1, we get*

$$\begin{aligned} Q\_1 &= -Q\_2 = Q = \begin{pmatrix} -0.2567 & 0.1996 \\ 0.1996 & 0.2565 \end{pmatrix}, P\_1 = P - 2\eta Q = \begin{pmatrix} 0.0935 & 0.1402 \\ 0.1402 & 4516 \end{pmatrix}, \\\ P\_2 = P = \begin{pmatrix} 0.4528 & -0.1393 \\ -0.1393 & 0.0925 \end{pmatrix}. \end{aligned}$$

*The stable dynamics and convergent time response curves with <sup>φ</sup>*(*s*)=(−1, 2)*T, s* = [−0.2, 0]*, are plotted in Figure 1. The corresponding switching rule (3) is also shown in the sub-figure of Figure 1. Numerical simulations indicate that there is no sliding motion.*

*Now we give some comparisons with the existing results for this example to validate the superiority of our results.*


*(c) For the case of constant time delay, by solving the matrix inequalities in ([27] Theorem 5), one can get the upper bound d* = 0.2455*, which is also less than d* = 0.2489*. Therefore, the restriction on the time delay of our results is weaker than that proposed in ([27] Theorem 5).*


**Table 1.** The upper bound *<sup>d</sup>* of time delay for different ¯*<sup>d</sup>* <sup>=</sup> <sup>−</sup> ˜*<sup>d</sup>* <sup>=</sup> *<sup>δ</sup>*.

**Figure 1.** The stable dynamics (**Left**) and convergent response curves (**Right**) of the system in Example 1 with *d*(*t*) = 0.1 + 0.1 sin(10*t*).

**Example 2.** *Consider the switched system (1) wiht M* = {1, 2}*, d*(*t*) = 0.01 + 0.01 sin(50*t*)*, and*

$$A\_1 = \begin{pmatrix} 0 & 1 \\ 2 & -8 \end{pmatrix}, B\_1 = \begin{pmatrix} 0 & 0 \\ 0 & -1 \end{pmatrix}, A\_2 = \begin{pmatrix} 0 & 0.5 \\ -2 & 1 \end{pmatrix}, B\_2 = \begin{pmatrix} 0 & 0.5 \\ 0 & 1 \end{pmatrix}.$$

*It is easy to derive that <sup>d</sup>* <sup>=</sup> 0.02*,* ¯*<sup>d</sup>* <sup>=</sup> <sup>−</sup> ˜*<sup>d</sup>* <sup>=</sup> 0.5*. By choosing <sup>η</sup>* <sup>=</sup> 0.01*, <sup>μ</sup>*<sup>1</sup> <sup>=</sup> *<sup>μ</sup>*<sup>2</sup> <sup>=</sup> <sup>1</sup>*, according to Corollary 1, we get the following feasible solution*

$$\begin{aligned} Q\_1 &= -Q\_2 = Q = \begin{pmatrix} -0.02252 & 0.0251 \\ 0.0251 & 0.0287 \end{pmatrix}, P\_1 = P - 2\eta Q = \begin{pmatrix} 0.0202 & 0.1940 \\ 0.0055 & 0.0028 \end{pmatrix}, \\\ P\_2 = P &= \begin{pmatrix} 0.0198 & 0.0060 \\ 0.0060 & 0.0034 \end{pmatrix}. \end{aligned}$$

*The sliding dynamics and stable response curves are shown in Figure 2. Numerical simulations indicate that there are sliding motions, which can make the trajectory approach the origin along the switching surfaces.*

*If we choose η* = −1*, μ*<sup>1</sup> = *μ*<sup>2</sup> = 1*, by solving the matrix inequalities in Corollary 1, we obtain*

$$\begin{aligned} Q\_1 &= -Q\_2 = Q = \begin{pmatrix} 0.0127 & -0.0964 \\ -0.0964 & 0.4049 \end{pmatrix}, P\_1 = P - 2\eta Q = \begin{pmatrix} 0.0397 & -0.1802 \\ -0.1802 & 0.8545 \end{pmatrix}, \\\ P\_2 = P &= \begin{pmatrix} 0.0143 & 0.0125 \\ 0.0125 & 0.0446 \end{pmatrix}. \end{aligned}$$

*Numerical simulations show that there are unstable sliding motions for this case (see Figure 3), which is due to η*1,2 = *η*2,1 = *η* < 0*. This demonstrates that ηp*,*<sup>q</sup>* > 0 *is essential for the stability of the switched system (1) when sliding motions occur.*

**Figure 2.** The stable dynamics (**Left**) and convergent response curves (**Right**) of the system in Example 2.

**Figure 3.** The unstable dynamics (**Left**) and unstable response curves (**Right**) of the system in Example 2.

#### **5. Conclusions**

This paper has investigated the stability of delayed switched systems with all unstable subsystems. Under the designed state-dependent switching rule, some stability results for different assumptions on time delay are derived via integral inequality and multiple Lyapunov-Krasovskii functionals. Numerical simulations demonstrate that the proposed results are more effective and less conservative than that presented in [15–21,27].

The main deficiency of this paper is that the condition that determines whether sliding motions occur is not employed. As a matter of fact, similar to [21,22], we have derived some conditions to verify the existence or non-existence of sliding motions. Unfortunately, if we introduce these conditions to the stability results, it is difficult to get a feasible solution. In desperation, we adopt the way which is used in [34,35]. Namely, the condition to determine whether sliding motions occur is not given and the existence or non-existence of sliding motions is revealed via numerical simulation. We hope some more feasible conditions on sliding motions can be deduced in the future.

**Author Contributions:** Investigation, C.L.; Writing—original draft, X.L. All authors have read and agreed to the published version of the manuscript.

**Funding:** This research was funded by the Research Foundation of the Natural Foundation of Chongqing City, Grant Nomber: cstc2019jcyj-msxmX0492.

**Conflicts of Interest:** The authors declare no conflict of interest.

#### **References**


### *Article* **A Novel Multi-Source Domain Adaptation Method with Dempster–Shafer Evidence Theory for Cross-Domain Classification**

**Min Huang \* and Chang Zhang \***

School of Software Engineering, South China University of Technology (SCUT), Guangzhou 510006, China **\*** Correspondence: minh@scut.edu.cn (M.H.); sezchang\_2020@mail.scut.edu.cn (C.Z.)

**Abstract:** In this era of big data, Multi-source Domain Adaptation (MDA) becomes more and more popular and is employed to make full use of available source data collected from several different, but related domains. Although multiple source domains provide much information, the processing of domain shifts becomes more challenging, especially in learning a common domain-invariant representation for all domains. Moreover, it is counter-intuitive to treat multiple source domains equally as most existing MDA algorithms do. Therefore, the domain-specific distribution for each source–target domain pair is aligned, respectively. Nevertheless, it is hard to combine adaptation outputs from different domain-specific classifiers effectively, because of ambiguity on the category boundary. Subjective Logic (SL) is introduced to measure the uncertainty (credibility) of each domainspecific classifier, so that MDA could be bridged with DST. Due to the advantage of information fusion, Dempster–Shafer evidence Theory (DST) is utilized to reduce the category boundary ambiguity and output reasonable decisions by combining adaptation outputs based on uncertainty. Finally, extensive comparative experiments on three popular benchmark datasets for cross-domain image classification are conducted to evaluate the performance of the proposed method via various aspects.

**Keywords:** multi-source domain adaptation; Dempster–Shafer evidence theory; cross-domain classification

**MSC:** 68T07

#### **1. Introduction**

Recently, Deep Learning (DL) has made remarkable advances in various fields [1–7], especially in classification [8–10]. Despite excellent results, the success of deep methods highly relies on: (1) large-scale labeled data for supervised learning and (2) the training and test data meeting the requirement of being Independently Identically Distributed (IID). However, annotation is time-consuming and unaffordable in practice. If a model is trained on a dataset (known as the source domain), but tested on another non-IID dataset (known as the target domain), domain shifts occur and tend to severely degrade the performance of the learned model [11,12]. Therefore, it is necessary to develop models that are trained on the given labeled datasets, but that can generalize well to a non-IID unlabeled dataset.

Domain Adaptation (DA) aims to learn a discriminative model by reducing domain shifts between training and test distributions [13]. DA transfers the given labeled source domain knowledge to tackle the task to the different, but related target domain by learning domain-invariant representation between domains. Most approaches focus on Single-source Domain Adaptation (SDA), where the labeled data from only one single source domain are considered. Many achievements have emerged in this decade [14–18]. For example, DDC [14] adds an adaptation layer to the pre-trained AlexNet model to confuse the feature representation between the single source domain and the target domain. DSAN [16] proposes a novel fine-grained metric function to align the distribution of the single source domain and the target domain. Most of them learn to map the data from both

**Citation:** Huang, M.; Zhang, C. A Novel Multi-Source Domain Adaptation Method with Dempster–Shafer Evidence Theory for Cross-Domain Classification. *Mathematics* **2022**, *10*, 2797. https:// doi.org/10.3390/math10152797

Academic Editors: Jianping Gou, Weihua Ou, Shaoning Zeng and Lan Du

Received: 23 June 2022 Accepted: 4 August 2022 Published: 6 August 2022

**Publisher's Note:** MDPI stays neutral with regard to jurisdictional claims in published maps and institutional affiliations.

**Copyright:** © 2022 by the authors. Licensee MDPI, Basel, Switzerland. This article is an open access article distributed under the terms and conditions of the Creative Commons Attribution (CC BY) license (https:// creativecommons.org/licenses/by/ 4.0/).

domains into a common feature space to learn domain-invariant representations by minimizing domain distribution discrepancy, so that the source classifier could then be directly applied to target instances.

However, in practice, it is very likely to obtain multiple available source domains, while SDA is not up to employing those source data adequately. Hence, more challenging, Multi-source Domain Adaptation (MDA) is developed to utilize labeled data from multiple source domains with different distributions and has attracted extensive attention these days [19–21]. The most straightforward way is to combine all source domains into one single source domain and, then, directly apply SDA methods to align distributions. Due to the dataset expansion, the methods might improve the performance. However, the improvements might not be sufficient; the more accurate ways are supposed to explore to make full use of source domains.

With the spurt of progress in DL and SDA today, MDA has been gradually developed. However, there are two typical issues with most techniques [22–28]. (1) Firstly, it is more challenging to learn a common domain-invariant representation for all domains in MDA, because the damages of domain shifts cannot be eliminated even in SDA. Thereby, MDA is processed by aligning the domain-specific distribution for each source–target domain pair. (2) Secondly, multiple source domains are treated as equivalents. However, the benefits of each source domain to the target domain tasks are diverse in reality. The final output should be closer to the adaptation output of the source–target domain pairs with higher credibility. Some studies [29,30] add extra neural network components to measure the credibility (i.e., transferability). In this research study, we employed Subjective Logic (SL) [31] to obtain the uncertainty of every source domain without any addition of the neural network. Regarding source–target domain pairs as witnesses with different credibility (uncertainty), we introduced Dempster–Shafer evidence Theory (DST) to combine all domain-specific adaptation outputs.

As an uncertainty reasoning method, DST can effectively and reliably deal with uncertainty. It relies on Basic Probability Assignment Functions (BPAFs) to measure the initial degree of belief in the occurrence of an event, which is similar to the concept of the "probability" of a random event in probability theory. To generate BPAFs, DST is bridged with MDA and DL by subjective logic.

Our contributions are summarized as follows:


The rest of this paper is organized as follows. Section 2 reviews the related work. In Section 3, the preliminaries are given. Section 4 describes the proposed method in detail. A series of experiments is reported in Section 5 and discussed in Section 6. Finally, Section 7 summarizes this research study.

#### **2. Related Work**

#### *2.1. Single-Source Domain Adaptation*

Single-source Domain Adaptation (SDA) is bound up with multi-source domain adaptation. SDA aims to generalize a model learned from a labeled source domain to a related unlabeled target domain with a different data distribution by reducing the domain shift. SDA can be roughly divided into three categories according to different alignment strategies. (1) Discrepancy-based approaches utilize different metric schemas to explicitly measure the distance between the source and target domains and diminish the domain shift. Commonly used discrepancy metrics for domain adaptation include Maximum Mean Discrepancy (MMD) [32–34], moment matching [35,36], Kullback–Leibler (KL) divergence [37], correlation alignment [38,39], and mixture distance [40]. (2) Adversarial-based approaches align different data distributions by confusing a well-trained discriminator (domain classifier). Many methods [41–46] are based on Generative Adversarial Networks (GANs), which align different data distributions by implicitly learning the metric function (i.e., domain discriminator) between the source and target domains. (3) Reconstruction-based approaches assume that reconstructing the target domain from a latent representation by using the source task model can help learn domain-invariant representations. The reconstruction is usually obtained via an auto-decoder [47–49] or a GAN discriminator [50–52].

In our work, the first kind of approach was chosen and the most widely used discrepancy MMD was employed to align the distributions.

#### *2.2. Multi-Source Domain Adaptation*

In practice, available source data often come from several different, but related domains. Multi-source Domain Adaptation (MDA) is developed to make full use of these data. However, multiple source domain data provide much information, but challenge the processing of domain shifts. (1) Based on the assumption that the target domain distribution can be approximated by mixing the source domain distribution [53,54], some MDA methods focus on the weighted combination of source domains. For example, Sun and Shi [22] designed a method to weight the source domain classifiers based on the Bayesian learning principle. Xu et al. [23] proposed a voting method for multiple classifiers, which is based on the output of domain discriminators. (2) In addition, some methods are devised to map all source domains and the target domain to a unified feature space. For instance, MDAN [24] aligns the distribution of source domains with the target domain through multiple domain discriminators. M3SDA [25] employs moment matching to align the source–target and source-source domains in a common feature space. HoMM [26] exploits the high-order statistics for domain alignment in a reproducing kernel Hilbert space. (3) Some other methods are based on reconstruction [27,28], which reconstruct multiple source domains into an intermediate single source domain and then directly carry out SDA.

Sadly, the damages of domain shifts cannot be eliminated in SDA. It is more difficult to learn a common domain-invariant representation for all domains in MDA. Following MFSAN [55], the domain-specific distribution and classifier alignment architecture for cross-domain classification has proceeded. However, MFSAN treats every source domain equally. This is counter-intuitive because different source domains help the target task differently. Thus, regarding source–target domain pairs as witnesses with different credibility (uncertainty), DST is employed to combine all domain-specific adaptation results. Specifically, the uncertainty is captured, and BPAFs are generated by using subjective logic.

#### *2.3. Dempster–Shafer Evidence Theory*

Dempster–Shafer evidence Theory (DST) was first introduced in the 1960s. Based on the investigation of statistical problems, Arthur P. Dempster introduced the concept of upper and lower probabilities and their combining rules [56]. Then, the form of probability that does not satisfy additivity was defined for the first time [57]. Later, Glenn Shafer reinterpreted the upper and lower probabilities based on the belief function and developed the theory into a general framework for modeling epistemic uncertainty [58]. DST allows beliefs from different sources to be fused with various operators to obtain new beliefs considering all available evidence [59]. Currently, generating the belief function through DL has proven to be successful and efficient [60]. These unique characteristics make DST particularly suitable for information fusion [61,62]. Similar to information fusion, the idea of our MDA method is to combine evidence from multiple sources.

#### **3. Preliminaries**

#### *3.1. Unsupervised Multi-Source Domain Adaptation*

In this research study, the unsupervised MDA problem is investigated. Let D*<sup>s</sup>* = {D*si*}*<sup>N</sup> <sup>i</sup>*=<sup>1</sup> denote a collection of *N* available datasets of source domains, and each labeled source dataset <sup>D</sup>*si* <sup>=</sup> {(*X*(*j*) *si* , *y* (*j*) *si* )}*nsi <sup>j</sup>*=<sup>1</sup> with *nsi* samples is sufficient to train a source domain distribution model. Meanwhile, a target dataset <sup>D</sup>*<sup>t</sup>* <sup>=</sup> {*X*(*j*) *<sup>t</sup>* }*nt <sup>j</sup>*=<sup>1</sup> with *nt* samples drawn from the target domain D*<sup>t</sup>* has no labels to support training a reasonable distribution model. With given D*<sup>s</sup>* ∪ D*t*, the general goal of this problem is to train a cross-domain classifier *f<sup>θ</sup>* (*x*), which has a low target risk  *<sup>t</sup>* = E*x*∈D*<sup>t</sup>* [ *f<sup>θ</sup>* (*x*) = *yt*].

The domain-specific distribution and classifier alignment architecture in MFSAN [55] has proceeded to cross-domain classification. Thus, the domain adaptation model involves the source domain task loss L*s*, the domain adaptation loss L*d*, and the classifier constraint loss L*r*. As shown in (1), *λ* and *γ* are trade-off parameters.

$$
\mathcal{L} = \mathcal{L}\_s + \lambda \mathcal{L}\_d + \gamma \mathcal{L}\_r \tag{1}
$$

#### *3.2. Maximum Mean Discrepancy*

Maximum mean discrepancy, inspired by the two-sample test in statistics [63,64], is the most widely used discrepancy to align the distributions in domain adaptation. In general, MMD is interpreted as the maximum value (upper bound) of the expectation difference between two distributions mapped by any function *f* in a predefined function field F, which is an arbitrary vector in the unit sphere (i.e., *f* < 1) of the reproducing Hilbert space:

$$\text{MMD}[\mathcal{F}, p, q] := \sup\_{f \in \mathcal{F}} \left( \mathbf{E}\_p[f(\mathbf{x})] - \mathbf{E}\_q[f(y)] \right) \tag{2}$$

In practice, an estimate of the MMD compares the square distance between the empirical kernel mean embeddings as (3). H is the Reproducing Kernel Hilbert Space (RKHS) endowed with a characteristic kernel *k*. *k* means *k* **x***s* , **x***t* = & *φ*(**x***<sup>s</sup>* ), *φ* **x***t* ', where ·, · represents the inner product of vectors and *φ*(·) denotes some feature map to map the original samples to the RKHS H.

$$\text{MMD}^2[\mathcal{F}, \mathbf{X}\_s, \mathbf{X}\_t] = \left\| \frac{1}{n\_s} \sum\_{\mathbf{x}\_i \in \mathcal{D}\_t} \phi(\mathbf{x}\_i) - \frac{1}{n\_t} \sum\_{\mathbf{x}\_j \in \mathcal{D}\_t} \phi(\mathbf{x}\_j) \right\|\_{\mathcal{H}}^2 \tag{3}$$

#### *3.3. Basic Concepts of DST*

The Basic Probability Assignment Function (BPAF) is the fundamental unit of DST, which expresses the initial degree of belief in the proposition. Let Θ be a frame of discernment, which specifies the proposition range. The function *<sup>m</sup>* : <sup>2</sup><sup>Θ</sup> <sup>→</sup> [0, 1] becomes the BPAF when it satisfies (4). If *m*(*A*) > 0, *m*(*A*) is also called the belief mass, and *A* is named the focal element.

$$\begin{cases} \begin{array}{l} \mathfrak{m}(\bigcirc) = 0\\ \sum\_{A \subseteq \bigcirc} \mathfrak{m}(A) = 1 \end{array} \end{cases} \tag{4}$$

Dempster's rule ⊕ is at the core of DST, as it provides algorithmic rules for combining two pieces of evidence, as shown in (5). Besides, Dempster's rule is invoked *N* − 1 times to combine *N* sets of evidence.

$$m\_1(X) \oplus m\_2(X) = \begin{cases} \ 0, X = \mathcal{Q} \\ \frac{1}{1 - X} \sum\_{A\_i \cap B\_j = X} m\_1(A\_i) m\_2(B\_j), X \neq \mathcal{Q} \end{cases} \tag{5}$$

The definition of conflict factor K, shown in (6), reflects the degree of conflict between *m*<sup>1</sup> and *m*2, whereby 1/(1 − *K*) represents the normalization factor. Obviously, Dempster's rule tries to fuse shared parts from different sources and ignores conflicting beliefs.

$$K = \sum\_{A\_i \cap B\_j = \bigcirc} m\_1(A\_i) m\_2\left(B\_j\right) \tag{6}$$

#### *3.4. Dirichlet Distribution*

The Dirichlet distribution is involved in SL, which bridges DL, MDA, and DST. In the context of multi-class classification, SL converts the outputs (from DL and MDA) of the neural networks into the concentration parameter of the Dirichlet distribution and associates it with the belief masses (for DST). Accordingly, DST could combine multisource evidence after BPAFs are obtained and output the final decision.

If the probability density function of multivariate continuous random variable *θ* = {*θ*1, *θ*2,..., *θk*} is (7):

$$p(\boldsymbol{\theta} \mid \boldsymbol{a}) = \frac{\Gamma\left(\sum\_{i=1}^{k} \boldsymbol{a}\_{i}\right)}{\prod\_{i=1}^{k} \Gamma(\boldsymbol{a}\_{i})} \prod\_{i=1}^{k} \theta\_{i}^{a\_{i}-1} \tag{7}$$

where ∑*<sup>k</sup> <sup>i</sup>*=<sup>1</sup> *θ<sup>i</sup>* = 1, *θ<sup>i</sup>* ≥ 0, *α<sup>i</sup>* > 0, *i* = 1, 2, ... , *k*, and Γ(·) is the Gamma function. Then, the random variable *θ* is said to obey the Dirichlet distribution with concentration parameter *α* and denoted as *θ* ∼ *Dir*(*α*).

Dirichlet distribution *θ* exists on the (*k* − 1)-dimensional simplex, as shown in Figure 1.

**Figure 1.** Visualization of Dirichlet distribution, where *θ* = {*θ*1, *θ*2, *θ*3} and *θ*1, *θ*2, *θ*<sup>3</sup> ≥ 0, *θ*<sup>1</sup> + *θ*<sup>2</sup> + *θ*<sup>3</sup> = 1. (**a**) *α* = (10, 1, 1); (**b**) *α* = (1.001, 1.001, 1.001); (**c**) *α* = (10, 10, 10). Bright yellow represents high probability, and dark blue represents low probability. In the multi-classification problem, each vertex is regarded as a category.

The most important property of the Dirichlet distribution is that it is the conjugate prior to the multinomial distribution. If *θ* follows the Dirichlet distribution, its prior probability distribution is *p*(*θ*|*α*) = *Dir*(*θ*|*α*) and posterior probability distribution is *p*(*θ*|*D*, *α*) = *Dir*(*θ*|*α* + *n*), where *D* is the given simplex and *n* = (*n*1, *n*2, ... , *nk*) is the observation count of the multinomial distribution. The concentration parameters *α* = {*α*1, *α*2, ... , *αk*} of the Dirichlet distribution as a priori distribution are also called the hyperparameters of the posterior distribution. Hence, it is convenient to obtain the posterior distribution from the prior distribution.

#### **4. Research Methodology**

Following the two-stage alignment framework in MFSAN [55], a novel Multi-source domain Adaptation Network with Dempster–Shafer evidence theory (MAN-DS) for crossdomain classification is proposed. MAN-DS aims to train a model based on multi-source domain labeled samples and adapts to classify target instances with different distributions. As shown in Figure 2, the MAN-DS framework consists of four key components, i.e., common feature extractor, domain-specific feature extractor, domain-specific classifier, and Dempster's combination. Different source domains are extracted into different feature

spaces, and then, the distribution alignment of each pair of source and target domains and the output alignment of every source classifier are imposed. Domain-specific adaptation outputs are combined by Dempster's rule in the end. Besides, the *so f tmax* layer of the classifier is replaced with an activation layer (e.g., ReLU).

**Figure 2.** The overall structure of MAN-DS.

#### *4.1. Common Feature Extractor*

The damages of domain shifts cannot be eliminated in SDA, so it is more difficult to learn a common domain-invariant representation for all domains in MDA. To address this problem, the easiest way is to train multiple networks to map each source–target domain pair into a specific feature space. However, this would take too much time and space. Thus, the feature extractor is divided into two parts. The first part extracts common features, and the second part extracts domain-specific features (see the next section). In the first part, a common convolutional neural subnetwork *f*(·) is used to automatically map samples in all domains from the original feature space into a common feature space.

#### *4.2. Domain-Specific Feature Extractor*

Now, we come to the second part where domain-specific features are extracted by different extractors. For each pair of source and target domains, a specific subnetwork *hi*(·) aims to map *f*(*xsi*) and *f*(*xt*) into the same domain-specific feature space. The objective of domain adaptation is to find a domain-invariant representation between domains. In other words, an *hi*(·) is desired, which makes the distribution discrepancy between *hi*(*f*(*xsi*)) and *hi*(*f*(*xt*)) as small as possible. There are many explicit or implicit methods to achieve this goal. Here, the most widely used MMD is employed to reduce the distribution discrepancy between domains. The MMD loss is reformulated as:

$$\mathcal{L}\_{mmd} = \frac{1}{N} \sum\_{i=1}^{N} \text{MMD}^2[\mathcal{F}\_\prime h\_i(f(X\_i i)), h\_i(f(X\_l))] \tag{8}$$

#### *4.3. Domain-Specific Classifier*

Traditionally, a series of *so f tmax* classifiers *ci*(·) is employed to classify the source domain samples after extracting domain-specific invariant features, respectively. However, the use of the exponent in the *so f tmax* function leads to the probability of the predicted category being inflated. It was replaced with an activation function (e.g., RELU) to ensure that the network outputs non-negative values in this research study. The multiclassification problem is a multinomial distribution fitting problem. As the conjugate prior, the Dirichlet distribution is convenient to obtain the posterior distribution from the prior distribution.

Subjective logic [31] defines a theoretical framework for obtaining the probabilities of different classes and the overall uncertainty of the multi-classification problem based on the *evidence* collected from the data. SL provides an additional mass function, which allows the model to distinguish between a lack of evidence. In our model, SL provides the degree of overall uncertainty of each source, which is important for final decisions to some extent.

For the *K*-classification problem, the nonnegative-activated output *e* = (*e*1,*e*2, ... ,*ek*) of the last fully connected layer of the classifier refers to *evidence* and is closely related to the concentration parameters *α* = (*α*1, *α*2, ... , *α* + *k*) of the Dirichlet distribution, as shown in the following:

$$
\mathfrak{a}\_k = \mathfrak{e}\_k + \mathbf{1}, \quad k = 1, 2, \dots, K \tag{9}
$$

With subjective logic, for each pair of the source–target domain, the probability *b* (*i*) *k* for the *k*th category and the overall uncertainty *u*(*i*) are calculated by:

$$\begin{aligned} b\_k^{(i)} &= \frac{\mathfrak{e}\_k^{(i)}}{\mathcal{S}^{(i)}} = \frac{\mathfrak{a}\_k^{(i)} - 1}{\mathcal{S}^{(i)}}\\ \mathfrak{u}^{(i)} &= \frac{K}{\mathcal{S}^{(i)}} \end{aligned} \tag{10}$$

where *S*(*i*) = ∑*<sup>K</sup> <sup>k</sup>*=1(*e* (*i*) *<sup>k</sup>* <sup>+</sup> <sup>1</sup>) = <sup>∑</sup>*<sup>K</sup> k*=1(*α*(*i*) *<sup>k</sup>* ) is the Dirichlet strength. Obviously, *<sup>u</sup>*(*i*) <sup>+</sup> ∑*<sup>K</sup> <sup>k</sup>*=<sup>1</sup> *b* (*i*) *<sup>k</sup>* = 1. Correspondingly, the less total evidence observed, the greater the total uncertainty is. The mean of the corresponding Dirichlet distribution *P*ˆ *si* for the probability *p*ˆ (*k*) *<sup>i</sup>* is computed as *p*ˆ (*k*) *<sup>i</sup>* <sup>=</sup> *<sup>α</sup>*(*i*) *i <sup>S</sup>*(*i*) .

In addition, Figure 3 demonstrates the process of the outputs of multiple domainspecific classifiers in detail. The evidence of each source is obtained using neural networks (Step ①). According to subjective logic [31] , the obtained evidence parameterizes the Dirichlet distribution (Step ②) to induce the classification probability and uncertainty (Step ③). The classification probability and overall uncertainty are inferred by combining the belief masses of multiple sources based on Dempster's rule (Step ④). Dempster's combining is discussed in Section 4.4.

**Figure 3.** The process of combining the outputs of multiple domain-specific classifiers.

Source domain task loss L*cls* is calculated here. To adapt to the Dirichlet distribution [65], the cross-entropy function is formulated as (11).

$$\begin{split} \mathcal{L}\_{acc}\left(\mathfrak{a}^{(i)}\right) &= \int \left[ \sum\_{k=1}^{K} -y\_{ij}\log\left(p\_{jk}\right) \right] \frac{1}{B\left(\mathfrak{a}\_{j}\right)} \prod\_{k=1}^{K} p\_{jk}^{a\_{jk}^{(i)}-1} d\mathfrak{p}\_{j} \\ &= \sum\_{k=1}^{K} y\_{jk} \left(\psi\left(S^{(i)}\right) - \psi\left(a\_{jk}^{(i)}\right)\right) \end{split} \tag{11}$$

where *ψ*(·) is the digamma function, the parameter *α<sup>i</sup>* of the Dirichlet distribution and forming the multinomial opinions *D*(*p<sup>i</sup> αi*), where *p<sup>i</sup>* is the category assignment probabilities on a simplex, and *pjk* is the predicted probability of the *jth* sample for category *k*.

The above loss function ensures that more evidence is generated for the correct label of each sample than for other classes, but there is no guarantee that less evidence is generated for the incorrect label. That is, in MAN-DS, the expected evidence of incorrect labels shrinks to 0 [66]. To this end, the following KL divergence term is introduced:

$$\begin{split} KL\left[D\left(\mathbf{p}\_{\bar{j}}\mid\bar{\mathbf{a}}\_{\bar{j}}\right)\Vert D\left(\mathbf{p}\_{\bar{j}}\mid\mathbf{1}\right)\right] &= \log\left(\frac{\Gamma\left(\sum\_{k=1}^{K}\bar{a}\_{\bar{j}k}\right)}{\Gamma(K)\prod\_{k=1}^{K}\Gamma\left(\bar{a}\_{\bar{j}k}\right)}\right) \\ &+ \sum\_{k=1}^{K} \left(\bar{a}\_{\bar{j}k} - 1\right) \left[\psi\left(\bar{a}\_{\bar{j}k}\right) - \psi\left(\sum\_{r=1}^{K}\bar{a}\_{\bar{j}r}\right)\right] \end{split} \tag{12}$$

Therefore, given parameter *α<sup>j</sup>* of the Dirichlet distribution for each sample *j*, the loss is:

$$\mathcal{L}\left(\mathfrak{a}^{(i)}\right) = \sum\_{j=1}^{n\_{si}} \mathcal{L}\left(\mathfrak{a}\_{j}\right) = \sum\_{j=1}^{n\_{si}} \left\{ \mathcal{L}\_{\text{acc}}\left(\mathfrak{a}\_{j}\right) + \rho KL\left[D\left(\mathfrak{p}\_{j} \mid \mathfrak{a}\_{j}\right) \|D\left(\mathfrak{p}\_{j} \mid \mathbf{1}\right)\right] \right\} \tag{13}$$

where *ρ* > 0 is a balance factor. In practice, *ρ* increases slowly from zero to 1 to avoid paying too much attention to the KL divergence term in the early stage of learning.

That is, the classification loss is formulated as:

$$\mathcal{L}\_{cls} = \sum\_{i}^{N} \mathcal{L}\left(\mathfrak{a}^{(i)}\right) \tag{14}$$

#### *4.4. Dempster's Combination*

With subjective logic, there is an FoD Θ = {1, 2, ... , *K*} and *K* + 1 focal elements {{1}, {2}, ... , {*K*}, Θ} with belief mass {*b*1, *b*2, ... , *bk*, *u*} in every source–target domain pair. To fuse these adaptation outputs from *N* sources, only call Dempster's rule (defined in (5)) *N* − 1 times as:

$$m\_{\oplus}(b\_k) = m\_1(b\_k) \oplus m\_2(b\_k) \oplus \dots \cup m\_{N-1}(b\_k) \tag{15}$$

In addition, the prediction results of multiple classifiers for the same target sample should be consistent. Dempster's combination could help to avoid ambiguity and large uncertainty on the category boundary, which is demonstrated in Figure 4.

Moreover, the Manhattan distance is used to measure the difference among the classifiers to achieve this goal, as well. Denote *e*(*i*) = *e* (*i*) <sup>1</sup> ,*e* (*i*) <sup>2</sup> , ... ,*e* (*i*) *<sup>k</sup>* ,*e*(*i*) <sup>=</sup> *<sup>α</sup>*(*i*) <sup>−</sup> <sup>1</sup> <sup>=</sup> *<sup>b</sup>*(*i*)*S*(*i*) as the final output of the *i*th source–target domain pair. The loss-of-label Manhattan distance is formulated as:

$$\mathcal{L}\_{dist} = \frac{1}{N} \sum\_{i}^{N} |e^{(i)} - m\_{\oplus}(e)| \tag{16}$$

#### *4.5. Objective Function and Algorithm*

The overall objective function of the proposed model is formulated as (17).

$$\underset{f,h,c}{\text{arg min}} (\mathcal{L}\_{cls} + \gamma \mathcal{L}\_{mmd} + \lambda \mathcal{L}\_{disc}) \tag{17}$$

In detail, L*cls* is minimized to accomplish the source domain task; L*mmd* is minimized to reduce the domain shifts between each source domain and the target domain; L*disc* is a consistent regular term and minimized to constrain the outputs of domain-specific classifiers. In addition, *γ* and *λ* are trade-off parameters; refer to (1).

**Figure 4.** The demonstration the prediction conflict of domain-specific classifiers. The algorithm of MAN-DS is summarized in Algorithm 1, and it can be trained by the standard back-propagation.

#### **Algorithm 1** The algorithm of the proposed method

**Input:** source domain data {D*si*}*<sup>N</sup> <sup>i</sup>*=1, target domain data D*t*, the number of training iterations *T*, and batch size *M*;

**Output:** model parameters;


```
12: end for
```
#### **5. Experiment**

The effectiveness of our cross-domain classification method was verified by conducting comprehensive experiments on three well-known benchmarks: **ImageCLEF-DA**, **Office-31**, and **Office-Home**.

#### *5.1. Data Preparation*

**ImageCLEF-DA** [67] is a benchmark dataset for the ImageCLEF 2014 domain adaptation challenge, which is organized by selecting the 12 common categories shared by the following three public datasets, each considered as a domain: Caltech-256(**C**), ImageNet ILSVRC 2012(**I**), and Pascal VOC 2012 (**P**). There are 50 images in each category and 600 images in each domain. All domain combinations were used, and three transfer tasks were built: **C**, **I** → **P**; **C**,**P** → **I**; **I**,**P** → **C**.

**Office-31** [68] is a benchmark for domain adaptation, comprising 4110 images in 31 classes collected from three distinct domains: Amazon (**A**), which contains images downloaded from amazon.com, Webcam (**W**), and DSLR (**D**), which contains images taken by a web camera and digital SLR camera with different photographic settings. The images in each domain are unbalanced. To enable unbiased evaluation, all methods were evaluated on all three transfer tasks: **A**, **W** → **D**; **A**,**W** → **D**; **W**,**D** → **A**.

**Office-Home** [69] consists of 15,588 images, larger than Office-31 and ImageCLEF-DA. It consists of images from 4 different domains: Artistic images (**A**), Clip Art (**C**), Product images (**P**), and Real-World images (**R**). For each domain, the dataset contains images of 65 object categories collected in the office and home settings. All domain combinations were used, and four transfer tasks were built:: **A**, **P**, **R** → **C**; **A**, **P**, **C** → **R**; **A**, **R**, **C** → **P**; **P**, **R**, **C** → **A**.

#### *5.2. Compared Method*

There is a small amount of MDA work based on a domain-specific distribution and classifier alignment architecture. To verify the effectiveness of our MDSAN model, the Multiple Feature Spaces Adaptation Network (MFSAN) [55] was introduced as the multisource baseline. In addition, the proposed method was compared with ResNet [70], Deep Domain Confusion (DDC) [14], the Deep Adaptation Network (DAN) [71], Deep CORAL (DCORAL) [72], and Reverse Gradient (RevGrad) [73].

There are several comparative standards for different purposes. (1) **Source combine**: all source domains are combined into a traditional single-source vs. target setting; (2) **Single best**: the best single source transfer results among the multiple candidate source domains with SDA methods; (3) **Multi-source**: the results of MDA methods. The first standard is to verify whether multiple source domains are beneficial for the target task or whether the simple combination of source domains will lead to negative transfer. In addition, the second standard evaluates whether the best SDA method could be further improved by introducing other source domains. The third standard demonstrates the effectiveness of the proposed approach.

Furthermore, ablation experiments were performed to verify the effectiveness of DST for adaptation outputs' fusion. This variant is denoted as *V*1, which simply averages the outputs in the end. In addition, variant *V*<sup>2</sup> does not consider L*mmd*, and variant *V*<sup>3</sup> ignores L*dist*.

#### *5.3. Implementation Details*

All methods were implemented based on the PyTorch framework and deployed and testified on the same device. For a fair comparison, the same data pre-processing routines and model architecture were utilized in all experiments. The pre-trained ResNet50 [70] was employed as the common feature extractor, where the fine-tuning strategy was used to save time. For all domain-specific feature extractors, the same structure (*conv*(1 × 1), *conv*(3 × 3), *conv*(1 × 1)) was utilized. At the end of the neural network, the channels were reduced to 256, like DDC [14]. According to subjective logic, the *so f tmax* layer was replaced with *so f tplus* to activate the outputs and avoid negative values. The optimization method was mini-batch stochastic gradient descent with a momentum of 0.9. The learning rate was gradually decreased by *<sup>η</sup><sup>p</sup>* = *<sup>η</sup>*<sup>0</sup> (1+*α*)*<sup>β</sup>* , where *<sup>p</sup>* is the training progress linearly changing from 0 to 1, and *η*<sup>0</sup> = 0.01, *α* = 10, *β* = 0.75. This would optimize to promote convergence and low error on the source domain. As for the hyperparameters, *γ* = *ρ* = 100*λ* was simply set. They were changed from 0 to 1 by a progressive schedule *γ<sup>p</sup>* = <sup>2</sup> exp(−*θp*) <sup>−</sup> 1,(*<sup>θ</sup>* <sup>=</sup> <sup>10</sup>), instead of fixing them throughout the experiments.

#### *5.4. Experimental Results*

MAN-DS was compared with the above-mentioned methods on three datasets, and the average results of five repeated experiments are reported in Tables 1–3, respectively. The maximum accuracy in a transfer task is marked in bold.


**Table 1.** Performance comparison of classification accuracy (%) on Office-31 dataset.

**Table 2.** Performance comparison of classification accuracy (%) on Image-CLEF dataset.


**Table 3.** Performance comparison of classification accuracy (%) on Office-Home dataset.


#### **6. Discussion**

*6.1. Result Observations*

From these experimental results, insightful observations are given:


#### *6.2. Ablation Experiment*

Ablation experiments were implemented by conducting *V*1, *V*2, and *V*3, as shown in Tables 1–3. The encouraging results show that every component of MAN-DS is positive to improve performance.

To further verify the effectiveness of the DST fusion strategy, supplementary experiments were carried out where *Si* is the *i*sth domain-specific classifier, as reported in Table 4. The maximum accuracy in a transfer task is marked in bold.


**Table 4.** Classification accuracy (%) with and without DST fusion strategy on Office-Home dataset.

#### *6.3. Feature Visualization*

Feature visualization is demonstrated in Figure 5. The category boundaries of the domain-specific classifier on the task **D**,**W**→**A** learned by MAN-DS and MFSAN are visualized by using t-SNE embeddings. It is clear that MAN-DS is more effective in dealing with prediction conflicts, in which DST is effective.

**Figure 5.** Domain-specific classifier feature visualization.

#### *6.4. Parameter Sensitivity*

Parameter sensitivity was tested by sampling the trade-off parameter (where *γ* = *ρ* = 100*λ* for simplicity) values in {0.01, 0.02, 0.05, 0.1, 0.2, 0.5, 1, 2}. To study the parameters' sensitivity, the experiments were implemented on task **D**,**W**→**A** and **A**,**C**→**R**, and the results are shown in Figure 6. As observed, the accuracy increases with the increase of *γ* and reaches a peak at *γ* = 1, then decreases. The proposed method MAN-DS can keep a relatively stable result in the range of (0.1, 2) of *γ*, which is higher than the baseline. Generally, MAN-DS is not sensitive to changes in the parameters in a certain range. Hence, setting *γ* to (0.1, 2) is recommended to achieve better performance. In the reported experiment, the parameters {*γ*, *ρ*, *λ*} were set to {1, 1, 0.01}, respectively.

**Figure 6.** Accuracy with respect to *γ* = *ρ* = 100*λ*.

#### *6.5. Computational Complexity*

The FLoating point OPerations (FLOPs) were used to measure the operation times of forward propagation in neural network; the smaller the FLOPs, the faster the computation speed is. In addition, the smaller the number of PARAMeters (PARAMs) in the neural network, the smaller the size of the model is. Table 5 shows the FLOPs and PARAMs of MAN-DS, MFSAN, and ResNet50. Compared with ResNet50, the small increase of computational complexity mainly comes from the component of domain-specific feature extractors and classifiers. Compared with the baseline MFSAN, MAN-DS improves the accuracy without increasing the computational complexity.

**Table 5.** FLOPs and PARAMs.


Moreover, Dempster's combination does not increase the computational complexity of the algorithm. For the *K*-classification task, MAN-DS always obtains *K* + 1 instead of 2*<sup>K</sup>* focal elements, which is {1, 2, ... , *K*, Θ}. That is, the computational complexity caused by Dempster's combination is not *O*(2*n*), but *O*(*n*).

#### **7. Conclusions**

The core of MDA is making full use of available source data collected from several different, but related domains. However, it becomes difficult and challenging due to the multiple domain shifts. Following the domain-specific alignment architecture, this study proposed a novel multi-source domain adaptation network combing Dempster– Shafer evidence theory for cross-domain image classification to reduce multiple domain shifts and enhance transfer accuracy. In addition, SL and the Dirichlet distribution were employed to bridge MDA with DST.

To evaluate the effectiveness of the proposed method, three popular benchmark datasets were used and ten transfer tasks were devised to train and validate MAN-DS. Extensive experiments demonstrated that MAN-DS outperforms its competitors in crossdomain image classification. The insightful conclusions are as follows:


In this research study, the original and unimproved Dempster's rule was used. In the future, the combination rules will be optimized based on the improved information entropy method to take more evidence information into account. Besides, more effective MDA and DST bridging methods will be investigated.

**Author Contributions:** Conceptualization, M.H. and C.Z.; methodology, M.H. and C.Z.; software, C.Z.; validation, M.H.; formal analysis, M.H.; investigation, M.H. and C.Z.; resources, M.H.; data curation, M.H.; writing—original draft preparation, C.Z.; writing—review and editing, M.H. and C.Z.; visualization, C.Z.; supervision, M.H.; project administration, M.H.; funding acquisition, M.H. All authors have read and agreed to the published version of the manuscript.

**Funding:** The work described in this paper was partially funded by two Natural Science Foundation of Guangdong Province Projects of Grant Number 2021A1515011496 and Grant Number 2022A1515011370.

**Institutional Review Board Statement:** Not applicable.

**Informed Consent Statement:** Not applicable.

**Data Availability Statement:** Not applicable.

**Acknowledgments:** The authors would like to thank the Editor and the Reviewers for their valuable comments and suggestions.

**Conflicts of Interest:** The authors declare no conflict of interest.

#### **Abbreviations**

The following abbreviations are used in this manuscript:


#### **References**


### *Article* **An Improved Matting-SfM Algorithm for 3D Reconstruction of Self-Rotating Objects**

**Zinuo Li †, Zhen Zhang \*,†, Shenghong Luo, Yuxing Cai and Shuna Guo**

† These authors contributed equally to this work.

**Abstract:** The 3D reconstruction experiment can be performed accurately in most cases based on the structure from motion (SfM) algorithm with the combination of the multi-view stereo (MVS) framework through a video recorded around the object. However, we need to artificially hold the camera and stabilize the recording process as much as possible to obtain better accuracy. To eliminate the inaccurate recording caused by shaking during the recording process, we tried to fix the camera on a camera stand and placed the object on a motorized turntable to record. However, in this case, the background did not change when the camera position was kept still, and the large number of feature points from the background were not useful for 3D reconstruction, resulting in the failure of reconstructing the targeted object. To solve this problem, we performed video segmentation based on background matting to segment the object from the background, so that the original background would not affect the 3D reconstruction experiment. By intercepting the frames in the video, which eliminates the background as the input of the 3D reconstruction system, we could obtain an accurate 3D reconstruction result of an object that could not be reconstructed originally when the PSNR and SSIM increased to 11.51 and 0.26, respectively. It was proved that this algorithm can be applied to the display of online merchandise, providing an easy way for merchants to obtain an accurate model.

**Keywords:** 3D reconstruction; multi-view stereo; structure from motion; background matting

**MSC:** 68T45

#### **1. Introduction**

3D reconstruction refers to the establishment of mathematical models of three-dimensional objects that are suitable for computers to process, which is the foundation for processing, manipulating, and analyzing their properties in a computer environment and the key technology for establishing a virtual reality expressing the objective world in a computer. 3D reconstruction is generally vision-based, and is a way to obtain a 3D model of the target object by collecting images from a camera and obtaining 3D coordinates according to the triangulation principle [1]. 3D reconstruction plays a very significant role in object recognition, scenery understanding, 3D modeling and animation, industrial control, etc. [2]. The development of deep learning in recent years has also brought a new impact and convenience to 3D reconstruction [3].

Applications of 3D reconstruction have appeared in many fields such as the real-time reconstruction of nearby scenes in robot navigation and mobile robots [4], the research on tumors in the medical field [5], and the reconstruction of artifacts in tourist attractions [6]. During the development of computer vision, techniques to study the direction of 3D reconstruction have become more and more mature, and many methods have been proposed for this direction, but some problems still remain inevitable. For example, 3D reconstruction can be an issue for weak texture regions and highlight regions. Both of the problems are caused by the regions having a lot of similar RGB information, which leads to failures in

**Citation:** Li, Z.; Zhang, Z.; Luo, S.; Cai, Y.; Guo, S. An Improved Matting-SfM Algorithm for 3D Reconstruction of Self-Rotating Objects. *Mathematics* **2022**, *10*, 2892. https://doi.org/10.3390/ math10162892

Academic Editor: Jakub Nalepa

Received: 26 June 2022 Accepted: 11 August 2022 Published: 12 August 2022

**Publisher's Note:** MDPI stays neutral with regard to jurisdictional claims in published maps and institutional affiliations.

**Copyright:** © 2022 by the authors. Licensee MDPI, Basel, Switzerland. This article is an open access article distributed under the terms and conditions of the Creative Commons Attribution (CC BY) license (https:// creativecommons.org/licenses/by/ 4.0/).

School of Computer Science and Engineering, Huizhou University, Huizhou 516007, China

**<sup>\*</sup>** Correspondence: zzsjbme@sjtu.edu.cn; Tel.: +86-182-1726-7715

feature extraction and feature point matching. Another example is that in some scenarios, a constant background or a too complex background can also have a significant impact on the 3D reconstruction.

When we sell products online, we often choose to record an introductory video about our products and then upload the video with the details of the item online. In this case, users are able to look at the pictures or watch the pre-recorded video by the merchants, but cannot freely view the appearance and details of the product, so there are limitations to this approach. In this paper, we wished to apply the SfM algorithm to the field of commodity reconstruction to provide multiple views for customers. Hence, we considered a fixed camera instead of a moving one to remove any artificial influence such as shaking to improve the accuracy of the results obtained from SfM. For many online merchants, it may be more convenient for them to put an object on an auto turntable since they may not know how to take pictures appropriately. If the background has not changed and only the object is rotating, we call it a "self-rotating" state, since it is the same as the rotation of the Earth. Most of the feature points will come from the outside of the object instead of the object itself, which are not useful at all for 3D reconstruction. This kind of problem will also lead to a poorly reconstructed model. When using the SfM algorithm provided by OpenMVG [7] in performing SIFT (scale-invariant feature transform) feature extraction [8] and feature point matching, we found that such an algorithm was very flawed in the case where the object is in a self-rotating state while the background does not change.

From the issues above, it can be concluded that the background has a bad influence on the accuracy of 3D reconstruction, so we proposed a Matting-SfM algorithm in this paper, in which we eliminated the background of the targeted object. This method removed the influence of the background on the SfM and thus helped us obtain a good result in the end.

In summary, the conventional SfM algorithm is not able to reconstruct an accurate result of an object that is in a self-rotating state. In response to such a problem, we propose the Matting-SfM method, which has made the following contributions:


The rest of this paper is organized as follows. Section 2 introduces the existing 3D reconstruction techniques. Section 3 presents our methods. The illustration of our experimental materials is provided in Section 4. Our experimental results are provided in Section 5, and the conclusions is presented in the last section.

#### **2. Related Work**

At present, the major 3D reconstruction methods generally include visual geometric 3D reconstruction and deep learning reconstruction. In visual geometry 3D reconstruction, there are some classical open source projects such as Colmap [9], OpenMVG [10], VisualSfM [11], etc. On the other hand, there are also some deep learning methods for 3D reconstruction such as PatchMatchNet [12], MVSNet [13], R-MVSNet [14], PointMVS-Net [15], Cascade series [16], etc. All of these methods can be used with a MVS framework such as OpenMVS [17], CMVS [18], PMVS [19], etc. to obtain a good reconstruction result. Moreover, further developments such as real-time reconstruction [20] or the applications in embedded devices and hardware [21,22] are also very impressive.

Several 3D reconstruction algorithms are currently being widely used. In terms of computing camera poses, there are two representative examples. One is based on RGB-D (e.g., BundleFusion [23]), which requires a camera with depth information and has more accuracy, but it also causes a device limitation. The other is RGB-based such as SfM (structure from motion) [24] or SLAM [25], which does not require the camera to have depth information, but the depth should be calculated during the computing process. These methods can be applied to the MVS system after obtaining the camera poses and

a surface-optimized model with mapping will be obtained. In classical visual geometric reconstructions, the SfM algorithm has been widely used, and thanks to the achievements of people nowadays, we obtained a global SfM with high efficiency optimization [26]. Zhu et al. used a hybrid global and incremental SfM algorithm [27], and the following year, they pushed the global SfM to a scale of millions of input images, larger than any previous work [28]. Chen et al. proposed a tree-structured SfM algorithm [29], which greatly improved the efficiency compared to the traditional SfM algorithm and also handled the outliers more reliably, making the SfM algorithm more efficient and fault-tolerant.

Although the previous algorithm produced these exciting results, there is currently none that can specifically handle the 3D reconstruction of self-rotating objects well. In order to solve such a problem, we used the ResNet [30] to solve the defect of the algorithm, which cannot reconstruct the rotating objects. We propose a Matting-SfM algorithm, in which we segmented the targeted object based on Background Matting v2 [31], using it to eliminate the background completely. It largely removes the influence of the background on SfM. The experiment proved that the Matting-SfM algorithm showed a great improvement over the traditional SfM algorithm in the case of rotating objects and the background remains unchanged, thus laying a good foundation for subsequent MVS reconstruction.

#### **3. Methods**

#### *3.1. Video Segmentation and Background Replacement*

Conventionally, the matting approach contains some of the classical algorithms, among which the Canny algorithm [32] suits our task most. The problem is, for some complex backgrounds, the results of the traditional algorithm are not always satisfying (Figure 1). In Figure 1, the left background such as the areas marked in red boxes are not useful and the object in the yellow box is the only area of interest. That is, the traditional algorithm cannot eliminate the background totally.

**Figure 1.** The matting result of the Canny algorithm.

To solve this problem, in this paper, foreground segmentation and background replacement of objects were performed based on Background Matting v2 (BGMv2 for shortcut). We needed to provide a video or an image dataset of the object with the background and a background image without the object. The more accurately the background is aligned with the original video, the better (Figure 2).

We trained a new model by ourselves based on the approach of Lin's team [31], which was more suitable for our task to perform 3D reconstruction. The network was ResNet 50, the epoch was set to 30 with a batch size of 16. We first trained the base network and the refinement network was trained after it. Our two datasets comprised VideoMatte240 K and a dataset made by ourselves. VideoMatte240 K contained 484 pairs of high-resolution Alpha matte and foreground video clips extracted from green screen stock footage, constituting 240,709 unique frames. The self-made dataset contained 1000 pictures of different objects such as toys and models. The Alpha matte *α* and foreground *F* were extracted by Photoshop manually.

**Figure 2.** The input video (*I*) and background image (*B*).

First, the original video or image *I* and the background image *B* were linked to a size of 6 ∗ *HC* ∗ *WC*, where the *H* represents the Height of the image, *W* represents the Width of the image. Then, the image will be downsampled with multiplier *C* to generate an input of size x 6 ∗ *HC* ∗ *WC*, which is fed to the basenet. The input and output results of the two networks can be briefly represented as follows (Figure 3). The model ends up with five results: Alpha, Foreground, ...Error Map, Refine, and Composite. For us, Composite was the result that we needed to focus on, which is the core input of our 3D reconstruction system.

**Figure 3.** The two network structures of BGMv2 and the input and output results.

The architecture of the basenet is based on DeepLabV3 [33] and DeepLabV3+ [34], where basenet consists of the backbone and ASPP module [35] along with a decoder module. The above generated images of size 6 ∗ *HC* ∗ *WC* are input to the backbone for feature extraction. Behind the backbone, the atrous spatial pyramid pooling (ASPP) module is connected, which is a model that combines null convolution [33] with spatial pyramid pooling (SPP) [36]. The decoder is connected behind the ASPP and stitches together the previous output and the extracted special features of the backbone through skip connection and performs bilinear upsampling, and then extracts the coarse result that consists of four parts: Coarse Alpha *αC*; Foreground Residual *F<sup>R</sup> <sup>C</sup>* ; Error Map *Ec*; Hidden *Hc*. The subscript *C* indicates the downsampling multiplier and the *R* refers to the word residual. We set the downsampling multiplier to 4, which means that the four coarse results generated by basenet were 1/16 of the original image.

Unlike basenet, operations are not performed on the original map in refinenet, but on the patches with the *K* highest prediction errors extracted from the feature map with the help of Error Map *EC*. Since the multiplier *C* in the previous step was set to 4, the input *EC* obtained in refinenet is 1/16 of the original image, so each pixel within *EC* corresponds to a patch of4\*4 size of the original image. Refinenet connects the four coarse results output from basenet with the processed image *I* and background map *B* as feature maps, selects patches in the feature maps by *EC*, and then performs two3\*3 convolutions to output4\*4 patches. The next step performs upsampling to output 8 \* 8 patches, connects this patch with the corresponding 8 \* 8 patches in the original map, performs two 3 \* 3 convolutions again, and finally obtains 4 \* 4 Alpha outputs and Error Map patches. Finally, the coarse Alpha and coarse Error Map are upsampled until the original size, and then the patches are replaced with the4\*4 Alpha and Error Map patches obtained by refinenet.

Using the black background as the final composite image background, given as image *I*, background map *B* using the obtained Alpha mask map *α* and foreground map *F*, the new image *I* can be synthesized by replacing the *B* with *B* as follows (1):

$$I' = aF + (1 - a)B'\tag{1}$$

While the above *I* and background map *B* were provided by us, *B* was set to a black background since we wanted to remove the background. Alpha mask and foreground *F* were predicted by the network structure of BGMv2 as follows. After performing the processing of the two networks, the final output foreground residual *F<sup>R</sup>* can be expressed in Equation (2):

$$F^{\mathbb{R}} = F - I \tag{2}$$

*F* can then be obtained by feeding *F<sup>R</sup>* into image *I* in Equation (3), and by combining Equations (2) and (3), we can obtain a more detailed foreground image *F*:

$$F = \max\left(\min\left(F^R + I, 1\right), 0\right) \tag{3}$$

The L1 loss is employed over the entire Alpha matte and its (Sobel) gradient to learn with respect to the ground truth, *α* is the ground truth of *α* obtained by manual processing:

$$L\_{\mathfrak{a}} = \|\mathfrak{a} - \mathfrak{a}^\*\|\_1 + \|\nabla \mathfrak{a} - \nabla \mathfrak{a}^\*\|\_1 \tag{4}$$

Using Equation (3), we can calculate the foreground layer using the predicted foreground residual *F<sup>R</sup>*. We only calculated the L1 loss on pixels when *α* > 0, where *α* > 0 is a Boolean expression, and *F\** is the ground truth of *F* obtained by manual processing:

$$L\_F = \left\| \left( \mathfrak{a}^\* > 0 \right) \* \left( F - F^\* \right) \right\|\right\|\,\tag{5}$$

The ground truth error map is defined as in Equation (6) for the refinement region selection. Next, we determined the loss by computing the mean squared error between the expected error map and the actual error map *E*, where *E\** is the ground truth error map defined by [28]:

$$E^\* = |\mathfrak{a} - \mathfrak{a}^\*|\tag{6}$$

$$L\_E = \|E - E^\*\|\_2 \tag{7}$$

According to the above formulas, the base network *(αc, FcR, Ec, Hc) = Gbase (Ic, Bc)* operates at 1/*c* of the original image resolution and the loss function is used as:

$$L\_{base} = L\_{n\_C} + L\_{F\_C} + L\_{E\_C} \tag{8}$$

The same as refinenet *(α, F, R) = Grefine (αc, FcR, Ec, Hc, I, B)*, the loss function of it is used as Equation (9):

$$L\_{refine} = L\_{\alpha} + L\_{F} \tag{9}$$

#### *3.2. Reconstructing Sparse Point Cloud*

In visual geometric 3D reconstruction, there are two methods of the SfM algorithm: incremental SfM and global SfM. In this paper, we used incremental SfM for reconstruction, so the Global SfM was not included. Before the steps of 3D reconstruction, we had to conduct some pre-processing steps such as SIFT feature extraction [8], AC-RANSAC [37] for linear fitting, etc. The main steps of pre-processing can be described as follows (Figure 4).

**Figure 4.** The pre-processing process.

In order to find the connection between p and p , we used the AC-RANSAC algorithm provided by OpenMVG to calculate the Basic Matrix in Equation (10), and the parameters in the formula can be referred to in Figure 5.

$$F = K'^{-T}[T\_X]RK^{-1} \tag{10}$$

where *F* is the basic matrix, and *K* and *K* are the internal parameter matrix of the two cameras. *l* and *l* are the rays of *p* and *p* , *I* and *I* are two different planes. *O*<sup>1</sup> and *O*<sup>2</sup> two diferent views. *R* and *T* are the rotation and translation matrix in 3D coordinates, and [*TX*]*R* is referred to as the essential matrix in Equation (11):

$$E = T \times R = [T\_X]R\tag{11}$$

**Figure 5.** The parameter in the basic matrix.

In the pre-processing stage, we also need to obtain the homography matrix (Figure 6). It is known that the internal parameter matrix *K* of the first camera, the internal parameter matrix *K* of the second camera, the position of the second camera with respect to the first camera is *R*, → *t* , → *t* is a vector, <sup>→</sup> *n* is the unit normal vector of the plane *π* in the coordinate system of the first camera, and *d* is the distance from the coordinate origin to the plane *π* (7). *P, p* and *p'* are three corresponding points in different planes. Through the parameters above, we can gain the homography matrix in (12). Once the basic matrix and homography matrix are obtained, we can use these two matrices to the following triangulation calculation, which is a very important step of the SfM.

*T*

!

(12)

**Figure 6.** The parameter in the homography matrix.

In the pre-processing stage, we could see a certain problem that the same background feature points could be observed in every image with the same background; instead, the feature points of the reconstructed object showed less matching compared to the background. We used OpenMVG as an example and the SfM algorithm provided by OpenMVG for reconstruction, which has the following approximate steps (Figure 7).

**Figure 7.** The steps of the SfM provided by OpenMVG for 3D reconstruction.

In the step of 'Reconstructing the Initial Point Cloud from Two Views' (see Figure 7), we needed to select an edge from the connected graph obtained from the previous step. In this step, the relationship between the edges should be satisfied that when all points correspond to point triangulation, the median angle between the camera and the ray on the 2D image cannot be greater than 60◦, but not less than 3◦. For a dataset with a large number of feature points that come from the background, the ray pinch angle is basically constant, which is the reason why no feature points meet the requirement. In the 2D image, it looks like there is a certain angle (see Figure 8) between these two pairs of points, but the corresponding 3D points of these two pairs are completely unchanged after triangulation and do not satisfy the case where the median of the ray angle is greater than 3◦ when the corresponding points are triangulated, so such pairs of points will not be selected. All of the features extracted from the dataset and their matches are shown in Figures 9 and 10. Although there were indeed a large number of matches in Figure 9, most of them were not useful at all because they came from the background.

**Figure 8.** The angle of the ray at the triangulation point.

**Figure 9.** The features extracted from the image.

**Figure 10.** The useless matches between two images.

The dataset of the Composite result after segmentation was input (Figure 2). It is easy to extract feature points from such a dataset; for the SIFT feature extractor, a large amount of the same black RGB information in the background cannot be extracted as features, so most of the feature points will come from the object that needs to be reconstructed and a sparse point cloud will be gained step-by-step (Figure 11).

**Figure 11.** The sparse point cloud obtained from the SfM.

#### *3.3. Densifying the Sparse Point Cloud by MVS*

After obtaining the sparse point cloud and camera poses in the previous step, for a better observation, these results were used as the input to MVS to densify the sparse point cloud. For all of the sparse point clouds in this paper, OpenMVS [16] was used to finish that task except for VisualSfM. Note that OpenMVS does not support VisualSfM anymore since 2 years ago, so we were only able to use CMVS-PMVS [18,19] to densify the sparse point cloud, but there will not be too much difference. The inputs to the whole MVS system are the image dataset and the camera poses, which need to be processed with domain frame selection [38] as well as the global best domain frame selection [39] in the initial stage of data preparation. In the DensifyPointCloud step, the semi-global matching (SGM) [40] is used to compute the depth of the image and input the computed depth map into MVS to obtain the dense point cloud, which is more complete and tighter than the sparse point cloud.

In this way, the sparse point cloud is made more tighter and easy to observe, which means that we have an intuitive way to compare the results obtained by different methods. The output can be clearly seen in Figure 12 and the difference between the sparse point cloud and dense point cloud can be well observed in Figure 13.

**Figure 12.** The output of densifying the sparse point cloud by MVS.

**Figure 13.** The comparison between the sparse and dense point cloud.

#### **4. Experiment Materials and Evaluation Indices**

For the experimental part, the dataset organization is shown in Figure 14, and more details will be introduced as follows. A single image was selected to see the details, and for better observation, some of the total were chosen to show the continuous images in the dataset. Four representative experiments were selected in this paper. These experiments were all conducted in a well-lit environment, and the resolution and frame rate of the three videos were all 4 k/60 Hz. Thirty frames of each of these three videos were used as the three groups of the dataset that were input to Colmap, VisualSfM, and OpenMVG. Since we wanted to place more emphasis on the influence of the background, a single typical car model was chosen to test our algorithm. Our final purpose was to gain a very accurate object without any residue of the background since merchants may not know how to eliminate the residual 3D points in the point cloud. All of the experiments focused on the influence of the background.


**Figure 14.** The dataset organization.

The first experiment was to verify the performance of the traditional reconstruction method by putting the object on the turntable and recording a video using a camera to shoot around the targeted object. The targeted object of the video was a car model, and there were a large number of feature points on the object that could be provided to the algorithm for computation. The left side of Figure 15 shows one of the images in the object's image dataset obtained from the video, and the right side shows a general overview of the 30 image dataset of the object.

**Figure 15.** The image dataset used in experiment 1.

The second experiment was to investigate whether the traditional algorithm could use the dataset of a self-rotating object for 3D reconstruction. The object of the video was the same car model, and we put the car on a motorized turntable and fixed the camera on a camera mount. In this experiment, the background was complicated. The left side of

Figure 16 shows one of the images in the object's image dataset obtained from the video, and the right side shows a whole look of the 30 image dataset.

**Figure 16.** The image dataset used in experiment 2.

Next, to eliminate the impact of the background, we tried to artificially remove the background of the dataset, making it as simple as possible. Therefore, we proceeded with the third experiment. A smooth and white background was chosen to obtain the dataset. The same as in the above figures, on the left side of Figure 17 is one of the images of the object's image dataset from the video, and on the right side is an overview of the 37 image dataset of the object.

**Figure 17.** The image dataset used in experiment 3.

The fourth experiment was to explore the performance and accuracy of the Matting-SfM algorithm. The background of the second dataset was replaced with a black background without feature points, which means that the original background eliminated the background totally. Figure 18 shows the intermediate products of the Matting-SfM, that is, the dataset after segmentation.

**Figure 18.** The image dataset used in experiment 4.

When it came to the evaluation, we choose three methods: hist similarity; peak signal to noise ratio (PSNR); and the structural similarity index (SSIM). All of these were used to evaluate the similarity between the original image and the models.

For the hist similarity, the histogram data of the source image and the image to be filtered were collected, and the collected image histograms were normalized, then we directly performed a correlation comparison provided by OpenCV.

For the PSNR, given a clean image *I* and noisy image *K* of size *m\*n*, the formulas used were:

$$\text{MSE} = \frac{1}{\text{mm}} \sum\_{i=0}^{m-1} \sum\_{j=0}^{n-1} \left[ I(i,j) - K(i,j) \right]^2 \tag{13}$$

$$\text{PNSR} = 10 \times \log\_{10} \left( \frac{\text{MA} \chi\_I^2}{\text{MSE}} \right) \tag{14}$$

where the MSE is the mean square error; MAX*<sup>I</sup>* <sup>2</sup> refers to the maximum possible pixel value for the images.

For SSIM, given two images *X* and *Y*, the formulas are as follows. The *L(X,Y), C(X,Y)* and *S(X,Y)* refer to the luminance, contrast, and structure of two images, respectively:

$$L(X,Y) = \frac{2\mu\_X\mu\_Y + C\_1}{\mu\_{X^2} + \mu\_{Y^2} + C\_1} \tag{15}$$

$$\mathbb{C}(\mathbf{X}, \mathbf{Y}) = \frac{2\sigma \chi \sigma\_{\mathbf{Y}} + \mathbf{C}\_2}{\sigma\_{X^2} + \sigma\_{Y^2} + \mathbf{C}\_2} \tag{16}$$

$$S(X,Y) = \frac{\sigma\_{X\_Y} + \mathcal{C}\_3}{\sigma\_X \sigma\_Y + \mathcal{C}\_3} \tag{17}$$

where *uX, uY* represent the mean of image X and *Y*; *σ<sup>X</sup>* and *σ<sup>Y</sup>* represent the standard deviation of image *X* and *Y*; *σX2* and *σY2* represent the variance of image *X* and *Y*; *σXY* represents the covariance of image *X* and *Y*; *C1*, *C2*, and *C3* are constants to avoid the denominator being 0 and maintain stability. Usually *C1 = (K1 \* L)ˆ2,* C2 = (K2 \* L)ˆ2, C3 = C2/2, and generally *K1* = 0.01, *K2* = 0.03, *L* = 255.

Finally SSIM can be expressed as:

$$\text{SSIM}(X, Y) = L(X, Y) \* \mathbb{C}(X, Y) \* S(X, Y) \tag{18}$$

#### **5. Results**

For all of the experiments, the SfM system was set to a high quality mode to ensure that all of the SfM systems were run in the same way, but in some cases, the high quality mode still did not go well.

The first experiment was to explore the performance of the traditional SfM algorithm. The experiment results proved that the traditional SfM algorithm worked well when the camera recorded around the object (the object was kept still and the background changed, see Figure 12). In the first experiment, the models were indeed obtained, but the background influenced the accuracy (see the blue edges around the model in Figure 19) of the model, making it look coarse and rough. Furthermore, the shaking of the video will also have an impact on the result, so we must make a lot of effort to stabilize the camera in our hands.

**Figure 19.** The results of experiment 1.

In order to solve the problem raised in experiment 1, we proceeded with experiment 2. The recording method of fixing the camera position and making the targeted object rotate was used. This way avoided the bad impact of human factors on the reconstruction. However, for this dataset, we could not obtain any results at all, and all of the methods failed when reconstructing the second dataset (Figure 20).

It is deduced that this phenomenon was due to the unchanged background, so next, the third dataset was used for testing. Although we tried to simplify the background, the background still had a certain impact on the accuracy of the result (Figure 21). In this case, the results were obtained but they were still not satisfactory, so the influence of background still existed. Moreover, a result could not be gained even in the high quality mode for OpenMVG, as we only obtained a defective result from it.


**Figure 20.** The results of experiment 2.


**Figure 21.** The results of experiment 3.

After several experiments, it was obvious that the effect of the background of the datasets was always bad for the traditional SfM algorithm, so we introduced our method of Matting-SfM. In experiment 4, Matting-SfM was used to process the second dataset by simply providing a background image (Figure 2) and totally eliminated the background (Figure 22).


**Figure 22.** The dataset after segmentation.

The second dataset was put directly into our Matting-SfM algorithm, and after processing, it produced an intermedia dataset (Figure 22). In this experiment, we removed the effect caused by the background. Finally, we performed the whole procedure of 3D reconstruction, where it can be seen that a 3D model with high accuracy and detailed texture was reconstructed (Figure 23).

**Figure 23.** The final reconstruction of Matting-SfM after segmentation.

For all of the datasets above-mentioned, except for the first one, we used the same method to process the dataset, and finally, the comparisons are listed in Table 1. Note that Table 1 only includes the third dataset because Colmap, VisualSfM, and OpenMVG did not obtain any results from the second dataset; since Matting-SfM is an algorithm that focuses on the self-rotating object, the first dataset was not included.


**Table 1.** A comparison of all of the datasets using the mathematical method.

The experiments used three methods to evaluate the results, all of which were used to evaluate the similarity between the original image and the models. To ensure the accuracy of the results, all models were set to the same direction and the screenshot was compared to the original image. We compared all of the same angles with the datasets and calculated the mean values of the three methods. For the parameters above, the higher the parameters, the more similar the model to the original image.

What should be emphasized is that the parameters are only a digital way to evaluate the results, and the mathematical method is not always the best way to distinguish the differences between images as they may produce some misunderstandings and the models actually look more different than the display of numbers. Therefore, we have to use a visible way to evaluate the results, as shown in Figure 24. Through the mathematical methods and visible ways, we can clearly distinguish the results of four approaches, where Matting-SfM obtained more accurate results than the others.


**Figure 24.** A comparison of all of the datasets using the visible method.

In addition, the performance of Matting-SfM was tested on other objects (Figure 25). Three objects were selected, the first one being a money jar, the second one being a dog doll, and the last was a very small accessory. Note that the background of the three datasets was chosen to be as complex as possible. It can be seen that the first two objects were reconstructed well through Matting-SfM, however, it did not work well on the conventional SfM. Furthermore, due to the tiny size and the complex background of the last object, the conventional SfM did not produce any results, while for Matting-SfM, the object was indeed reconstructed, but the quality leaves some small room to be improved.

It can be concluded that Matting-SfM can work properly with fixed camera position and self-rotating object and it can reconstruct a good result. Matting-SfM solves the problem of not being able to reconstruct self-rotating objects with unchanging background. Experiments have shown that our results are greatly improved after applying the Matting-SfM algorithm. The result shows that the Matting-SfM algorithm is able to reconstruct the object under rotation normally, which solves the problem that the traditional SfM algorithm cannot reconstruct the object under a self-rotating state.

**Figure 25.** More objects used to test the performance of Matting-SfM.

#### **6. Conclusions**

In this paper, we proposed a Matting-SfM algorithm to solve the problem of reconstruction failure under the condition of a self-rotating object and maintain a high accuracy. Since the SfM algorithm has certain requirements on the stability and lighting of the camera, we selected an indoor environment. We fixed the camera using a camera support and placed the object on a motorized turntable to make the object rotate for shooting. This approach not only ensured the stability to achieve high precision reconstruction, but also avoided the negative impact of artificial recording.

At present, we have embedded the algorithm in our system and have deployed it on the server. By uploading the format of mp4 videos and the background image, the algorithm in the server will eliminate the background using deep learning methods, then output the video with the background in black after segmenting the object. The video is processed by intercepting key frames and outputting them as an image dataset, and the output dataset will be automatically reconstructed after the processing is completed. Finally, the system downloads the results (GLTF Files and PNG Texture Image) to the computer, so we can simply view the result through the website constructed by Web OpenGL.

Although this algorithm solves the problem that self-rotating objects with unchanging background cannot be reconstructed, there is still no way to match the feature points well for some objects with not enough feature points, or regions with highlights and weak textures such as the back of the car in Figure 22. We can see that the depth information was wrong when calculating, leading to a dent in the back of the model. For some excessively detailed and skeletonized areas, a background that is very aligned with the original video needs to be provided, otherwise it will output less accurate results. Sine we cared more about accuracy than real-time, in the future, we will try to solve the problem of highlights and weakly textured areas with a deep learning approach, replacing such areas with valid RGB information from a deep learning method in a better way. In this way, feature extraction, feature point matching, and texture mapping will not be affected badly. Furthermore, it is a good way to improve feature extraction by replacing the CNN (convolutional neural network) [41–45] with a traditional algorithm such as SIFT as it may perform better.

**Author Contributions:** Conceptualization, Z.Z. and Z.L.; Data curation, Z.Z. and Z.L.; Formal analysis, Z.Z., Z.L., Y.C., S.L. and S.G.; Funding acquisition, Z.Z. and Z.L.; Investigation, S.L.; Supervision, Z.Z. All authors have read and agreed to the published version of the manuscript.

**Funding:** This research was funded by the Young Innovative Talents Project of Colleges and Universities in Guangdong Province (grant number 2021KQNCX092); the Doctoral program of Huizhou University (grant number 2020JB028); Outstanding Youth Cultivation Project of Huizhou University (grant number HZU202009); The Science and Technology Program of Guangzhou (grant number NO.202201020625).

**Institutional Review Board Statement:** Not applicable.

**Informed Consent Statement:** Not applicable.

**Data Availability Statement:** The VideoMatte240K dataset was used in this study, and the dataset can be found here. [https://grail.cs.washington.edu/projects/background-matting-v2/#/datasets, accessed on 16 January 2022].

**Conflicts of Interest:** The authors declare no conflict of interest.

#### **References**


### *Article* **Enhancing the Transferability of Adversarial Examples with Feature Transformation**

**Hao-Qi Xu 1,2, Cong Hu 1,2,\* and He-Feng Yin 1,2**


**\*** Correspondence: conghu@jiangnan.edu.cn

**Abstract:** The transferability of adversarial examples allows the attacker to fool deep neural networks (DNNs) without knowing any information about the target models. The current input transformationbased method generates adversarial examples by transforming the image in the input space, which implicitly integrates a set of models by concatenating image transformation into the trained model. However, the input transformation-based methods ignore the manifold embedding and hardly extract intrinsic information from high-dimensional data. To this end, we propose a novel feature transformation-based method (FTM), which conducts feature transformation in the feature space. FTM can improve the robustness of adversarial example by transforming the features of data. Combining with FTM, the intrinsic features of adversarial examples are extracted to generate transferable adversarial examples. The experimental results on two benchmark datasets show that FTM could effectively improve the attack success rate (ASR) of the state-of-the-art (SOTA) methods. FTM improves the attack success rate of the Scale-Invariant Method on Inception\_v3 from 62.6% to 75.1% on ImageNet, which is a large margin of 12.5%.

**Keywords:** adversarial example; feature transformation; black-box attack; ensemble attack; deep neural network

**MSC:** 68T10

#### **1. Introduction**

DNNs have been shown to perform well in many fields, for example, image classification [1–3], human recognition [4], image segmentation [5], image fusion [6], visual object tracking [7,8], super-resolution [9], etc [10]. The ultimate goal of these studies is to make DNN-based applications more practicable and efficient. However, the existence of adversarial examples presents a concern for security of many applications, such as autonomous driving [11], face recognition [12–14], etc.

Adversarial examples [15], generated by adding indistinguishable perturbations to the raw images, can lead the DNNs to make completely different predictions. They can even take effect for completely unknown models, which is called the transferability of adversarial examples. In addition to this, there are several studies on universal adversarial perturbations [16,17], which are able to take effect on any image. Some studies are devoted to the application of adversarial examples to real-world scenarios, such as face recognition, autonomous driving, etc. [18–22]. Studying both adversarial attack and defense [23–26] is of significance, not only in revealing the vulnerability of DNNs, but also in improving the robustness of DNNs.

Many white-box attack methods have been proposed, such as Fast Gradient Sign Method (FGSM) [27], Basic Iterative Method (BIM) [28], etc. However, it is difficult for an attacker to obtain the structure and other parameters of the target model in the real-world situation. Therefore, various approaches have emerged to enhance the transferability of

**Citation:** Xu, H.-Q.; Hu, C.; Yin, H.-F. Enhancing the Transferability of Adversarial Examples with Feature Transformation. *Mathematics* **2022**, *10*, 2976. https://doi.org/10.3390/ math10162976

Academic Editors: Jianping Gou, Weihua Ou, Shaoning Zeng and Lan Du

Received: 26 July 2022 Accepted: 17 August 2022 Published: 18 August 2022

**Publisher's Note:** MDPI stays neutral with regard to jurisdictional claims in published maps and institutional affiliations.

**Copyright:** © 2022 by the authors. Licensee MDPI, Basel, Switzerland. This article is an open access article distributed under the terms and conditions of the Creative Commons Attribution (CC BY) license (https:// creativecommons.org/licenses/by/ 4.0/).

adversarial examples for black-box attack. Ensemble Attack [29] is an effective method to enhance the transferability of adversarial examples. Lin et al. [30] proposed Scale-Invariant Method (SIM), which utilizes input transformation to obtain a new model. A set of models can be obtained by using different transformations several times. With this approach, they can perform an ensemble attack with only one trained model, which is an implicit ensemble attack. Input transformation-based methods are successfully used for an adversarial attack, such as Diverse Input Method (DIM) [31], Translation-Invariant Method (TIM) [32], Admix Attack Method (Admix) [33], etc. However, these methods ignore the manifold structure of adversarial examples and few works focus on feature transformation. To this end, this work proposes a feature transformation-based method (FTM) to improve the transferability of adversarial examples. Compared with the input transformation, our approach transforms the intrinsic features of data instead of the input images. FTM is an implicit ensemble attack that can simultaneously attack multiple models that extract different features. It can improve the robustness of the adversarial example at the feature level. This work proposes several feature transformation strategies. FTM could effectively improve the performance of the SOTA adversarial attacks. Our contributions can be summarized as follows.


The structure of the paper is organized as follows. Section 2 introduces related work. Section 3 details the proposed FTM. Section 4 shows the experimental results. Section 5 gives a summary of this work.

#### **2. Related Work**

#### *2.1. Adversarial Example and Transferability*

It is firstly pointed out by Szegedy et al. [15] that DNNs are vulnerable to adversarial examples, which are generated by adding imperceptible noises to raw images.

Let *x* be a clean image, *y* = *f*(*x*; *θ*) be the output label predicted by the model with parameters *<sup>θ</sup>*, and || · · · ||*<sup>p</sup>* denotes the *<sup>p</sup>*-norm. The adversarial example is an image *<sup>x</sup>adv* whose output label *<sup>f</sup>*(*xadv*, *<sup>θ</sup>*) = *<sup>f</sup>*(*x*, *<sup>θ</sup>*), and the *Lp* norm of the adversarial perturbation *<sup>x</sup>adv* − *<sup>x</sup>* is smaller than a threshold  as ||*xadv* − *<sup>x</sup>*|| ≤ . *<sup>p</sup>* = <sup>∞</sup> is used to limit the distortion. Many methods are proposed to improve the attack success rate (ASR) of adversarial examples. These methods can be divided into two branches: advanced gradient calculation and input transformations.

#### *2.2. Advanced Gradient Calculation*

This branch exploits better gradient calculation algorithms to enhance the performance of adversarial examples in both white-box settings and black-box settings.

**Fast Gradient Sign Method (FGSM)**: Szegedy et al. [27] make the point that linear behavior in high-dimensional spaces is sufficient to cause adversarial examples. According to this point, they propose the FGSM, which generates an adversarial example *xadv* by maximizing the loss function *J*(*xadv*, *y*; *θ*) with a one-step update:

$$\mathbf{x}^{adv} = \mathbf{x} + \boldsymbol{\epsilon} \cdot \operatorname{sign}(\nabla\_{\mathbf{x}} I(\mathbf{x}, \mathbf{y}, \theta)) \tag{1}$$

where *J*(*x*, *y* : *θ*) denotes the loss function of classifier *f*(*x* : *θ*), ∇*<sup>x</sup> J*(*x*, *y*, *θ*) is the gradient of loss function with regard to *x* and *sign*(·) is the sign function to make the perturbation meet the *Lp* norm bound.

**Basic Iterative Method (BIM)**: Kurakin et al. [28] extend FGSM to an iterative version by iteratively applying gradient updates multiple times with a small step size *α*. BIM can be expressed as:

$$\mathbf{x}\_{t+1}^{adv} = \mathbb{C}lip\_{x}^{c} \{ \mathbf{x}\_{t}^{adv} + \mathfrak{a} \cdot \text{sign}(\nabla\_{x} I(\mathbf{x}, y, \theta)) \} \tag{2}$$

where *x*<sup>0</sup> = *x* and *Cilp <sup>x</sup>*(·) restricts generated adversarial examples to be within the -ball of *x*.

**Momentum Iterative Fast Gradient Sign Method (MI-FGSM)**: To reduce the variation in update direction and avoid local minima, Dong et al. [34] introduce momentum into the BIM. The update procedure is formulated as follows:

$$\mathbf{g}\_{t+1} = \boldsymbol{\mu} \cdot \mathbf{g}\_t + \frac{\nabla\_x \mathbf{J}(\mathbf{x}\_\prime \mathbf{y}\_\prime \boldsymbol{\theta})}{||\nabla\_x \mathbf{J}(\mathbf{x}\_\prime \mathbf{y}\_\prime \boldsymbol{\theta})||\_1} \tag{3}$$

$$\mathbf{x}\_{t+1}^{adv} = \text{Clip}\_{x}^{c} \{ \mathbf{x}\_{t}^{adv} + \boldsymbol{\alpha} \cdot \text{sign}(\mathbf{g}\_{t+1}) \} \tag{4}$$

where *gt* gathers the gradient of the first *t* iterations with a decay factor *μ*.

**Nesterov Iterative Fast Gradient Sign Method (NI-FGSM)**: NI-FGSM [30] adopts Nesterov's accelerated gradient to improve the transferability of MI-FGSM. This method replaces *xadv <sup>t</sup>* in Equation (4) with *xnest*, while *xnest* can be formulated as follows:

$$\mathfrak{x}\_{nest} = \mathfrak{x}\_t^{adv} + \mathfrak{a} \cdot \mathfrak{y} \cdot \mathfrak{g}\_t \tag{5}$$

#### *2.3. Input Transformations*

Various input transformation-based methods, such as DIM, TIM, SIM, and Admix, are proposed to generate transferable adversarial examples.

**Diverse Input Method (DIM)**: Inspired by the facts that data augmentation is effective to prevent networks from overfitting, Xie et al. [31] apply random resizing and random padding to the inputs to improve the transferability of adversarial examples.

**Translation-Invariant Method (TIM)**: Dong et al. [32] propose to replace the gradient on the original image with the average value of multiple translated images for the update. Inspired by the translation-invariant property, they approximate this process by convolving the gradient with a predefined kernel matrix to avoid introducing much more computations.

**Scale-Invariant Method (SIM)**: Lin et al. [30] discover the scale-invariant property of deep learning models and introduce the definition of loss-preserving transformation and model augmentation. Accordingly, they present SIM that calculates the average gradient on the scaled copies of the original image for the update.

**Admix Attack Method (Admix)**: Admix is proposed by [33] to enhance the transferability of the adversarial examples. It integrates gradient information of different categories of images for the update. Specifically, Admix randomly selects a number of different categories of images and then admix the sampled image with a minor weight to the original input image. It calculates the gradient on the mixed image for update.

#### *2.4. Adversarial Defense*

In addition to adversarial attacks, many works on adversarial defense have been proposed to improve the robustness of the classifiers. The current defense methods can be divided into two categories.

One category aims to improve the robustness of the classifier itself, such as adversarial training [27,35]. It adds adversarial examples to the training set during the training of the model, making it immune to the adversarial examples. This is a popular and effective defense method and has many great following works [36,37]. However, its effectiveness is largely limited by the method of generation of the added adversarial examples.

Another category of defense methods reduces the impact of adversarial perturbations by modifying the input images, such as adding noises and compressing the images [38,39]. Xie et al. [40] propose to perform randomized resizing and padding to inputs at inference

time, which is the top-1 defense solution in the NIPS competition. Nips-r3 fuse multiple adversarial trained models and perform several input transformations at inference time. These methods require no additional training computational overhead and are effective against various attack approaches.

#### **3. Our Approach**

A DNN model could be formulated as *f*(*x*) = lin(con(*x*)), where con(·) and lin(·) denote the convolutional part and the fully connected part, respectively. *p* = con(*x*) denotes the feature extracted by the convolutional part.

To obtain an ensemble of models that extract different features, we propose the feature transformation denoted as FT(·). Through introducing feature transformation, we can obtain a new model *f* (*x*) = lin(*p* ) = lin(FT(con(*x*))) extracting different features from the original model during every iteration. FTM optimizes the adversarial perturbations over several different transformed features:

$$\mathop{\rm arg\,min}\_{\mathbf{x}^{adv}} \min \quad \frac{1}{m} \sum\_{i=0}^{m} \mathbf{J}(\lim(\mathbf{F}\mathbf{T}\_{i}(\mathbf{con}(\mathbf{x}^{adv}))), \mathbf{y}\_{trunc}),\tag{6}$$

$$\text{s.t.}\,\,\,\|\mathbf{x}^{adv} - \mathbf{x}\|\_{\infty} \le \mathbf{e}\_{\prime} \tag{7}$$

where *m* denotes the number of iterations and FT(·) denotes the feature transformation. Thus, FTM is an implicit ensemble attack that simultaneously attacks *m* models. The illustration of the FTM is shown in Figure 1.

In this paper, we consider five strategies of feature transformation as follows: Strategy I: Fixed threshold random noise: Add a random vector *z* sampled from the uniform distribution U(−*r*,*r*):

$$\text{FT}(p) = p + z \tag{8}$$

Strategy II: Mean-based threshold random noise: *z* is a random vector sampled from the uniform distribution U(−*r*,*r*) and *p* is the mean value of feature *p*. Adding *p* · *z* to feature *p*:

$$\text{FT}(p) = p + \overline{p} \cdot z \tag{9}$$

Strategy III: Feature overall scaled: Multiply the features *p* by a random number *k* sampled from the uniform distribution U(−*r*,*r*):

$$\text{FT}(p) = k \cdot p \tag{10}$$

Strategy IV: Each value of feature scaled separately: Multiply feature *p* by a random vector *z* sampled from the uniform distribution U(−*r*,*r*):

$$\text{FT}(p) = \mathbf{z} \cdot \mathbf{p} \tag{11}$$

Strategy V: Offset mean random noise: Add a random vector *z* sampled from the uniform distribution U(−*r* + *s*,*r* + *s*) to feature *p*:

$$\text{FT}(p) = p + z \tag{12}$$

The feature transformation should also be an accuracy-preserving transformation. We define the accuracy-preserving feature transformation as follows:

**Definition 1** (Acc-preserving Feature Transformation)**.** *Given a test set X and a classifier f*(*x*) = lin(con(*x*))*, Acc*(lin(con(*x*)), *X*) *denotes the accuracy of model f*(*x*) *on data set X. If there exists an feature transformation* FT(·) *that satisfies Acc*(lin(con(*x*)), *X*) ≈ *Acc*(lin(FT(con(*x*))), *X*)*, we say* FT(·) *is an accuracy-preserving feature transformation.*

We experimentally study the acc-preserving feature transformation strategies in Section 4.1.2. We determine the magnitude *r* of uniform distribution to ensure that our feature transformations are accuracy-preserving transformations. The algorithm of the FTM attack is summarized in Algorithm 1.

**Algorithm 1** Algorithm of FTM.

**Input:** Original image *x*, true label *ytrue*, a classifier *f* = lin(con(*x*)), loss function *J*, feature transformation FT(·) **Hyper-parameters:** Perturbation size , maximum iterations *T*, number of iterations of feature transformation *m* **Output:** Adversarial example *xadv* 1: perturbation size in each iteration: *α* = /*T* 2: **while** 0 ≤ *t* < *T* − 1. 3: **if** *k* = 0. 4: *x*<sup>0</sup> = *x*. 5: **end if** 6: *g* = 0 7: **while** 0 ≤ *i* < *m* − 1 8: feature:*p* = con(*x*) 9: transformed feature: *p* = FT(*p*) 10: Get the gradients by ∇*<sup>x</sup> J*(lin(*p* ), *ytrue*) 11: Update *g* = *g* + ∇*<sup>x</sup> J*(lin(*p* ), *ytrue*) 12: **end while** 13: Get average gradients as *g* = <sup>1</sup> *<sup>m</sup>* · *g* 14: Update *xadv <sup>i</sup>*+<sup>1</sup> <sup>=</sup> Clip *x*{*xadv <sup>i</sup>* + *α* · sign(*g*)} 15: **end while** 16: return *xadv* = *xadv*

*T*

#### **4. Experimental Results**

*4.1. Experiment on ImageNet*

#### 4.1.1. Experimental Setup

**Dataset.** We perform experiments on ImageNet, which is the most common and challenging image classification dataset. 1000 images from the ImageNet [41] are selected as our test set. The 1000 benign images belong to 1000 different categories and can be correctly classified by the tested models.

**Networks.** This work selects four mainstream models, including Inception\_v3 (Inc\_v3) [42], Inception\_v4 (Inc\_v4), Inception-Resnet\_v2 (IncRes\_v2) [43], and Xception(Xcep) [44].

**Attack setting.** We follow the setting in Lin et al. [30] with the maximum perturbation as  = 16, number of iteration *T* = 16, and step size *α* = 1.6, which is a difficult and challenging attack setting. We adopt the decay factor *μ* = 1.0 for MI-FGSM. The transformation probability is set to 0.5 for DIM. The number of scale copies is set to *m* = 5 for SIM. We set *m*<sup>1</sup> = 5 , and randomly sample *m*<sup>2</sup> = 3 images with *η* = 0.2 for Admix. The hyper-parameter settings of these attack methods are all consistent with the original papers.

#### 4.1.2. Accuracy-Preserving Transformation

To investigate accuracy-preserving transformations, we test the accuracy of the models integrated with the five strategies on the ImageNet dataset. We keep the magnitude *r* of uniform distribution in the range of [0, 10].

The magnitude of uniform distribution is an important hyper-parameter of FTM. A larger magnitude will increase the diversity of the implicit ensemble models and thus improve the transferability of the adversarial examples. However, too large a magnitude will make the model invalid and thus lose the ability to guide the generation of AE. As shown in Figure 2, the accuracy curves are smooth and stable for strategies I, II, and V when the magnitude is in the range of [0, 4]. They drop significantly after the magnitude exceeds 4. Moreover, the accuracies for strategy III and IV are extremely low when the magnitude is close to 0. They turn to remain stable after the magnitude exceeds 4. It can be seen that the feature transformation strategy with scaled operation is more sensitive to small magnitude, e.g., strategies III and IV. The feature transformation strategy of adding noise is more sensitive to a large magnitude, e.g., strategies I, II, and V. Based on the experimental results, the magnitude of uniform distribution is set to 4 in the following experiment to ensure that the feature transformations are accuracy-preserving transformations.

**Figure 2.** The average classification accuracy of Inc\_v3, Inc\_v4, IncRes\_v2, and Xcep integrated with five different feature transformation strategies on ImageNet. The horizontal coordinate is the magnitude of uniform distribution and the vertical coordinate is the accuracy of the model.

#### 4.1.3. Feature Transformation Strategies

In this section, we show the experimental results of the proposed FTM with five feature transformation strategies. We set *m* = 1 and generate adversarial examples on the Inc\_v3 by FT-FGSM, FT-MI-FGSM, and FT-SIM. The ASRs against the other three black-box models are presented in Table 1.

**Table 1.** The black-box ASRs (%) of FT-FGSM, FT-MI-FGSM, and FT-SIM with five strategies on ImageNet. The adversarial examples are generated on Inc\_v3. The highest ASRs are shown in bold.


When combined with FT-FGSM, Strategy III achieves the best overall attack performance, reaching 35.9% and 37.5% when attacking IncRes\_v2 and Xcep, respectively. When attacking with FT-MI-FGSM, Strategy V attains the best overall attack performance, reaching 57% and 53.3% when attacking Inv\_v4 and IncRes\_v2, respectively. When FT-SIM is used to attack IncRes\_v2 and Xcep, Strategy III achieves the ASRs of 35.9% and 37.5%, which outperforms the other strategies. It can be seen that the overall performance of Strategy III is better and it performs better in the experiments combined with SIM, which is an input transformation-based method. Thus, we adopt Strategy III in the following experiments.

#### 4.1.4. Attack with Input Transformations

We test the ASRs of MI-FGSM, SIM, DIM, and Admix, respectively. Then we combine these methods with FTM as FT-MI-FGSM, FT-SIM, FT-DIM, and FT-Admix. Some adversarial examples are shown in Figure 3. We adopt Strategy III, set *m* = 1, set the magnitude of uniform distribution *r* = 4, and then use the generated adversarial examples to attack the four models. We compare the black-box ASRs of FT-MI-FGSM, FT-SIM, FT-DIM, and FT-Admix with MI-FGSM, SIM, DIM, and Admix in Tables 2–5. In the tables, the first columns are the local models, and the first rows are the target models. The values in the tables are the attack success rates (ASRs) on the target models using the adversarial examples generated from the local models. The higher ASRs are bolded.

When combined with MI-FGSM, the ASRs is increased by up to 9.4%, from 55% to 64.4% when attacking Xcep with Inc\_v4. When FT-SIM is used to attack IncV3 with IncRes\_v2, the ASR is improved from 62.6% to 75.1%, which outperforms the SIM by 12.5%. The adversarial examples generated by FT-DIM achieved about 55% ASR against all models. When FT-Admix is used to attack IncV3 with Xecp, the ASR reaches 72.2%.

According to the reported experimental results, it can be observed that FTM could improve the ASRs of adversarial examples generated by the SOTA black-box attack methods. It is confirmed that feature transformation can improve the transferability and robustness of adversarial examples.

**Figure 3.** Adversarial examples generated by MI-FGSM, DIM, SIM, Admix, the proposed FT-MI-FGSM, FT-DIM, FT-SIM, and FT-Admix on the Inc\_v3.

**Table 2.** The black-box ASRs of MI-FGSM and FT-MI-FGSM on ImageNet. The first column is the local model, and the first row is the target model. The values in the table are the ASRs (%) on the target models using the adversarial examples generated with the local models. The higher ASRs are shown in bold.


**Table 3.** The black-box ASRs of SIM and FT-SIM on ImageNet. The first column is the local model, and the first row is the target model. The values in the table are the ASRs (%) on the target models using the adversarial examples generated with the local models. The higher ASRs are shown in bold.



**Table 4.** The black-box ASRs of DIM and FT-DIM on ImageNet. The first column is the local model, and the first row is the target model. The values in the table are the ASRs (%) on the target models using the adversarial examples generated with the local models. The higher ASRs are shown in bold.

**Table 5.** The black-box ASRs of Admix and FT-Admix on ImageNet. The first column is the local model, and the first row is the target model. The values in the table are the ASRs (%) on the target models using the adversarial examples generated with the local models. The higher ASRs are shown in bold.


4.1.5. Attack against Defense Method

In this section, we quantify the effectiveness of FTM against several defense methods, including random resizing and padding (RandP) [40], JPEG compression (JPEG) [39], randomized smoothing (RS) [38], and the rank-3 submission in the NIPS-2017 (NIPS-r3). RandP is the top-1 submission in the NIPS competition, which mitigates the effect of adversarial perturbations by randomized resizing and padding. JPEG is a defensive compression framework, which could rectify adversarial examples without reducing classification accuracy on benign data. RS constructs a "smoothed" classifier from an arbitrary base classifier, which is more adversarially robust. NIPS-r3 fuses multiple adversarial trained models and performs several input transformation at inference time.

We choose SIM as the comparison method and generate adversarial examples with Inc\_v3. The average ASRs on Inc\_v4, IncRes\_v2, and Xcep are shown in Table 6. The ASRs are improved by a large margin of 9.5% on average. It validates that the adversarial examples generated by FTM are more robust to fool models with defense mechanisms.

**Table 6.** The black-box ASRs of SIM and FT-SIM on ImageNet against four defense methods. The adversarial examples are generated with Inc\_v3. The values in the table are the average ASRs (%) on the Inc\_v4, IncRes\_v2, and Xcep. The higher ASRs are shown in bold.


#### 4.1.6. Parameter Analysis

In this section, we perform additional analysis for the difference among different numbers of iterations *m*. The adversarial examples are generated by FT-DIM on Inc\_v3. The number of iterations of feature transformation ranges from 1 to 9.

As shown in Figure 4, the average black-box ASR increases from 59.2% for 1 iteration to 62.7% for 3 iterations. As the number of iterations increases to 9, the success rate of the attack increases to 65.3%. It validates that the ASR of FTM increases as the number of iterations of feature transformation increases. The sensitivity of the attack success rate gradually decreases as the number of iterations increases. Since a higher number of iterations results in a larger computational overhead, the trade-off between effectiveness and overhead needs to be made according to the specific scenario.

**Figure 4.** The black-box ASRs of FT-DIM attack with different number of iterations on ImageNet. The adversarial examples are generated on Inc\_v3 and the ASRs are the average ASRs on Inc\_v4, IncRes\_v2, and Xcep.

#### *4.2. Experiment on Cifar10* Cifar10

To further demonstrate the effectiveness of our approach, we also conducted experiments on the Cifar10 [45] dataset. Cifar10 has 60,000 color images with 32 × 32 pixels and is divided into 10 categories. We select 1000 images belonging to the 10 categories from the test set, which are correctly classified by all the experimental models. We compare the effects of the FTM with the MI-FGSM, SIM and Admix using the ResNet [46] family of models. The maximum perturbation  = 4, number of attack iterations *T* = 4, and the step size *α* = 1.

The experimental results for FT-MI-FGSM, FT-SIM, and FT-Admix are shown in Tables 7–9. The first columns are the local models and the first rows are the target models. It can be seen that our method improves the ASRs across all experiments. FT-MI-FGSM achieves 83.8% ASR, when attacking Res152 with Res50. FT-SIM improves the ASR of SIM from 66.6% to 73.9%, when attacking Res101 with Res152. FT-Admix boosts the ASR of Admix attack from 43.1% to 55.1%, when attacking Res101 with Res152.

The experimental results on Cifar10 validate that FTM is effective not only on large image dataset, but also on small image dataset. Moreover, FTM can significantly improve the transferability and robustness of the adversarial examples generated by the SOTA black-box attack methods.


**Table 7.** The black-box ASRs of MIM (MI-FGSM) and FT-MIM (FT-MI-FGSM) on Cifar10. The first column is the local model, and the first row is the target model. The values in the table are the ASRs (%) on the target models using the adversarial examples generated with the local models. The higher ASRs are shown in bold.

**Table 8.** The black-box ASRs of SIM and FT-SIM on Cifar10. The first column is the local model, and the first row is the target model. The values in the table are the ASRs (%) on the target models using the adversarial examples generated with the local models. The higher ASRs are shown in bold.


**Table 9.** The black-box ASRs of Admix and FT-Admix on Cifar10. The first column is the local model, and the first row is the target model. The values in the table are the ASRs (%) on the target models using the adversarial examples generated with the local models. The higher ASRs are shown in bold.


#### **5. Conclusions**

We propose a novel feature transformation-based method (FTM), which effectively improves the transferability of adversarial examples. Five feature transformation strategies are proposed and the hyper-parameters of them are comprehensively analyzed. The experimental results on two benchmark datasets show that FTM can improve the transferability of the adversarial example significantly. It improves the ASRs of the SOTA methods by up to 12.5% on ImageNet. Our method can be combined with not only any gradient-based attack methods but also any neural networks that can extract features. However, the tuning of hyper-parameters is difficult, because different models and feature transformation strategies require a large number of experiments to choose the magnitude of uniform distribution. In the future, we will explore more feature transformation strategies to improve the transferability of adversarial examples while reducing the difficulty of tuning hyper-parameters.

**Author Contributions:** Conceptualization, H.-Q.X. and C.H.; methodology, H.-Q.X., C.H. and H.-F.Y.; software, H.-Q.X.; data curation, H.-Q.X.; resources, C.H.; writing—original draft preparation, H.-Q.X., C.H. and H.-F.Y.; project administration, C.H.; funding acquisition, C.H. All authors have read and agreed to the published version of the manuscript.

**Funding:** This research was funded in part by the National Natural Science Foundation of China (Grant No. 62006097), in part by the Natural Science Foundation of Jiangsu Province (Grant No. BK20200593), in part by the China Postdoctoral Science Foundation (Grant No. 2021M701456), and in part by the Fundamental Research Funds for the Central Universities (Grant No. JUSRP121074).

**Institutional Review Board Statement:** Not applicable.

**Informed Consent Statement:** Not applicable.

**Data Availability Statement:** The ImageNet and Cifar10 datasets were analyzed in this study. The ImageNet dataset can be found at https://image-net.org/ (accessed on 10 July 2022). Cifar10 dataset can be found at https://www.cs.toronto.edu/~kriz/cifar.html (accessed on 10 July 2022).

**Conflicts of Interest:** The authors declare no conflict of interest.

#### **References**


### *Article* **Extension Design Pattern of Requirement Analysis for Complex Mechanical Products Scheme Design**

**Tichun Wang \*, Hao Li and Xianwei Wang**

College of Mechanical and Electrical Engineering, Nanjing University of Aeronautics and Astronautics, Nanjing 210016, China

**\*** Correspondence: wangtichun2010@nuaa.edu.cn

**Abstract:** Due to the configuration process of a complex product scheme, a design structure often has the characteristics of multi-level, multi-attribute, creativity, and complexity; in order to improve the efficiency and quality of product scheme design, it has important research value to reasonably organize, reason, and reuse design knowledge. In this paper, the extension modeling problem under the extension design mode of complex product scheme is studied, the multitype design knowledge element modeling expression model of complex product scheme design is given, and the extension process model and the implication process model of requirement analysis of complex product scheme design is established. A new demand element weight assignment method based on extension distance is proposed to obtain accurate demand analysis index weight from the perspective of combined qualitative and quantitative analysis. On the basis of constructing the extension correlation degree of demand primitives, this paper puts forward the implementation method of the extension design pattern for the demand analysis of a complex product scheme design and gives the specific implementation algorithm. Finally, an example of product design is given to illustrate the method, and the results show the effectiveness and operability of the method.

**Keywords:** intelligent design; data analysis; models and algorithms; extension theory; scheme design

**MSC:** 68T20

#### **1. Introduction**

The product scheme design of the aerospace industry or power generation equipment industry is creative, skilled labor based on the combination of some theories and a large amount of practical experience; its design process is a multi-level multi-attribute creative and complex configuration process, the interaction of various design factors generate design constraints and design conflicts in the design process [1–4], so accurately describing analyzing and transforming the customers' requirements is very important for the smooth development of the product. Requirement analysis is not only considering the customer's requirement information but also considering the information of the entire life cycle of the product, that is, the design's feasibility, manufacturability, reliability, maintainability, energy, and environmental protection; they are the design goals of the various activities in the product development process to make requirements analysis better guide the subsequent design. Because of that, scheme design cannot, in isolation, describe the requirement model from the customer's point of view; it should consider the entire life cycle of the product and strive to make the requirement model, and not only output necessary design requirements but also be conducive to the mapping between product's function and structure, and lay the foundation for product design automation [5–7].

Currently, many scholars analyze and discuss the customer's requirements from different perspectives and give its corresponding method of requirement analysis [8–11], but it usually has some problems, such as the formalization of requirement description is not enough or the information of product requirements lack objectivity [12–14]. Extenics,

**Citation:** Wang, T.; Li, H.; Wang, X. Extension Design Pattern of Requirement Analysis for Complex Mechanical Products Scheme Design. *Mathematics* **2022**, *10*, 3132. https:// doi.org/10.3390/math10173132

Academic Editor: Bo-Hao Chen

Received: 5 July 2022 Accepted: 12 August 2022 Published: 1 September 2022

**Publisher's Note:** MDPI stays neutral with regard to jurisdictional claims in published maps and institutional affiliations.

**Copyright:** © 2022 by the authors. Licensee MDPI, Basel, Switzerland. This article is an open access article distributed under the terms and conditions of the Creative Commons Attribution (CC BY) license (https:// creativecommons.org/licenses/by/ 4.0/).

which is founded by Chinese scholar Professor Cai Wen, is an emerging discipline; it uses a formal model to study the possibility of object extension, and its pioneering and innovative rules and methods use formal implementation to search the rules of contradiction issues from qualitative and quantitative angles [15–17]. Extenics, which depends on basic element theory and extension mathematics, formalize the process of solving problems to establish the corresponding mathematical model and, on the basis of it, develops a new calculation method and technology that is more intelligent and formally resolves the issue of deep knowledge's storage representation and processing in the knowledge base [18,19], to promote knowledge in knowledge engineering more formal, deeper, and more fundamental. At present, extenics has some successful applications in the field of product design [15,20–23], but the study of systematically applying extenics in requirement analysis of complex mechanical products is rare, and it is still in its infancy. Axiomatic design is a new conceptual product design theory proposed by Suh of MIT in the early 1990s. Its purpose is to establish a scientific basis for complex product design and improve design activities in product development by providing designers with thinking methods and tools based on logic and rationality [24–27]. Different from the research method of discrete products [28,29], the research method of this paper is the extension intelligent design method, which aims to study and analyze the extension and implication of design problems in the process of requirement analysis of complex product scheme design. The formal modeling problem of design knowledge is solved by establishing a knowledge model, the extension reasoning problem of requirement analysis is solved by establishing a requirement analysis process model, and the extension design mode of requirement analysis is established to realize the extension requirement analysis of complex product scheme design.

Therefore, on the basis of integrating extension design and axiomatic design, and the relevant design methods and the concept of optimal solution [30–33]. This paper studies the extension process model and implication process model of requirement analysis in complex product concept design. Due to the problem of multi-attribute and multi-parameter requirement analysis, we put forward an allocation model of requirement basic element weight based on extensible distance, calculated the extensible relational degree of requirement basic element from the angle of a combination of qualitative and quantitative, and constructed the framework of extension design pattern of requirement analysis for complex mechanical product scheme design. In this paper, we will give the specific process with examples. Firstly, the extension modeling for the extension design pattern of concept design is given in Section 2. Then, the extension design pattern of requirement analysis of conceptual design is described in Section 3. Then, an extension design pattern of requirement analysis for complex mechanical products scheme design is provided in Section 4. Finally, the discussions and acknowledgments are given in Sections 5 and 6, respectively.

#### **2. Extension Modeling for Extension Design Pattern of Concept Design**

Due to dealing with the various complex design reasoning problems in product concept design, it needs to solve the issue of deep knowledge's storage representation and processing in the process of concept design reasoning. For this, extenics introduces basic element theory into product concept design; it takes a basic element as the logic cell of extensible design, and it gathers the represented design object's quality, quantity, action, and relation into an ordered triple *J* = (*Γ*, *c*, *v*) which is constituted of the design object *Γ*, object's characteristics *c* and the value *v* of characteristics. Formal modeling describes the information action and relation in the design process and puts forward a new methodology system for people to know the world and solve contradictions in the real world.

#### *2.1. The Basic Element Modeling of Multitype Design Knowledge*

On the basis of the semantic segmentation method multitype, design information in the conceptual design process is analyzed and arranged to form the minimum complete independent units of design information that can represent the design characteristics. Due to the manifestations of different units, we can establish the corresponding design knowledge units; formal and modeling describe it by basic element. In this paper, the design information in the conceptual design is divided into static design information, behavioral design information, and relational design information.

When modeling the static design information, we can describe it by the matter element model *J*(*R*), which belongs to basic element theory. If the design object has characteristics, then its matter element model *J*(*R*) is as below:

$$J(\mathbf{R}) = \begin{bmatrix} \Gamma(\mathbf{N}) & \mathbf{C}(\mathbf{N})\_1 & [V(\mathbf{C})\_1, \mathcal{W}(\mathbf{C})\_1] \\ & \mathbf{C}(\mathbf{N})\_2 & [V(\mathbf{C})\_2, \mathcal{W}(\mathbf{C})\_2] \\ & \vdots & \vdots \\ & & \mathbf{C}(\mathbf{N})\_n & [V(\mathbf{C})\_n, \mathcal{W}(\mathbf{C})\_n] \end{bmatrix} \tag{1}$$

Among it, *Γ*(*N*) describes the name of the object, *V*(*C*) is the value of design characteristic, *W*(*C*) is the weight of design characteristic, *V*(*C*) and *W*(*C*) have many forms such as the value of precise point, interval value with Fuzzy Information, subordinate function, the qualitative semantic description, and so on. Thus, in order to express more general, assuming *V* = *vL*, *vR* , *W* = *wL*, *wR* , both of them are interval values with Fuzzy Information, then Formula (1) can be expressed as follows:

$$J(\mathbf{R}) = \begin{bmatrix} \Gamma(N) & \mathbf{C}(N)\_1 & \left( \left[ v(\mathbf{C})\_1^L, v(\mathbf{C})\_1^R \right], \left[ w(\mathbf{C})\_1^L, w(\mathbf{C})\_1^R \right] \right) \\ & \mathbf{C}(N)\_2 & \left( \left[ v(\mathbf{C})\_2^L, v(\mathbf{C})\_2^R \right], \left[ w(\mathbf{C})\_2^L, w(\mathbf{C})\_2^R \right] \right) \\ & \vdots & \vdots \\ & \mathbf{C}(N)\_n & \left( \left[ v(\mathbf{C})\_n^L, v(\mathbf{C})\_n^R \right], \left[ w(\mathbf{C})\_n^L, w(\mathbf{C})\_n^R \right] \right) \end{bmatrix} \tag{2}$$

When modeling the behavioral design information, we can describe it by the affair element model *J*(*I*), which belongs to basic element theory. If the design object has *m* characteristics, then its affair element model *J*(*I*) is as below:

$$J(I) = \begin{bmatrix} \Gamma(D) & B(D)\_1 & \begin{pmatrix} \mathcal{U}(B)\_1 \ \left[ w(B)\_1^\top, w(B)\_1^R \right] \end{pmatrix} \\ & B(D)\_2 & \begin{pmatrix} \mathcal{U}(B)\_2 \ \left[ w(B)\_2^\top, w(B)\_2^R \right] \end{pmatrix} \\ & \vdots & \vdots \\ & B(D)\_m & \left( \mathcal{U}(B)\_m \left[ w(B)\_m^\top, w(B)\_m^R \right] \right) \end{pmatrix} \end{bmatrix} \tag{3}$$

Among it, *Γ*(*D*) is the name of design behavior, *B*(*D*) is the operating characteristic of design behavior, and *W*(*B*) is the weight of operating characteristic.

When modeling the relational design information, we can take the relational element model *J*(*Q*) to describe the configuration relationship, logical relationship, implication relationship, comparative relationship, and assembly relationship in the design process; if the design constraints relationship has characteristics, then its relational element model *J*(*Q*) is as below:

$$J(\mathbf{Q}) = \begin{bmatrix} \Gamma(S) & A(S)\_1 & \begin{pmatrix} G(A)\_1 \ \left[ w(A)\_1^\top, w(A)\_1^R \right] \end{pmatrix} \\ A(S)\_2 & \begin{pmatrix} G(A)\_2 \ \left[ w(A)\_2^\top, w(A)\_2^R \right] \end{pmatrix} \\ \vdots & \vdots \\ A(S)\_k & \left( G(A)\_k \left[ w(A)\_k^\top, w(A)\_k^R \right] \right) \end{pmatrix} \end{bmatrix} \tag{4}$$

Among it, *Γ*(*S*) is the name of the design constraints relationship, *A*(*S*) is the relational characteristic of the design constraints relationship, and *W*(*A*) is the weight of the relational characteristic.

In the process of complex product conceptual design, the design knowledge often has mixing characteristics; that is, the combination of static design information, behavioral design information, and relational design information; for this, we describe it by the compound element model *J*(*F*), which belongs to basic element theory. Through the function of

conjunction Θ to represent the multilayer semantic and more abundant design information, which is the frequently used conjunction, Θ is conjunction "∧" and/or conjunction "∨" and forms the corresponding and compound element or compound element and/or compound element, thus forming the overall design information of scheme design. The compound element model *J*(*F*) can be expressed as follow:

$$J(\mathbf{F}) = \begin{bmatrix} \Gamma(F) & (\Theta)\Gamma(\mathbf{J}(\mathbf{R}\_i)) & (V(\mathbf{J}(\mathbf{R}\_i)), W(\mathbf{J}(\mathbf{R}\_i))) \\ & (\Theta)\Gamma(\mathbf{J}(\mathbf{I}\_j)) & (V(\mathbf{J}(\mathbf{I}\_j)), W(\mathbf{J}(\mathbf{I}\_j))) \\ & (\Theta)\Gamma(\mathbf{J}(\mathbf{Q}\_s)) & (V(\mathbf{J}(\mathbf{Q}\_s)), W(\mathbf{J}(\mathbf{Q}\_s))) \end{bmatrix} \tag{5}$$

Among it, *i*, *j*, *s* separately represent the number of matter element, affair element, and relational element.

It should be emphasized that when taking the above models as representing design knowledge, it only expresses a state of the design attributes and does not express the degree of importance; the weight will not have to be contained in the above models.

#### *2.2. Construction of Extension Set of Basic Element*

In the process of product design, the customer's requirements can be divided into two components of common requirements and personalized requirements. Common requirements are the customer's knowledge and requirements for the product convergence; for this part of the design, we generally use the existing classical structure model or variant structure of the existing structure model to accomplish the conceptual product design. Personalized requirements are the customer's special knowledge and requirements for the product, and conceptual product design is often required by attaining innovation or extension on the structure or function of the existing product. Thus, in order to meet the customer's requirements comprehensively, the design process has the characteristics of dynamics, diversity, relevance, and level; the existing dominant design information may not be able to fully meet the design requirements; for this, it needs to mine design knowledge and form a set of design knowledge to improve the innovation ability of conceptual design.

According to the basic element theory of extenics, we know that the basic element has properties of implication and extension; through extension transformation, we can obtain more abundant tacit knowledge and obtain the corresponding extension set; this provides a means of support for the smooth implementation of the conceptual product design.

(1) Implication and the set of the basic element. For basic element *J*<sup>1</sup> and *J*2, if *J*<sup>1</sup> exists, then *J*<sup>2</sup> must exist, we call it *J*<sup>1</sup> contains *J*2, recording it as @*J*<sup>1</sup> ⇒ @*J*<sup>2</sup> , among it, @represents identification of existing. Because basic elements can be complex by conjunction Θ, the implication of basic element can be represented by @*Ji*Θ@*J<sup>j</sup>* ⇒ @*Js*Θ@*J<sup>t</sup>* , among it, *i*, *j*, *s*, *t* all represent the number of basic elements. Form the implication set of basic elements by basic elements, which is obtained by implication. The implication of basic elements can transmit and transform, so we can carry out the reasoning of the conceptual design process by the implication.

(2) Extensibility and basic element extension set. The extensibility of basic elements contains three aspects: divergence, expansion, and relevance. In the design field, through carrying out extension transformation of basic element characteristics and the value of characteristics, on the one hand, it can create the ways and approaches for design objects to outward divergence and expand, and acquire the extension design knowledge in the design field, on the other hand, it can build relationships between design objects, and acquire the relational design knowledge in the design field. We can acquire an extension set of the basic element *S*(*J*)*<sup>T</sup>* by extension transformation.

$$\begin{aligned} \mathcal{S}(I)\_T &= \left\{ (I, \Phi, \Psi) \middle| I \in T\_{\Omega(I)} \Omega(I), \Phi = K(I) = k(X), \\ \Psi &= T\_K K(T\_I I) = T\_k k(X^\*), X = c(I), X^\* = c^\*(T\_I I) \right\} \end{aligned} \tag{6}$$

Among it, *T*Ω, *TK*, and *T<sup>J</sup>* separately represent design object *J*'s extension transformation of the domain, correlation function, basic element characteristics, and its value. *c* is

the evaluation characteristic of *J*; its value is *X* = *c*(*J*); *c*\* is the evaluation characteristic of *J* that is acquired by extension transformation *TJ*, its value is *X*<sup>∗</sup> = *c*∗(*TJJ*); Φ = *k*(*X*) is the correlation function of evaluation characteristic, Ψ = *Tkk*(*X*∗) is correlation function of evaluation characteristic that is acquired by extension transformation *TJ*.

By Equation (6), the objects in the existing basic element set can be subject to extension transformation in many ways, such as domain, correlation function, basic element feature, and eigenvalue, so as to obtain more extensive design knowledge in the design field and related design knowledge among the design fields, thus providing support for subsequent extension reasoning.

#### **3. The Extension Design Pattern of Requirement Analysis of Conceptual**

For the conceptual design of complex products, the customer's requirements generally have the characteristics of abstraction, ambiguity, variability, diversity levels, and relevance; this often troubles designers in obtaining a correct understanding of the customer's design purpose, and it affects the design quality and design efficiency of products. Thus, on the basis of extension theory, analyze the customer's requirements, transform the design requirement into an objective expression of formal and modeling product requirement information, clearly reflect the level relationship and relational characteristics of customer's requirements, make the requirements information transform into technology requirements information effectively to guide products conceptual design, on the basis of these, to make the requirement analysis of products conceptual design more reasonable comprehensive and standard.

#### *3.1. The Extension Process Model of Requirement Analysis*

Due to the requirement analysis of products can acquire the initial design scheme of products conceptual design, the model of requirement analysis will directly affect the subsequent product's whole process of design, manufacturing, use, and maintenance; it can be seen that requirement analysis is an important part in the process of product design. Thus, for requirement analysis of complex product conceptual design, we cannot, in isolation, describe the requirement model from the customer's requirements and should carry out requirement transformation from the angle of the product life cycle; this process involves the whole product life cycle information, such as the design feasibility, manufacturability, assembly, maintainability, reliability. Strive to make the requirement model useful for the relevance and mapping of customer domain, functional domains, structural domain, and process domain in conceptual product design, and then provides a theoretical foundation and practical means and methods for the automation of complex product design.

On the basis of basic element theory, we can build a basic element model for every requirements information in requirements analysis, separately build the matter-type requirement basic element model *J*(*R*)*C*, behavior-type requirement basic element model *J*(*I*)*C*, relation-type requirement basic element model *J*(*Q*)*<sup>C</sup>* and compound -type requirement basic element model *J*(*F*)*C*. Matter-type requirement basic element model *J*(*R*)*<sup>C</sup>* describes characteristics requirements, functional requirements, structural requirements, environmental requirements, performance requirements, and other aspects of static properties and design information. Behavior-type requirement basic element model *J*(*I*)*<sup>C</sup>* describes design behavior-type information related to requirement analysis in the product design process, such as solving problems, judgment knowledge, process planning, and reasoning. Relation-type requirement basic element model *J*(*Q*)*<sup>C</sup>* describes the various constraints or dependent information between requirement characteristics in the product design process, such as configuration relationship, comparative relationship, and logical relationship. The compound-type requirement basic element model *J*(*F*)*<sup>C</sup>* is the combination of the various requirement basic elements. On the basis of the above basic element models, we can acquire the set of requirement basic element *S*(*J*)*CT* and the corresponding knowledge database of various requirement basic elements. Based on the extension theory, the demand information is analyzed, evaluated, and transformed to form the subsequent product design information, which can better support the rapid design of complex products. The extension process model of requirement analysis in complex product conceptual design is shown in Figure 1.

**Figure 1.** The extension process model of requirement analysis of complex product conceptual design.

It can be seen that after obtaining the corresponding demand information based on the relevant design requirements, the demand information can be decomposed based on semantic transformation and combined with extension analysis and evaluation methods, and then the primitive modeling can be carried out to form the extension set of demand primitives Extension reasoning and extension transformation are used to map the requirements hierarchically, so as to obtain the design information that meets the design requirements. After the primitive modeling, it is stored in the primitive knowledge base.

#### *3.2. The Implication Process Model of Requirement Analysis*

In the extension process model of the requirement of complex product conceptual design, after semantic transforming requirement information, extension analyzing, and evaluating it, we can acquire the minimum, complete, independent design information unit in the representation design process, and after modeling it, we can acquire its corresponding requirement basic element. By requirement analysis process of product design, we know that customer requirements in the field of product design can generally be divided into common customer requirements and individual customer requirements; common customer requirements are the converging understanding and requirements of the customers for the product in the design field, individual customer requirements are some special understanding and requirements based on common customer requirements.

Because the representation of common customer requirements is common design information in the design field, obviously, in order to provide improved support for the rapid design of the product, it needs effectively reuse this part of the common design information, which is a common requirement basic element. Because the basic elements have the property of implication, the experts in the design field use the method of analysis and evaluation or the method of data mining to acquire the implication relationship in the extension set of requirement analysis, and due to the implication relationship, in new product design, it only needs to match the condition items of implication, then we can reuse the result items of implication relationship, thus to effectively reuse existing design results, short the design cycle and improve the design efficiency. The implication process model of requirement analysis that is oriented toward the rapid design of the product is expressed as follows:

⎧ ⎪⎪⎪⎪⎪⎨

⎪⎪⎪⎪⎪⎩

$$\begin{cases} \forall (\mathcal{I}\_{\rm Cm}, \mathcal{I}\_{\rm Cn}) I\_{\rm Cm} \in \Omega \land I\_{\rm Cn} \in \Omega \land (\mathcal{I}\_{\rm Cm} \oplus \mathcal{I}\_{\rm Cn}) \in \Omega \land ((\mathcal{I}\_{\rm Cm} \oplus \mathcal{I}\_{\rm Cn}) \Rightarrow (\mathcal{I}\_{\rm S} \oplus \mathcal{I}\_{\rm I})) \in \Omega, \quad m \neq n \\\ if \quad \bullet \mathcal{I}\_{\rm Cl} \in \Omega \land \partial \mathcal{I}\_{\rm C0} \in \Omega \land \ast \left( (\mathcal{I}\_{\rm Cl} \oplus \mathcal{I}\_{\rm C0}) \right) \in \Omega, \quad i \neq j \\\ \exists \left( \left( I\_{\rm Cl} \oplus \mathcal{I}\_{\rm C0} \right) \Xi(I\_{\rm Cm} \oplus \mathcal{I}\_{\rm Cn}) \right) \land \mathcal{K} \left( \left( I\_{\rm Cl} \oplus \mathcal{I}\_{\rm C0} \right) \Xi(I\_{\rm Cm} \oplus \mathcal{I}\_{\rm Cn}) \right) \geq \mathcal{K}\_{\rm 0} (\Omega) \\\ then \quad \left( I\_{\rm Cl} \oplus \mathcal{I}\_{\rm C0} \right) \in \mathcal{S} (I\_{\rm Cm} \oplus \mathcal{I}\_{\rm Cn})\_{\rm CT} \end{cases} \tag{7}$$

Among it, *JCm* and *JCn* represent the existing requirement basic elements, *J<sup>s</sup>* and *J<sup>t</sup>* represent the basic elements of design result in extension set, *JC*0*<sup>i</sup>* and *JC*0*<sup>j</sup>* represent basic requirement elements in the process of requirement analysis, Ω represents discourse domain of design, Ξ represents matching identification of basic element model, *K*((*JC*0*i*Θ*JC*0*j*)Ξ(*JCm*Θ*JCn*)) represents the matching degree of basic element model, *K*0(Ω) represents the allowable matching threshold in discourse domain.

From the above implication process model of requirement analysis, it can be seen that when the match degree between the requirement basic element or its compound element and the existing requirement basic element or its compound element meets the given match threshold, the design results contained in the existing requirement basic element or its compound element can apply into product scheme design as an effective reusable object. In the extension multiplexing method of fast configuration conceptual design, the basic element matching algorithm based on extension theory is described.

#### *3.3. The Weight Distribution Model of Requirement Basic Element Based on Extension Distance*

The extension process model of requirement analysis based on complex product conceptual design can achieve the decomposition and mapping of the design requirements, but because requirement information in requirement analysis of conceptual product design has characteristics of fuzziness and relevance, the weight of requirement characteristics and design parameters is usually not easy to be determined. For this, this paper puts forwards a new method of weight allocation based on extension distance compared with the existing weight allocation method; the weight allocation method based on extension distance is an analysis method combined qualitative and quantitative, and it can preferably solve the problems that evaluation indicators are difficult to quantify and statistical in requirement analysis, and can exclude the impact of human factors, make the result of weight allocation more scientific, more objective and more accurate.

Assuming after decomposing the requirement, it has *P* requirement basic elements; According to the design requirement, it needs to invite *Z* experts in the design field; On the basis of importance degree of costumer's requirement, separate ratio scale of requirement basic element into 0~9, form the ratio scale interval *uij* = [*u<sup>l</sup> ij*, *<sup>u</sup><sup>r</sup> ij*]; that is, *j* the expert evaluates the requirement basic element *<sup>J</sup>i*, among it 0 ≤ *<sup>u</sup><sup>l</sup> ij* ≤ 9, 0 ≤ *<sup>u</sup><sup>r</sup> ij* ≤ 9, *<sup>u</sup><sup>l</sup> ij* ≤ *<sup>u</sup><sup>r</sup> ij*. Thus acquire the ratio scale interval sequence of requirement basic element *J<sup>i</sup>* that is expressed by *U*(*Ji*) = ([*u<sup>l</sup> <sup>i</sup>*1, *<sup>u</sup><sup>r</sup> <sup>i</sup>*1], [*u<sup>l</sup> <sup>i</sup>*2, *<sup>u</sup><sup>r</sup> <sup>i</sup>*2], ··· , [*u<sup>l</sup> iZ*, *<sup>u</sup><sup>r</sup> iZ*]). Build ideal ratio scale interval sequence of basic requirement element *U*(*J*0) = ([*u<sup>l</sup>* <sup>01</sup>, *<sup>u</sup><sup>r</sup>* <sup>01</sup>], [*u<sup>l</sup>* <sup>02</sup>, *<sup>u</sup><sup>r</sup>* <sup>02</sup>], ··· , [*u<sup>l</sup>* <sup>0</sup>*Z*, *<sup>u</sup><sup>r</sup>* <sup>0</sup>*Z*]) based on *P* requirement

basic elements' ratio scale interval sequence, and meets [*u<sup>l</sup>* 0*j* , *u<sup>r</sup>* 0*j* ] = max 1≤*i*≤*p ul ij*, max 1≤*i*≤*p ur ij* .

Then construct the extension relational coefficient *ρij* that is *U*(*Ji*) and *U*(*J*0) concerning *j* the scale value based on extension distance:

$$\begin{array}{lcl}\rho\_{ij} = \rho\left(\left[u\_{ij}^{l}, u\_{ij}^{r}\right], \left[u\_{0j}^{l}, u\_{0j}^{r}\right]\right) = \frac{\rho\left(u\_{ij}^{l}, \left[u\_{0j}^{l}, u\_{0j}^{r}\right]\right) + \rho\left(u\_{ij}^{r}, \left[u\_{0j}^{l}, u\_{0j}^{r}\right]\right)}{2} \\ = \frac{\left(\left[u\_{ij}^{l} - \frac{u\_{0j}^{l} + u\_{0j}^{r}}{2}\right] - \frac{1}{2}\left(u\_{0j}^{r} - u\_{0j}^{l}\right)\right) + \left(\left|u\_{ij}^{r} - \frac{u\_{0j}^{l} + u\_{0j}^{r}}{2}\right| - \frac{1}{2}\left(u\_{0j}^{r} - u\_{0j}^{l}\right)\right)}{2} \end{array} \tag{8}$$

Then the extension degree *λ<sup>i</sup>* between *U*(*Ji*) and *U*(*J*0) is:

$$\lambda\_i = \frac{1}{Z} \sum\_{j=1}^{Z} (\theta - \rho\_{ij}) \tag{9}$$

Then the relatively of requirement is expressed by:

$$w\_{lli} = \lambda\_i / \sum\_{i=1}^{P} \lambda\_i \tag{10}$$

Thus, obtaining the sequence of the weight of basic requirement element *wU* = [*wU*1, *wU*2, ··· , *wUP*] *<sup>T</sup>*, and meet ∑*<sup>P</sup> <sup>i</sup>*=<sup>1</sup> *wUi* = 1.

The weight distribution of basic elements of design parameters obtained by mapping requirements analysis takes each basic element of demand as the standard, that is, the scale interval of each basic element of demand as the ideal scale interval, and the extension correlation coefficient is calculated. Assuming it has *Q* design parameters and basic elements, it needs to invite *Z* experts in the design field. On the basis of the importance degree of costumer's requirement, a separate ratio scale of requirement basic element into 0~9 acquire design parameter basic element *Jk*'s ratio scale interval sequence *V*(*Jk*) = (*lbracku<sup>l</sup> <sup>k</sup>*1, *<sup>u</sup><sup>r</sup> <sup>k</sup>*1], [*u<sup>l</sup> <sup>k</sup>*2, *<sup>u</sup><sup>r</sup> <sup>k</sup>*2], ··· , [*u<sup>l</sup> kZ*, *<sup>u</sup><sup>r</sup> kZ*]), *k* = 1, 2, ··· , *P*. By using the above similarly processing process, take the requirement basic element *J<sup>i</sup>* as the evaluation standard, then the extension relational degree *λik* between design parameter basic element *J<sup>k</sup>* and requirement basic element *J<sup>i</sup>* can be expressed by:

$$\begin{split} \lambda\_{ik} &= \frac{1}{Z} \sum\_{j=1}^{Z} \rho \left( \left[ v\_{kj}^{l}, v\_{kj}^{r} \right], \left[ u\_{ij'}^{l}, u\_{ij}^{r} \right] \right) = \frac{1}{Z} \sum\_{j=1}^{Z} \left( 9 - \frac{\rho \left( v\_{kj}^{l}, \left[ u\_{ij'}^{l}, u\_{ij}^{r} \right] \right) + \rho \left( v\_{kj'}^{l}, u\_{ij}^{l} \right]}{2} \right) \\ &= \frac{1}{Z} \sum\_{j=1}^{Z} \left( 9 - \frac{\left( \left| v\_{kj}^{l} - \frac{u\_{ij}^{l} + u\_{ij}^{r}}{2} \right| - \frac{1}{2} \left( u\_{ij}^{r} - u\_{ij}^{l} \right) \right) + \left( \left| v\_{kj}^{l} - \frac{u\_{ij}^{l} + u\_{ij}^{r}}{2} \right| - \frac{1}{2} \left( u\_{ij}^{r} - u\_{ij}^{l} \right) \right)}{2} \right) \end{split} \tag{11}$$

On the basis of it, we can acquire an extension relational degree matrix A*<sup>J</sup>* between *Q* design parameter basic elements and *P* requirement basic elements:

$$\mathbf{A} \mathbf{y} = \begin{bmatrix} \lambda\_{11} & \lambda\_{12} & \cdots & \lambda\_{1Q} \\ \lambda\_{21} & \lambda\_{22} & \cdots & \lambda\_{2Q} \\ \vdots & \vdots & \vdots & \vdots \\ \lambda\_{P1} & \lambda\_{P2} & \cdots & \lambda\_{PQ} \end{bmatrix}\_{P \times Q} \tag{12}$$

The design parameter basic element weighting extension relational degree sequence *wV* based on requirement basic element weight sequence is:

$$w\_V = w\_U^T \* \mathbf{A}\_I = [w\_1, w\_2, \dots, w\_P]\_{1 \times P} \* \begin{bmatrix} \lambda\_{11} & \lambda\_{12} & \dots & \lambda\_{1Q} \\ \lambda\_{21} & \lambda\_{22} & \dots & \lambda\_{2Q} \\ \vdots & \vdots & \vdots & \vdots \\ \lambda\_{P1} & \lambda\_{P2} & \dots & \lambda\_{PQ} \end{bmatrix}\_{P \times Q} \tag{13}$$

Then absolutely weight *wVk* of design parameter basic element *J<sup>k</sup>* is:

$$w\_{Vk} = \sum\_{i=1}^{P} (w\_{lli} \* \lambda\_{ik}), 1 \le i \le P \tag{14}$$

Then absolutely weight *w*∗ *Vk* of design parameter basic element *J<sup>k</sup>* is:

$$w\_{Vk}^{\*} = w\_{Vk} / \sum\_{k=1}^{Q} w\_{Vk} \tag{15}$$

From these, acquire design parameter basic element weight sequence *wV* = *w*∗ *<sup>V</sup>*1, *w*<sup>∗</sup> *<sup>V</sup>*2, ··· , *w*<sup>∗</sup> *VQ <sup>T</sup>* , and meets ∑*<sup>Q</sup> <sup>k</sup>*=<sup>1</sup> *wVk* = 1.

#### *3.4. The Implementation of Extension Design Pattern of Requirement Analysis*

The final result of product conceptual design requirements analysis can effectively map the design parameters of subsequent products, including functional design parameters, structural design parameters, and process design parameters. The essence of extension design for requirement analysis of conceptual product design effectively transformed the basic requirement element into a design parameter basic element and formed an extensible design frame. Based on extension theory and axiomatic design, the traditional QFD is improved, and a demand analysis extension design mode that transforms customer requirements into design parameters is proposed. Compared with traditional quality function deployment QFD [34,35], the extension design pattern of requirement analysis is not just formulate the product planning or improve the product structure; it also uses the improved quality function deployment QFD to acquire design information that guides and runs through the product lifecycle. Figure 2 gives the frame of implementation of the extension design pattern of requirement analysis.

**Figure 2.** The extension design pattern of requirement analysis of conceptual product design.

It can be seen from Figure 2 that the extension design mode of product scheme design requirement analysis divides the process of product scheme design requirement analysis into customer domain, function domain, structure domain, process domain, etc. The extension set of the basic requirement element corresponds to the extension modeling of the customer domain. The extension set of the design parameter basic element includes the extension set of the functional characteristic basic element, the extension set of the structural characteristic basic element, and the extension set of the process characteristic basic element. The extension set of the functional characteristic basic element corresponds to the extension modeling of the functional domain, the extension set of the structural characteristic basic element corresponds to the extension modeling of the structural domain, and the extension set of process characteristic basic element corresponds to the extension modeling of the process domain. Based on axiomatic design theory, adjacent design domains have corresponding mapping relations, which can be realized by z-mapping. Similarly, the corresponding basic element extension sets of adjacent design domains also have corresponding z-mapping and corresponding extension incidence matrix. By using the demand analysis implication process model, extension analysis and extension transformation, and the z-mapping of axiomatic design, the extension set of the design parameter basic element and extension

scheme set is generated, and the optimal design scheme is obtained based on the extension optimization method. The extension optimization method will be discussed in detail in the subsequent papers. It needs to be explained here that the effective construction of the extension set of the design parameter basic element is acquired by the combination of Z mapping in the design field, the implication process model of requirement analysis, and the method of extension analysis and transformation. Specifically, by Z mapping in the design field based on axiomatic design, design parameters commonly have the structural characteristics of the level association. However, using compound elements can formally and, with modeling, describe the design parameters which have the structural characteristics of the level association so that the product extension design can be implemented smoothly.

In summary, the implementation steps of the extension design pattern of requirement analysis for conceptual product design can be expressed as follow:

Step1: Acquire requirement information in the design field, decompose the requirement, build the unit of requirement information, and build the model of the basic requirement element.

Step 2: Construct the extension set of the basic requirement element, and build the implication process model of the basic requirement element.

Step 3: On the basis of axiomatic design, hierarchical Z-mapping requirement basic element in customer domain into design parameter basic element in function domain, structural domain, and process domain, transform the design parameters combined with the implication process model of requirement basic element and its relational extension transformation method and acquire its relational extension set of design parameter basic element.

Step 4: Construct a weight allocation model of the basic requirement element, and build an extension correlation matrix in different design fields.

Step 5: On the basis of the basic element knowledge base, rule base, and case base that is constructed in the design field, they match the design parameter basic element to acquire the set of conceptual design schemes.

#### **4. Application Example**

This paper takes a selection scheme of large-scale hydropower turbines as an example to describe the implementation of extension design patterns for complex products. Due to the different areas of local geology, topography, water quality, current, and the environment have a big difference, so the design requirement of hydroelectric power stations in the different regions has characteristics of diversity and dynamic; it needs a variety of types of turbines to meet the corresponding design requirements. However, because the large turbine design theory is imperfect, the internal fluid motion of the turbine is complex, and the production model of turbine design has characteristics of a single piece, small-batch, and large sets; these make the large turbine design process have characteristics that are multilevel, multi-attribute, multi-constrained and multi-objective, and the implementation of the design scheme becomes very cumbersome. Therefore, based on the extension design mode of complex products, this paper conducts an extension analysis of design requirements for a large turbine selection scheme, determines the design domain of turbine products, and obtains the basic design parameters of a large turbine selection scheme in the design domain.

It is known that a large-scale hydropower's geographical environment is a multimountainous region; its terrain is relatively steep, the water's silt content is high, the water flow is relatively large, and the head is relatively high. Table 1 gives the design requirement parameters of a hydropower station.


**Table 1.** The Design and Exploitation Requirement Parameters of a Hydropower Station.

To make a selection design of the turbine, you must first determine the direction of the design of the turbine, that is, determine the structure type of turbine based on the actual situation of hydropower, such as geology, topography, water quality, and water flow. According to the experience in the field of design, we know that large turbine structure type contains Francis, axial, oblique flow, tubular, and pelton, and each type of turbine applies to different conditions; it is generally determined by the head, power, load changes, sediment concentration, flow, etc. By analyzing the hydropower station's requirement information, we know that the hydropower station has a high head, medium-power, medium load changes, high sediment concentration, and large water flow, so it is suitable to use a Francis turbine. By experts' analysis, discussion, and evaluation, the hydropower design requirement information is broken down into common requirement information and individual requirement information; the specific content is shown in Table 2.

**Table 2.** The Decomposition of Design Requirement Information.


For common requirement information, new product design parameters can be obtained based on common design requirements templates in the design field; For individual requirement information, due to the diversity of information change, there are no corresponding templates to be chosen, and we need to take an approach that is same as common design requirements templates to the analysis, that is transformation and mapping among the customer domain, functional domains, structure domain and process domain based on the axiomatic design theory to obtain the corresponding product design parameters.

Figure 3 shows the framework of the Francis turbine's common design requirements information template based on axiomatic design theory and extension theory.

**Figure 3.** The framework of the Francis turbine's common design requirements information template.

It can be seen from Figure 3 that the common design requirement information template for the hydraulic turbine includes three parts: the requirement domain, the functional domain, and the structural domain. For different design domains, the corresponding design domain structural template and the corresponding basic element set can be generated; that is, the requirement domain corresponds to the requirement domain structural template and the requirement domain basic element set, and the functional domain corresponds to the functional domain structural template and the functional domain basic element set. The structure domain corresponds to the structure template of the structure domain and the basic element set of the structure domain. Based on the axiomatic design theory, it can be seen that the structure template of the demand domain, the structure template of the function domain, and the structure template of the structure domain have the same mapping relationship. Similarly, there is the same mapping relationship between the demand domain basic element set, the functional domain basic element set, and the structure domain basic element set.

Quick configuration of the selection of the turbine design is to determine the turbine's critical flow path model such as runner, volute, draft tube, the guide vane, and so on. In shunt turbine general design requirements information as you can see, in the framework of the template, volute, guide tube, and turbine guide vane wheel are key design components in turbine products secondary components devices, thus determine key flow turbine model can obtain the key design parameters of the turbine design, be able to support the smooth implementation of subsequent turbine design. To this end, in the framework of the common Francis turbine design requirements template, we use the extension process model of requirement analysis to model the design requirement information and two object-type basic elements to describe design requirement items in Table 1, which are fundamental design parameter requirement basic element *J*(*R*)*C*0-*<sup>D</sup>* and auxiliary design parameter requirement basic element *J*(*R*)*C*0-*P*. The fundamental design parameter requirement basic element *J*(*R*)*C*0-*<sup>D</sup>* is used to design the runner model and volute, draft tube, the guide vane flow path model, and the auxiliary design parameter requirement basic element *J*(*R*)*C*0-*<sup>P</sup>* is used to assist and guide selection design of the turbine.

After modeling the design information, it can build the corresponding extension set of basic elements and the basic element knowledge base. Therefore, according to the fundamental design parameter requirement, basic element *J*(*R*)*C*0-*D*'s head to retrieve basic element knowledge base and acquire runner model which meets head range requirements, the matched runner, basic element models, in basic element knowledge base are as below:


According to the experience in the selection of the turbine design, we know that when it meets head range requirements, the power and efficiency are the main basis for selecting the runner model. It can be seen that in matched runner basic element models, *JTurb-Runner*<sup>01</sup> cannot meet the power design requirement, *JTurb-Runner*<sup>03</sup> cannot meet the efficiency design requirement. Although *JTurb-Runner*<sup>02</sup> and *JTurb-Runner*<sup>04</sup> both can meet the power design requirement and efficiency design requirement when rated power and maximum power are close, *JTurb-Runner*<sup>04</sup> has higher efficiency, so *JTurb-Runner*<sup>04</sup> is the best runner matching object in the basic element knowledge base.

Take basic element model *JTurb-Runner*04's name as the condition item for the extension process model of requirement analysis, use the classic frequent pattern tree algorithm (FP\_growth), set condition item *A*364 as the root node of frequent pattern tree, carry out frequent pattern mining among runner models and volute, draft tube, guide vane's flow path models in basic element knowledge base and rule base. If the acquired frequent pattern meets the requirements of the support and confidence, then we can acquire a strong implication relationship between runner model and volute, draft tube, guide vane's flow path models, that is *JTurb*-*Runner*04|*A*364 ⇒ *JWK*-94|*A*364, *JTurb*-*Runner*04|*A*364 ⇒ *JDY*-43|*A*364, *JTurb*-*Runner*04|*A*364 ⇒ *JWSG*-51|*A*364. According to the extension implication relationship, and on the basis of volute *JWK*-94|*A*364, draft tube *A*364 ⇒ *JWSG*-51 and guide vane *JDY*-43|*A*364, we can carry out selection design of volute, draft tube, guide vane's flow path which is associated to target runner.

At the same time, although *JTurb-Runner*<sup>04</sup> is the best runner matching object in the basic element knowledge base, *JTurb-Runner*04's property parameters are not fully compliant with design requirements; for this, it needs to carry out an extension transform for the value of characteristics of *JTurb-Runner*04, that is use expertise of turbine to analyze and optimize the value of characteristics of *JTurb-Runner*04, acquire reasonable nominal diameter of runner, rotational rate and flow path combination which meet force requirement and have high efficiency, and then to determine the other design parameters of target runner. On the basis of the description in the paper, combined with knowledge in the turbine design field, and according to the comprehensive characteristic curve of the existing turbine runner, we can first take the maximum head, design head, and head range as the characteristics of extension transform, take the force and efficiency as constraints of extension transform, to carry out multi-level reasoning analysis for runner *JTurb-Runner*04, find design parameter combination of matched runner's unit speed and unit flow in the comprehensive characteristic curve of the turbine runner. If the design parameter combination meets the design requirements, then take it as an extension reuse object; If the design parameter combination cannot meet the design requirements, it needs to take unit speed and unit flow as characteristics of extension transform to carry out the next level extension transformation, and so forth, ultimately acquire runner basic element model that meets requirements:


Take efficiency and reliability as the main evaluation characteristics in the extension design of runner selection, and take the compact of structure type (that is, the diameter and height dimensions), runaway rotation rate, rated rotation rate, and rated flow as referenced evaluation characteristics, it can be seen that *JTurb-TA*364-1 is the best runner object of extension transform, that is:


Because there is an extension implication relationship between runner *JTurb-Runner*<sup>04</sup> and volute *J*WK-94|*A*364, draft tube *J*WSG-51|*A*364, guide vane *J*DY-43|*A*364, when carrying out extension transform for runner *JTurb-Runner*04's characteristic value, the volute *J*WK-94|*A*364, draft tube *J*WSG-51|*A*364 and guide vane *J*DY-43|*A*364's design requirement parameters would change. According to the comprehensive characteristic curve of the turbine runner, we can acquire the corresponding design requirement interval. Table 3 gives the design requirement parameters of the partial flow path model.

**Table 3.** The Design Requirement Parameters of Partial Flow Path Model.


For this, it needs to carry out volute, draft tube, and guide vane's extension configuration design and extension adaptive design based on the new design requirement parameters, and then complete the scheme design of large turbine selection. If carrying out extension configuration design and extension adaptive design, it needs to carry out weight allocation for every characteristic. Because extension configuration design is mainly too fast and extensible to match the design objects based on existing design instances or design results, so the influence of common requirement characteristics for the weight of every characteristic parameter is most prominent; for this, this paper takes common requirement characteristics as evaluation standards of the design characteristic parameters weight allocation. On the basis of the description in the paper, take common requirement characteristics of the head, output, efficiency, cavitation property, and runaway property as requirement basic element items, and take volute, draft tube, and guide vane as design parameter basic element items, invite six experts in the turbine design field to grade requirement basic element items and design parameter basic element items by ratio scale interval [1–9], the specific values are shown in Tables 4 and 5.


**Table 4.** The Scoring Results of Requirement Basic Element Items.

**Table 5.** The Scoring Results of Design Parameter Items.


Build requirement basic elements' ideal ratio scale interval sequence *U*(0) = ([9, 9], [9, 9], [9, 9], [9, 9], [9, 9], [9, 9],), based on the Formula (7) build extension correlation coefficient matrix *ρ* between the requirement basic element items and requirement basic elements' ideal ratio scale interval sequence *U*(0):


Build extension correlation sequence *λ* between the requirement basic element items and requirement basic elements' ideal ratio scale interval sequence *U*(0) based on Formula (8):

$$\lambda = \begin{bmatrix} 7.875, \ 9.000, \ 8.083, \ 7.042, \ 7.958 \end{bmatrix}^T$$

Acquire the weight sequence of requirement basic element items based on Formula (9):

$$w\_{\rm{U}} = [0.197, \; 0.226, \; 0.202, \; 0.176, \; 0.199]^T$$

For design parameters items, separately select requirement basic element items as ideal ratio scale interval sequence, acquire extension correlation degree matrix *AJ* between design parameters and requirement basic element items based on Formulas (10) and (11):


We can acquire the design parameters' weight sequence *wV* = [0.364, 0.321, 0.333] *T* based on Formulas (12)–(14). By weight sequence, it can be seen that the weights of the various design parameters are consistent with turbine design because volute, draft tube, and guide vane are all the core components of each functional unit, so when carrying out fast configuration design, the weight of them is little difference; At the same time, due to the volute as diversion components, we should lead water into hydraulic components by minimum hydraulic losses and ensure water flow uniform, then it is conducive to the guiding apparatus that the guide vanes carry out flow regulation and draining parts that are draft tube carry out reflow processes, so the weight of volute is slightly higher. In addition, as the volute, draft tube, and guide vane's design requirement parameters are all the key control parameters, so volute, draft tube, and guide vane's respective design attributes have the same weight; that is, the volute's respective design attributes weights are *wV-WK* = 0.250, draft tube's respective design attributes weights are *wV-WSG* = 0.500, guide vane's respective design attributes weights are *wV-DY* = 0.250.

#### **5. Discussion**

From the above theoretical discussion and application cases, it can be seen that the method proposed in this paper has a strong theoretical foundation. From the topological knowledge modeling, extension analysis, implication analysis, demand analysis index weight acquisition, and extension pattern generation in demand analysis, an extension demand analysis method system for complex product scheme design is formed, which has good engineering applicability.

By establishing the basic element model of complex product scheme design requirement information and the corresponding extension set of requirement basic elements, this method can formally represent various deep-seated design requirement information. This method establishes the extension process model of complex product scheme design requirement analysis and the implication process model of requirement analysis. Based on the inherent implication and relevance of design requirement information, the rapid transformation and hybrid reasoning of design requirements are carried out, which makes the mapping of complex product scheme design requirements more intuitive and effective. Moreover, this method gives a basic demand element weight distribution model based on extension distance, which can obtain accurate demand analysis weight from the combination of qualitative and quantitative perspectives and can take into account the influence of design constraints and design characteristics on the design demand attribute weight. At the same time, based on the extension correlation degree of basic demand elements, this paper establishes the implementation framework and algorithm of the extension design pattern for the demand analysis of complex product scheme design, which comprehensively reflects the design requirements and design intent of the scheme and provides support for the smooth implementation of complex product design. The application of the example also verifies the effectiveness and feasibility of the algorithm.

In addition, the application of knowledge extension reuse technology in complex product scheme design not only makes product design standardized and systematic but also expands the application field of expert systems, provides a theoretical basis for computeraided product conceptual design, and plays an important role in the smooth implementation of complex product design scheme development.

#### **6. Conclusions**

In view of the multi-level, multi-attribute, and creative product structure configuration process of complex products, this paper studies and analyzes the extension design mode of complex product scheme design demand analysis with the characteristics of abstraction, fuzziness, variability, diversity, hierarchy, and relevance. The specific results and conclusions are as follows: (1) The basic element model of the demand information of complex product scheme design and the corresponding extension set of the basic demand element are established to realize the formal modeling of the demand analysis and design information of the product scheme design. (2) The extension process model and the implication process model of demand analysis for complex product scheme design are established, which provides support for generating more abundant knowledge of demand analysis. (3) The weight distribution model of the basic demand element based on extension distance is established, which provides support for improving the reasoning ability of product demand analysis. (4) The framework and algorithm of the extension design pattern for the requirement analysis of complex product scheme design are proposed, and the extension requirement analysis of complex product scheme design is realized. On the basis of obtaining the results of extension requirement analysis, how to effectively carry

out extension knowledge reasoning and extension knowledge reuse of complex product scheme design will have important research significance, which will provide important support for rapid configuration design of complex products.

**Author Contributions:** Conceptualization, T.W.; Data curation, T.W.; Formal analysis, T.W., H.L. and X.W.; Funding acquisition, T.W.; Methodology, T.W.; Validation, H.L. and X.W.; Writing—original draft, T.W.; Writing—review & editing, T.W. All authors have read and agreed to the published version of the manuscript.

**Funding:** This research was funded by Natural Science Foundation of Jiangsu Province of China (No. BK20221481), the National Natural Science Foundation of China (No. 51775272, No. 51005114).

**Institutional Review Board Statement:** Not applicable.

**Informed Consent Statement:** Not applicable.

**Data Availability Statement:** The labeled dataset used to support the findings of this study are available from the corresponding author upon request.

**Conflicts of Interest:** The authors declare no conflict of interest.

#### **References**


### *Article* **Research on Adversarial Domain Adaptation Method and Its Application in Power Load Forecasting**

**Min Huang \* and Jinghan Yin**

Department of Software Engineering, South China University of Technology (SCUT), Guangzhou 510006, China **\*** Correspondence: minh@scut.edu.cn

**Abstract:** Domain adaptation has been used to transfer the knowledge from the source domain to the target domain where training data is insufficient in the target domain; thus, it can overcome the data shortage problem of power load forecasting effectively. Inspired by Generative Adversarial Networks (GANs), adversarial domain adaptation transfers knowledge in adversarial learning. Existing adversarial domain adaptation faces the problems of adversarial disequilibrium and a lack of transferability quantification, which will eventually decrease the prediction accuracy. To address this issue, a novel adversarial domain adaptation method is proposed. Firstly, by analyzing the causes of the adversarial disequilibrium, an initial state fusion strategy is proposed to improve the reliability of the domain discriminator, thus maintaining the adversarial equilibrium. Secondly, domain similarity is calculated to quantify the transferability of source domain samples based on information entropy; through weighting in the process of domain alignment, the knowledge is transferred selectively and the negative transfer is suppressed. Finally, the Building Data Genome Project 2 (BDGP2) dataset is used to validate the proposed method. The experimental results demonstrate that the proposed method can alleviate the problem of adversarial disequilibrium and reasonably quantify the transferability to improve the accuracy of power load forecasting.

**Keywords:** domain adaptation; adversarial learning; adversarial equilibrium; transferability quantification; power load forecasting

**MSC:** 68T07

#### **1. Introduction**

Power load forecasting aims to predict the power load in the power system in the future by mining the characteristics of users' power consumption behavior hidden in historical records, weather, dates, and other data. According to the forecast time, power load forecasting can be divided into long-term, medium-term, and short-term. Short-term power load forecasting refers to prediction of the power load value several hours or days in the future, which is an important basis for realizing the rapid response of the power system to changes in power load.

Recently, machine learning has accomplished extraordinary triumphs in the avenue of computer vision [1], semantic segmentation [2], regression prediction [3], natural language processing [4], etc. However, two problems of traditional machine learning are gradually exposed: Firstly, traditional machine learning requires a large amount of labeled data, and the cost of collecting and labeling data is expensive; thus, it is difficult to be applied in fields that lack the data required for training models. Secondly, an important condition for traditional machine learning being effective is that test and train data obey the assumption of independent and identical distributions (IIDs); however, the condition of IID is usually not satisfied in the real world, resulting in a decrease in the accuracy and generalization capabilities. Correspondingly, due to the strong personalization of power consumption behavior, there are differences in the distribution of power load data of different users. Due

**Citation:** Huang, M.; Yin, J. Research on Adversarial Domain Adaptation Method and Its Application in Power Load Forecasting. *Mathematics* **2022**, *10*, 3223. https://doi.org/10.3390/ math10183223

Academic Editors: Jianping Gou, Weihua Ou, Shaoning Zeng and Lan Du

Received: 22 June 2022 Accepted: 31 August 2022 Published: 6 September 2022

**Publisher's Note:** MDPI stays neutral with regard to jurisdictional claims in published maps and institutional affiliations.

**Copyright:** © 2022 by the authors. Licensee MDPI, Basel, Switzerland. This article is an open access article distributed under the terms and conditions of the Creative Commons Attribution (CC BY) license (https:// creativecommons.org/licenses/by/ 4.0/).

to the difficulty in collecting historical data, there is a lack of labeled data for training. The above factors hinder the application of traditional machine learning methods in short-term power load forecasting.

Domain adaptation has received extensive attention as one of the effective methods to overcome the difficulties of few-shot learning [5–7]. Domain adaptation aims to transfer knowledge from related labeled data by reducing the distribution difference between the source domain and the target domain. Domain adaptation reduces the number of labeled samples required to achieve the target task and does not strictly require the data to satisfy the condition of IID.

The key aim of the domain adaptation method is to align the feature distribution of the source domain and target domain data. The process of aligning the feature distribution is also called domain alignment. Domain adaptation methods can be divided into three types roughly according to different alignment strategies: discrepancy-based, adversarial-based, and reconstruction-based.

Discrepancy-based methods use different metric schemas to measure the distance between the source domain and the target domain; it aligns the distribution by reducing the difference metric schemas. The method adds different distance loss functions to the artificial neural network. The most widely used metric schemas include Maximum Mean Discrepancy (MMD) [8–10], KL (Kullback–Leibler) divergence [11], JS (Jensen–Shannon) divergence [12], Wasserstein distance [13–15], CORAL (CORrelation ALignment) [16,17], etc.

Adversarial-based methods [18–25] are inspired by GANs and use artificial neural network modules instead of metric schemas to measure the distance. The key components of the adversarial domain adaptation model include a feature extractor and a domain discriminator. The feature extractor extracts the domain-invariant features of the source and target domains to confuse the domain discriminator; at the same time, the domain discriminator distinguishes a sample from the source domain or the target domain, and the strategy of maximizing and minimizing the domain discrimination loss is used to form a confrontation between the two and to implement domain alignment during the adversarial training.

Reconstruction-based methods [26–29] aim to reconstruct all domain data under the premise of preserving domain-specific features to better help learn domain-invariant features. The encoder–decoder is a typical implementation of reconstruction-based methods, the shared encoder encodes the input data as hidden features and learns domaininvariant features, and the decoder reconstructs the hidden features and preserves domainspecific features.

Domain adaptation methods realize the cross-domain transfer and reuse of knowledge, and so many researchers use it to overcome the problem of data shortage in power load forecasting: Ref. [30] proposes a general framework for adversarial domain adaptation methods on time series prediction problems; Ref. [31] introduces a contrastive evaluation module to protect the task-specific features of the target domain in domain alignment; Ref. [32] builds adversarial feature capture networks to achieve reliable energy prediction. Ref. [33] proposes an electricity load forecasting algorithm through bidirectional generative adversarial networks and validates it on user data with different behavior patterns; the flexibility and accuracy of the algorithm are improved. Ref. [34] proposes to construct a timeindependent model by maximizing the segmentation of time series differences to suppress the unstable prediction accuracy caused by the time distribution shift. The above studies focus on solving the problem that traditional machine learning relies on a large amount of labeled data and cannot learn knowledge from non-IID data. However, the methods do not consider the problem of lack of transferability quantification, and the adversarial-based methods [30,33] do not consider the problem of adversarial disequilibrium. Both of the above two problems will lead to the decline of the accuracy of the domain adaptation method and the robustness of the model. Therefore, this paper focuses on analyzing and researching these two problems and their solutions.

The main contributions of this paper include:


The rest of this paper is organized as follows: Section 2 analyzes two problems and summarizes the current solutions; Section 3 details the framework of the proposed method; Section 4 shows the experimental content and the analysis of the results; Section 5 concludes this article.

#### **2. Related Work**

This section briefly summarizes the current solutions for the adversarial disequilibrium and the approaches to design metrics of transferability.

#### *2.1. Adversarial Disequilibrium Problem*

For adversarial-based methods, the domain discriminator distinguishes whether they originate from the source domain or the target domain according to the features generated by the feature extractor; the domain discriminant results make a key impact on the parameter update of the model. However, the feature extractor easily wins the competition when it only retains shallow feature representation and discards the deep feature representation, which leads to the fact that the domain discriminator cannot accurately reflect the distance in distribution. The methods for solving the adversarial disequilibrium problem can be divided into two categories according to different enhancement strategies.

One way to address this problem is to combine the different metrics, which means the metric is introduced in adversarial training, and the training goal is to confuse the discriminator and reduce the metric. When adversarial disequilibrium occurs and the domain discriminator fails, the model can continue to optimize parameters according to the metric, so the method can effectively improve the training stability. Difference metrics have been maturely applied, but they are suitable for different scenarios due to differences in measurement dimensions, time overhead, gradient information, etc. Therefore, an effective selection from numerous metrics becomes the key to the feasibility of the method. Ref. [35] adopts Maximum Density Divergence (MDD) to minimize inter-domain distance and maximize intra-domain density, and embeds MDD into an adversarial-based domain adaptation framework to overcome the adversarial disequilibrium problem. Ref. [36] combines Multi-Kernel Maximum Mean Discrepancy (MK-MMD) reduces the fluctuation of the training process and maintains the adversarial equilibrium; Ref. [37] integrates MK-MMD in the partial adversarial domain adaptive network to deal with the adversarial disequilibrium problem.

Domain discriminator augmentation increases the domain information contained in the input features of the domain discriminator. From the view of the adversarial game, the method adds information to the domain discriminator for avoiding it being in a weak position in the confrontation. The stronger the domain discriminator, the better it can guide the feature extractor to learn domain-invariant features in adversarial. Ref. [38] proposes a conditional adversarial domain adaptation method, which supplements category information in the input features of the domain discriminator, and uses a multi-linear mapping method to describe the joint representation of feature information and category information. Ref. [39] combines features and labels to help model learning discriminative features, and proposed the principle of entropy minimization to set reliable pseudo-labels for the target domain. Ref. [40] proposed to normalize the conditional information so that it has the same norm as the feature, expand the conditional output norm, and improve the conditional

strategy based on the prototype. Ref. [41] proposes that the sample adversarial domain adaptively converts the noncentral sample distribution to the central sample distribution to improve the classification degree of feature distribution, and indirectly adds category information to the input of the feature extractor through clustering methods.

#### *2.2. Lack of Transferability Quantification Problem*

Domain adaptation learns domain-invariant features by reducing the distribution distance between the source domain and the target domain and then transferring knowledge from the source domain to the target domain. However, not all source domain knowledge can promote the achievement of the target task. Traditional domain adaptation methods lack the contribution differentiation of source domain knowledge. Useless information and noise in the source domain will hinder the model from achieving the target task, which will eventually lead to the degradation of method performance and the occurrence of negative transfer. The similarity-based quantification of transferability is currently an effective method for alleviating this problem.

The similarity-based transferability quantification method is based on the assumption that the higher the similarity is, the higher the transferability is, and the contribution of the source domain to the target task is distinguished according to the domain similarity, and the knowledge that is conducive to achieving the target task is selectively transferred. The key to this method is how to quantify domain similarity. Ref. [42] proposes an attention mechanism to quantify domain similarity, enhance semantic information with high transferability between domains and within domains, and improve the generalization ability and robustness of the algorithm. Ref. [43] proposes a weighted moment distance to quantify domain similarity, enhance the impact of high domain similarity data on the transfer process. Ref. [44] fuses batch spectral penalty in an adversarial-based domain adaptive network to suppress the phenomenon of forced alignment of low-transfer features, and enhance method transferability and discriminating ability.

#### **3. Proposed Method**

This section mainly introduces the novel method: Section 3.1 proposes an initial state fusion strategy to maintain the adversarial equilibrium, Section 3.2 designs a selective transfer method based on information entropy, and Section 3.3 details the architecture of models.

#### *3.1. Adversarial Equilibrium Strategy Based on Initial State Fusion*

The key of the domain discriminator augmentation is to supply domain structure information to the features, thereby improving the reliability of the domain discrimination and avoiding adversarial disequilibrium; therefore, the information introduced in the features has a crucial impact on the effectiveness of the method.

The initial state refers to the original data without feature extraction and distribution alignment, which has the most complete domain structure information, and the statistical features of the source domain and target domain data are highly distinguishable. These characteristics meet the requirements of the information for implementing domain discriminator augmentation. Therefore, this paper proposes to fuse the initial state in the input features of the domain discriminator. The reliability of the domain discrimination results is improved by supplementing the domain structure information of the input features. It avoids the domain discriminator being weak in the adversarial training and finally realizes the domain discriminator to reflect the distance of distribution implicitly and more accurately.

Due to the large dimensional difference between the intermediate features and the initial state, conventional feature fusion operations such as concat and add are easy to fail. We propose a strategy of splitting features first and then fusing them. Critical steps are shown in Figure 1. Firstly, the domain features (yellow in Figure 1) of the data are extracted using the feature extractor. Secondly, the domain features are split into several subfeatures

with dimensions equivalent to the initial state (pink in Figure 1), and subfeatures gradually dot the product with the initial state; the dot product is given by

$$a \bullet b = \sum\_{i=1}^{n} a\_i b\_i = a\_1 b\_1 + a\_2 b\_2 + \dots + a\_n b\_n \tag{1}$$

where *a* and *b* represent the subfeature and the initial state, respectively, and *ai* and *bi* represent the *i*-th element.

Each subfeature will perform the operation of (1) with the initial state; new feature elements are merged to form the fused feature (red in Figure 1). Finally, the fused feature is input into the domain discriminator for domain discrimination.

**Figure 1.** Initial state fusion strategy.

#### *3.2. Transferability Quantification Based on Information Entropy*

The quantification of transferability is based on the premise that domain similarity and transferability are positively correlated. In the adversarial domain adaptation method, the information entropy of domain discrimination can objectively reflect domain similarity. Therefore, we propose a transferability quantification method based on information entropy, which realizes the transfer source domain samples selectively and inhibits the occurrence of negative migration to a certain extent.

In information theory, information entropy is used to measure the information content of an event. The smaller the probability of an event, the greater the amount of information it contains, and the information entropy also increases. *p*(*xi*) is used to represent the probability density of event *xi* ∈ *X*, *i* = 1, 2, . . . , *n*, and the information entropy of event *X* is calculated by

$$H(X) = -\sum\_{\mathbf{x}\_i \in X} p(\mathbf{x}\_i) \ln p(\mathbf{x}\_i) \tag{2}$$

The domain discrimination is the basis for the adversarial domain adaptation method to reflect the degree of feature distribution alignment. The essence of domain discrimination is a two-class prediction task of the sample belonging to the source domain or the target domain. When the output layer of the domain discriminator is activated by the Softmax function, the output after activation is two predicted values whose sum equals 1, denoted as [*ps*, *pt*], which respectively represent the probability that the domain discriminator thinks the sample belongs to the source domain or the target domain. The Softmax activation is calculated by

$$S\_{\bar{i}} = \frac{\mathbf{e}^{\bar{i}}}{\sum\_{j=1}^{\mathfrak{n}} \mathbf{e}^{\mathfrak{l}}} \tag{3}$$

The information entropy of the domain prediction value is used to reflect the domain similarity. The closer the outputs *ps* and *pt* of the domain discriminator are, the more successfully the features of the source domain sample confuse the domain discriminator, making it impossible to make accurate domain discrimination. Furthermore, the high domain similarity means that the information entropy of the domain prediction value is

maximized, and the source domain samples that generate this feature should be given a higher weight during the transfer process. The weight is calculated by

$$
\omega\_{\bar{i}} = \exp\left[-p\_s \ln(p\_s) - p\_t \ln(p\_t)\right] - 1 \tag{4}
$$

where the exponential is the information entropy of *ps* and *pt*.

We propose to quantify transferability based on information entropy to tackle the problem of the lack of transferability quantification method, by weighting the source domain samples according to the quantification results to transfer knowledge selectively. The process of transferability quantification is shown in Figure 2. Firstly, the features of samples are extracted. Samples with high domain similarity are shown as having more domain-invariant features in the feature space, and the feature distribution of the source domain and target domain has a high degree of coincidence. Then, make the domain discrimination; the smaller the difference between the *ps* and *pt* output by the domain discriminator, the higher the similarity that the samples have, and the richer the transferable knowledge that is contained. At this time, the information entropy of the domain discrimination increases. Finally, calculate the weights; samples with higher transferability cause a greater impact on the transfer.

**Figure 2.** Transferability quantification process.

#### *3.3. A Novel Adversarial Domain Adaptation Method*

#### 3.3.1. Model Structure

The one-dimensional convolutional neural network and Bidirectional Long Short Term Memory Networks (1DCNN-BiLSTM) has both the efficient feature extraction ability of 1DCNN and the advantages of BiLSTM in describing the dependencies of a time series [45,46]. We use 1DCNN to build a feature extractor and BiLSTM to build a predictor; the model structure is shown in Figure 3. The model consists of three basic modules, a feature extractor, predictor, and domain discriminator. In addition, the initial state fusion module (the light blue module in Figure 3) is added before the domain discriminator, and the transferability quantification module (light green module in Figure 3) is added after the domain discriminator.

The model hyperparameters are shown in Table 1. The column hyperparameter are the properties required to build the model, followed by the corresponding values. The first line indicates that the feature extractor has three layers of 1DCNN. The values in the brackets in the second row represent the respective kernel size of the aforementioned three layers. The source domain and target domain data are convolved with 1DCNN to generate domain-invariant features. Dropout [47] is used in the BiLSTM layer of the predictor to randomly suppress neurons to avoid model overfitting. The features are fused with the initial state, and domain discriminant results are used to calculate the total loss.

**Figure 3.** Model structure. C represents 1DCNN, L represents BiLSTM, F represents fully connected layer.



The domain discriminant loss is composed of the cross-entropy between the domain discriminantion and the real domain label, which is calculated by

$$Loss\_{dcls} = \frac{1}{n\_s} \sum\_{i=1}^{n\_s} L\_{c\varepsilon}(d\_{s\prime}^i y\_s^{di}) + \frac{1}{n\_t} \sum\_{i=1}^{n\_t} L\_{c\varepsilon}(d\_{t\prime}^i y\_t^{dt}) \tag{5}$$

The prediction loss consists of two parts: the weighted source domain prediction loss and the target domain prediction loss, which is calculated by

$$Loss\_{prod} = \frac{1}{n\_s} \sum\_{i=1}^{n\_s} \omega\_i (y\_s^i - y\_s^{pi})^2 + \frac{1}{n\_t} \sum\_{i=1}^{n\_t} (y\_t^i - y\_t^{pi})^2 \tag{6}$$

The total loss of the model is composed of the domain discrimination loss and the prediction loss, which is calculated by

$$Loss = Loss\_{dcls} + Loss\_{pred} \tag{7}$$

where subscript *s* indicates that the variable belongs to the source domain, subscript *t* indicates that the variable belongs to the target domain, *n* is the number of samples in the domain; *d<sup>i</sup>* is the domain label, *ydi* is the predicted domain label, *y<sup>i</sup>* is the true value, *ypi* is the prediction, *ω<sup>i</sup>* is the weight, and *Lce* is the cross-entropy loss function.

#### 3.3.2. The Critical Steps of the Algorithm

The algorithm flow is shown in Figure 4. The critical steps of each epoch during training include:


**Figure 4.** Algorithm flow chart.

#### **4. Experimental Setup and Results**

In this section, we extensively evaluate our approach and compare it with state-ofthe-art domain adaptation methods. We also provide a detailed analysis of the proposed framework, demonstrating empirically the effect of our contributions.

#### *4.1. Datasets*

We evaluate the proposed approach to the BDGP2 dataset [48]. The time range is from 2016 to 2017. The sampling interval is 1 h. The sampling value includes power load, heating, cooling water, steam, and other meter data; in addition, this data set integrates outdoor temperature, humidity, cloud cover, and other climatic factors that can affect power consumption.

Four residential buildings are selected for analysis, namely Bear\_lodging\_Evan (domain A), Robin\_lodging\_Renea (domain B), Rat\_lodging\_Ardell (domain C), and Fox\_lodging\_Angla (domain D); the load has a periodic characteristic with the user's living habits, which is shown in Figure 5. We use the Augmented Dickey Fuller (ADF) to test that the time series is stationary. The *p* value is 0.00000218, and the absence of missing values is also the important reason for selecting the mentioned building's data. The variables of the inputs are shown in Table 2.

**Figure 5.** Power load for the four buildings. (**a**) Building A; (**b**) building B; (**c**) building C; (**d**) building D.



The experiment adopts single-step time series forecasting, the input are the variables in Table 2 of the first 24 h in each sliding window, and the true value is the load of the next hour. To verify the effectiveness and accuracy of the proposed method, we construct 12 transfer tasks for each method, and each task is denoted as S→T, which means the S is the source domain and the T is the target domain. When a building is selected as the source domain, we use all the samples of the building as the source domain data train set. When another building is selected as the target domain, we use 10% of the building's samples as the target domain train set and 20% of the samples as the target domain test set; the remaining 70% of the samples are not used. We use samples from two different buildings to create the condition of non-IID by retaining only a few samples of the target building to simulate the lack of data in the target domain.

#### *4.2. Implementation Details*

The experiments in this paper are all implemented under the same framework; the programming language is Python3.7.11, the deep learning framework is Pytorch1.10.1, the CUDA11.3, the CUDNN8.2, and the operating system is Windows 10. The CPU is Intel i5-11400H, the base frequency is 2.7 GHz, the memory is 16 G, the GPU is RTX3050Ti, and the GPU memory is 4 G.

The experiment in this paper adopts the same train setting; the optimizer is Adam, the max epoch is 50, and the batch size is 32, the initial parameters are generated by Pytorch-1.10.1 defaulted, and the learning rate can be calculated as

$$LR = \frac{0.01}{(1 + 10 \ast p)^{0.75}} \tag{8}$$

where *LR* is the learning rate of the current epoch, and *p* is the ratio of the current epoch round to the max epochs.

#### *4.3. Results*

The objective indicators for the experimental evaluation of prediction accuracy are Root Mean Square Error (RMSE), Mean Absolute Error (MAE), and Mean Absolute Percentage Error (MAPE).

RMSE is sensitive to outliers, and when it is small, it can be considered that the method outputs less predictable values with great deviations. MAE describes the absolute error between the prediction value and the true value, which is the most intuitive. MAPE converts the error value into an error rate, which can evaluate the method performance without considering the order of magnitude of the data.

$$\text{RMSE} = \sqrt{\frac{1}{n} \sum\_{i=0}^{n} \left( y\_i - y\_{pi} \right)^2} \tag{9}$$

$$\text{MAE} = \frac{1}{n} \sum\_{i=0}^{n} |y\_i - y\_{pi}| \tag{10}$$

$$\text{MAPE} = \frac{100\%}{n} \sum\_{i=0}^{n} \left| \frac{y\_i - y\_{pi}}{y\_i} \right| \tag{11}$$

where *n* is the number of test samples, *yi* is the true value, and *ypi* is the prediction.

The proposed method was compared with FineTune (FT) [49], Wasserstein Distance Guided Representation Learning (WDGRL) [50], Deep Adaptation Networks (DAN) [51], Domain Adversarial Neural Networks–Long Short Term Memory Networks (DANN-LSTM) [52], and Deep CORAL (DCORAL) [53].

FT is the lightest and most widely used method for knowledge transfer. DAN and DCORAL use MMD and CORAL to measure the distance between domains, respectively, which are widely used in discrepancy-based methods. The proposed method, WDGRL, and DANN-LSTM are based on adversarial; however, the difference is our consideration, and attempts to alleviate the adversarial disequilibrium problem. The performances of RMSE, MAE, and MAPE are shown in Tables 3–5. The last row represents the average performance of each method in different tasks, and the best performance of each task is highlighted in bold.

The prediction error of the proposed method is smaller than other methods in most of the adaptation tasks. The proposed method reduces RMSE by 1.53, MAE by 1.29, and MAPE by 1.53%. The reduction in RMSE proves that the method predicts fewer outliers and has a better stability. MAE is used to measure the absolute error, and MAPE is used to measure the error rate. The reduction in the two factors proves that the proposed method can improve the generalization ability of the model and the prediction accuracy effectively.


**Table 3.** RMSE Performance. The best performance of each task is highlighted in bold.

**Table 4.** MAE Performance. The best performance of each task is highlighted in bold.


**Table 5.** MAPE Performance. The best performance of each task is highlighted in bold.


In the domain adaptation tasks of the same target domain but different source domains, such as B→A, C→A, and D→A, the prediction error fluctuation of the method due to the change of the source domain is the slightest, which proves the transferability quantification based on information entropy success selectively transfers the knowledge in the source domain and mitigates negative effects where the low-correlation samples in the source domain lead to negative transfer.

The difference between the proposed method and other adversarial domain adaptation methods (DANN-LSTM and WDGRL) is the addition of the initial state fusion module to maintain the adversarial equilibrium. The proposed method has advantages in multiple tasks, and reduces RMSE by 1.57, MAE by 1.42, and MAPE by 2.2%; the adversarial equilibrium strategy based on initial state fusion effectively alleviates the adversarial disequilibrium problem. Domain structure information is supplemented in the intermediate features, which increases the reliability of domain discrimination. The domain discriminator supervises the feature extractor to achieve feature distribution alignment more effectively, thereby improving prediction accuracy.

The power load forecasting curves of the proposed method for one week from 0:00 on 14 March 2016, to 0:00 on 21 March 2016, are shown in Figure 6. The fitting degree between the prediction and the true value is high. The proposed method improves the load prediction accuracy effectively. However, the prediction error of the method for local peaks and valleys in the four fields is relatively large, and the power load mutation in field C is the most frequent, which means the user's personalized behavior is the most significant; thus, the prediction error of peaks is the largest, indicating that the prediction is easily affected by user personalized behavior. The transfer is not precise enough. Therefore, it is necessary to enhance the method's ability to learn domain-specific features, achieve more detailed selective transfer, suppress the occurrence of negative transfer more effectively, and further improve the prediction accuracy.

**Figure 6.** The power load forecasting curves for four buildings. (**a**) Task B→A; (**b**) task C→B; (**c**) task A→C; (**d**) task C→D.

Feature visualization is an important tool to measure the alignment degree of feature distribution. T-SNE [54] is widely used to visualize the high-dimensional data distribution in domain adaptation. The feature visualization results are shown in Figure 7. Red points correspond to the source domain, while blue ones correspond to the target domain. The more similar the source and target domain features are, the more effective the method is. In the proposed method, the source domain and target domain features have the smallest deviation, and the overlap between the two has a large proportion. Upon further analysis, it can be found that the features extracted and aligned by the proposed method are clustered, and the boundaries of each cluster are sharper than the baseline method. Clusters represent the features that the method extracts from different aspects, it indicates that the initial state fusion strategy improves the domain discrimination ability of the domain discriminator, further supervising the feature extraction to extract domain-invariant features effectively during the adversarial training. There are few features that the proposed method fails to align relative to the baseline method, indicating that the proposed method effectively suppresses the low-correlation information in the source domain, and retains information that can be transferred to the target effectively.

**Figure 7.** Feature visualization for different methods. Red points correspond to the source domain, while blue ones correspond to the target domain. (**a**) WDGRL; (**b**) DAN; (**c**) DANN-LSTM; (**d**) DCORAL; (**e**) Ours.

#### **5. Conclusions**

This paper focuses on the adversarial domain adaptation method and its application in power load forecasting. Domain adaptation alleviates the problem where traditional machine learning methods are limited by the amount of labeled data and the condition of IID; this has a strong significance for promoting intelligent power load forecasting systems. The adversarial domain adaptation method faces the problems of adversarial disequilibrium and a lack of transferability quantitation. This paper proposes corresponding solutions to the above two problems and conducts sufficient experimental verifications. The experimental results in the BDGP2 dataset prove that the proposed method gains a high power load prediction accuracy. This paper provides a research reference for solving the problems of adversarial disequilibrium and a lack of transferability quantitation, and provides an application reference for implementing power load forecasting based on the adversarial domain adaptation method. Furthermore, due to the strong personalization of users' electricity consumption behavior, the method does not perform well in the local peaks and valleys. Therefore, it is necessary to enhance the ability of the method to learn domain-specific features to achieve more refined selective transfer. Our future work will explore how to suppress the negative transfer better, and improve the prediction accuracy more effectively.

**Author Contributions:** Conceptualization, M.H. and J.Y.; methodology, M.H. and J.Y.; software, J.Y.; validation, M.H. and J.Y.; writing—original draft preparation, J.Y.; writing—review and editing, M.H.; funding acquisition, M.H. All authors have read and agreed to the published version of the manuscript.

**Funding:** This research was funded by two Guangdong Natural Science Foundation Projects (Grant No. 2021A1515011496 and Grant No. 2022A1515011370).

**Institutional Review Board Statement:** Not applicable.

**Informed Consent Statement:** Not applicable.

**Data Availability Statement:** Not applicable.

**Conflicts of Interest:** The authors declare no conflict of interest.

#### **Abbreviations**

The following abbreviations are used in this manuscript:


#### **References**


### *Article* **Deep Reinforcement Learning-Based RMSA Policy Distillation for Elastic Optical Networks**

**Bixia Tang 1, Yue-Cai Huang 2,\*, Yun Xue 1,2 and Weixing Zhou 1,2**


**\*** Correspondence: huangyuecai@scnu.edu.cn

**Abstract:** The reinforcement learning-based routing, modulation, and spectrum assignment has been regarded as an emerging paradigm for resource allocation in the elastic optical networks. One limitation is that the learning process is highly dependent on the training environment, such as the traffic pattern or the optical network topology. Therefore, re-training is required in case of network topology or traffic pattern variations, which consumes a great amount of computation power and time. To ease the requirement of re-training, we propose a policy distillation scheme, which distills knowledge from a well-trained teacher model and then transfers the knowledge to the to-be-trained student model, so that the training of the latter can be accelerated. Specifically, the teacher model is trained for one training environment (e.g., the topology and traffic pattern) and the student model is for another training environment. The simulation results indicate that our proposed method can effectively speed up the training process of the student model, and it even leads to a lower blocking probability, compared with the case that the student model is trained without knowledge distillation.

**Keywords:** routing, modulation and spectrum assignment; elastic optical networks; deep reinforcement learning; knowledge distillation

**MSC:** 68T07

#### **1. Introduction**

Accompanied with the rapid development of the Internet technology, services such as audio and video conferencing, webcasting, and cloud computing have become popular. The growing demand of these services leads to an exponential increase in data traffic and poses great challenges to the bearing communication networks [1]. Elastic optical networks (EONs) have been regarded to be a promising candidate for the next-generation optical communications [2,3]. In EONs, the spectrum is divided into narrow frequency slots, and traffic requests can be served by different numbers of frequency slots according to their data rate requirements and the quality of the connection. This flex-grid scheme greatly increases the network resource allocation flexibility compared to the traditional wavelength-division multiplexing (WDM)-based networks [4]. Meanwhile, it also brings difficulties for the network resource management.

The routing, modulation, and spectrum assignment (RMSA) [5] is a key problem for the EONs resource management. Due to the complexity, the RMSA problem is generally divided into two sub-problems: the routing and spectrum assignment [6], each of them tackled by heuristic solutions [7–10]. For the routing sub-problem, representative approaches include fixed routing, fixed alternative routing [11,12], and adaptive routing [4]. For the spectrum assignment sub-problem, there are the first-fit [13] and random-fit schemes and other methods. However, these rule-based heuristics, mostly relying on researchers' cognition, cannot comprehensively capture the effect of the complex network conditions.

**Citation:** Tang, B.; Huang, Y.-C.; Xue, Y.; Zhou, W. Deep Reinforcement Learning-Based RMSA Policy Distillation for Elastic Optical Networks. *Mathematics* **2022**, *10*, 3293. https://doi.org/10.3390/ math10183293

Academic Editors: Jianping Gou, Weihua Ou, Shaoning Zeng and Lan Du

Received: 14 August 2022 Accepted: 8 September 2022 Published: 11 September 2022

**Publisher's Note:** MDPI stays neutral with regard to jurisdictional claims in published maps and institutional affiliations.

**Copyright:** © 2022 by the authors. Licensee MDPI, Basel, Switzerland. This article is an open access article distributed under the terms and conditions of the Creative Commons Attribution (CC BY) license (https:// creativecommons.org/licenses/by/ 4.0/).

To overcome the above limitation, deep reinforcement learning (DRL) has recently been introduced to the RMSA problem [14–19], where the RMSA policies are parameterized by deep neural networks and the RMSA policies are improved through interactions with the optical network environment. Many of them have achieved a better performance than heuristic methods. However, the learned policies of these DRL-based approaches are highly related to the training environment, such as the traffic patterns and the network topologies. However, in a practical network, the traffic patterns and the network topologies are very likely to be changed. For example, the traffic volume from commercial and residential areas varies from working hours to off-duty hours. Meanwhile, the network topology becomes different in the case of a network failure or disasters. Once the environment is changed, the effectiveness of the learned RMSA policies deteriorates significantly [20]. Therefore, re-training is required and consumes a lot of computing power and time. To ease the requirement of re-training, Chen et al. [20] investigated the transfer learning (TL) between different network topologies. They first trained and obtained a model from source tasks, and then copied the parameters of the trained model as the starting point when training the target task. The limitation is that the target task needs to use the same neural network architecture with the source task. Moreover, the effect of traffic variation has not yet been investigated.

In this paper, we extend our previously published conference paper [19] and apply policy distillation [21] to the RMSA problem, combining knowledge distillation [22] with reinforcement learning (RL). First, a teacher model is trained for one task with a specific traffic pattern and network topology. Then, the well-trained policy of a teacher model is distilled, and the knowledge is transferred to a student model with a different traffic pattern and network topology, to assist the training of the student model. A major difference between our work and the transfer learning in [20] is that the student model (target) and the teacher model (source) can be different. This allows knowledge transfer in a broader context. We have applied the proposed design in three different application scenarios, which consider different traffic patterns and different topologies. The simulation results demonstrate that policy distillation can accelerate the training speed of the student model and improve its performance.

The rest of this paper is organized as follows. Section 2 surveys the related work. In Section 3, we briefly introduce some basics of RL. In Section 4, we introduce the proposed policy distillation architecture, including the problem formulation and the training of the teacher model and the student model. Then, we present the simulation results in Section 5. Lastly, we conclude the paper in Section 6.

#### **2. Related Work**

#### *2.1. Deep Reinforcement Learning in RMSA of EONs*

In recent years, research has emerged by exploiting DRL to solve the routing and spectrum assignment problem in the optical networks. Chen et al. [23] proposed a DRL framework, namely DeepRMSA, for the optical network management and resource allocation. The DeepRMSA uses the deep Q-learning algorithm for the training. Because the input-state representation has a significant impact on the performance, a series of work has explored different state representations. Chen et al. [14] defined a list of features of the candidate paths. Yan et al. [24] introduced the concept of a multi-modal optical network by considering the topology modality and routing modality to represent different features of the optical network and uses the actor–critic (AC) algorithm for the training. Suárez-Varela et al. [25] captured the key relationships between the links in the input-state representation, making the DRL agents easier and faster to learn. The same team then [26] introduced the Graph Neural Networks to further capture the network-state features. Xu et al. [18] introduced a link–path relationship matrix to capture the path information of the elastic optical networks.

There are some other works exploring various aspects by applying DRL in the optical network management. Huang et al. [15] proposed a DRL-based self-learning routing

scheme for the WDM-based networks. It allows the agent to continuously improve its performance by self-comparison. Koch et al. [27] adopted the RL algorithm for parameter optimization in EONs. In addition, a cost-efficient routing, modulation, wavelength, and port assignment algorithm based on DRL was developed in [28]. Moreover, Li et al. [29] investigated collaborative DRL agents for multi-domain provisioning in multi-area optical networks.

#### *2.2. Transfer Learning in EONs*

Transfer learning in EONs has recently attracted research interest. Yao et al. [30] proposed a TL-based resource optimization strategy for predicting the spectrum defragmentation time in space-division multiplexing EONs. Liu et al. [31] applied a TL approach to implement a scalable quality-of-transmission estimation in EONs. To our knowledge, the most relevant work of this paper is [20], where the authors propose a knowledge transfer design that alleviates scalability issues by transferring knowledge between RMSA agents with different tasks through a modular DRL agent structure. As mentioned in Section 1, its limitation is that the target task needs to use the same neural network architecture with the source task. In our previously published conference paper [19], we propose a knowledge distillation scheme based on DRL to achieve RMSA policy scalability in EONs. This paper extends [19] in three aspects: (1) the authors of [19] only consider different traffic patterns, while this paper considers different traffic patterns and topologies; (2) the training algorithm is updated to the most advanced asynchronous advantage actor–critic (A3C); and (3) many more simulation results are provided to verify our proposal.

#### **3. Preliminaries**

As this work is based on RL, we first explain some basics about RL for the facility of the readers.

#### *3.1. Reinforcement Learning*

Reinforcement learning is an important branch of machine learning. Many RL tasks can be modeled as Markov decision processes (MDP), expressed as tuples {*S*, *A*, *R*, *P*}. *S* is the state space of the environment; *A* is the action space of the agent; *R* is the reward function; and *P* represents the state transition probabilities. In the RL framework, the agent interacts with the environment. Specifically, given a state *st* ∈ *S*, the agent performs an action *at* ∈ *A* according to a *policy*, and then the environment emits a reward *rt* and changes its state from *st* to a new state *st*+<sup>1</sup> according to the state transition probabilities *P*. In this process, the agent influences the environment by taking the actions, and the environment feeds back reward *rt* to the agent, which will guide the agent to choose better actions. The goal of the agent is to improve its action policy by optimizing the cumulative future reward.

#### *3.2. Asynchronous Advantage Actor–Critic*

The RL agent needs to be trained by some training algorithm. In this work, we use the A3C algorithm [32] for the training. It is the asynchronous multi-threaded version of the AC algorithm [33]. The AC algorithm uses a policy network (also called actor) to select the action and a value network (also called critic) to evaluate actions. The actor updates its policy (i.e., action selection probability) according to the critic. Through the agent–environment interaction, the critic improves its evaluation accuracy, and the actor improves its policy gradually.

A3C makes the AC algorithm much easier and faster to converge. It adopts a multithreaded method, where each thread has an independent actor–critic pair interacting with a copy of the environment. Each thread collects the exploration experience from its environment copy and then regularly updates a shared global actor–critic pair. By doing this, the algorithm converges faster.

#### **4. Policy Distillation Design with EONs**

*4.1. Elastic Optical Networks*

In the EONs, the RMSA problem is to establish corresponding end-to-end paths and allocate appropriate frequency slots (FSs) for different traffic requests according to their data rate requirements. Furthermore, RMSA [6] algorithm must satisfy the spectrum contiguity constraint and spectrum continuity constraint. The topology of the EONs can be denoted by a graph *G*(*V*, *E*), where *V* and *E* represent the set of nodes and links, respectively. When a traffic request, denoted by *TR*(*vs*, *vd*, *b*), arrives, RMSA is needed from the source node *vs* ∈ *V* to the target node *vd* ∈ *V* with the required bandwidth *b*. The routing algorithm first calculates all possible paths from the source to the destination, then selects one path *Pvs*,*vd* from the *K*-shortest paths. Corresponding number of FS *n* required on the selected path *Pvs*,*vd* can be calculated by Equation (1) and Table 1.

$$m = \left\lceil b/(\mathcal{W} \cdot m(P\_{\mathbb{P}\_{\mathbb{S}}, \mathcal{V}\_d})) \right\rceil + 1 \tag{1}$$

*W* denotes the spectrum width of each FS; *m*(*Pvs*,*vd* ) ∈ [1, 2, 3, 4] corresponds to the modulation format selected according to the physical length of *Pvs*,*vd* [34]; and one FS is used for the guard band. Then, *n* allocated FSs must be contiguous (spectrum contiguity constraint), and each link along the demand path *Pvs*,*vd* must be assigned the same *n* contiguous FSs (spectrum continuity constraint).


**Table 1.** Transmission reach for different modulation formats [35].

#### *4.2. Policy Distillation Scheme*

We propose to integrate policy distillation into the RMSA problems of the optical networks. The whole architecture is shown in Figure 1. Two models, namely the teacher model and the student model, are trained for different tasks. First, a teacher model is trained for one task with specific traffic pattern and network topology. Then, the well-trained policy of the teacher model is distilled, and the knowledge is transferred to a student model with a different traffic pattern and network topology, to assist the training of the student model. There are three steps in the training process:


The RMSA policy for the student task is learned by the student model via Steps 2 and 3. Step 2 distills the knowledge from the well-trained policy network of the teacher model and transfers the knowledge to the student model to assist its training.

#### *4.3. State, Action, and Reward*

The optical network RMSA problem can be modeled as an MDP and solved in an RL-based framework. In the RL framework, three essential elements are the state, the action, and the reward. We consider the state only when there is a new traffic request. The state *st* is a 1 × 5*K* vector containing spectrum utilization information on the *K*-shortest candidate paths of the traffic request [14]. For each candidate path, we considered five elements of spectrum utilization as follows:


In addition, the action of the RMSA problem is to choose one path from the *K*-candidate paths and allocate spectrum on the selected path based on the first-fit strategy. Therefore, action *at* ∈ {1, 2, ··· , *K*}. The reward *rt* is defined to be 1 when the traffic request is accepted, and −1 otherwise.

#### *4.4. Teacher Model*

According to Step 1 in Figure 1, a teacher model is first trained, which is illustrated in more detail in Step 1 of Figure 2. We use DRL to train the teacher model and obtain the RMSA policy to optimize the EONs resource management. The A3C algorithm is adopted for the training, where multiple local actor–critic pairs are trained by interacting with the copies of the environment in parallel, and then periodically update the global actor–critic pair. The actor and critic are parameterized by two neural networks: the policy network *π*(*at*|*st*; *θp*,T) and the value network *V*(*st*; *θv*,T). The policy network *π*(*at*|*st*; *θp*,T) is used to generate the policy of RMSA, which is represented by a probability distribution. The value network *V*(*st*; *θv*,T) is used to obtain the value of *st* and evaluate the RMSA policy. T denotes the teacher model. *θp*,<sup>T</sup> and *θv*,<sup>T</sup> are the parameters of the policy and the value network, respectively. The global parameters maintained by the A3C algorithm are represented as *θ*∗ *<sup>p</sup>*,<sup>T</sup> and *θ*<sup>∗</sup> *<sup>v</sup>*,T.

**Figure 1.** Overview of the policy distillation design with EONs.

**Figure 2.** Detailed illustration of policy distillation design with EONs.

The details of training process for the teacher model are given in Algorithm 1. First, we initialize the experience buffer *D* to empty and set the initial exploration rate *ε* to 1. In line 3, each actor–critic pair thread parameters are firstly updated by the global parameters. Notice that for a general DRL task that can be modeled as a Markov decision process {*S*, *A*, *R*, *P*} mentioned in Section 3, the state transition from *st* to *st*+<sup>1</sup> follows a probability distribution *P*. However, for the RMSA task in this paper, as the state space is extremely large, state transitions are difficult to be modeled. Therefore, the RMSA task here belongs to the model-free MDP and can only be optimized through samples. In lines 6–10, during the sampling, we first input the 1 × 5*K*-dimensional state *st* into the policy and value networks. Then, the policy network outputs a 1 × *K*-dimensional probability distribution *π*(*at*|*st*; *θp*,T), where each probability ranges from 0 to 1, and the summation of the output *K* probabilities is 1. The value network outputs a value *V*(*st*; *θv*,T), which is a real number. Finally, we store the sample (*st*, *at*,*rt*, *V*(*st*; *θv*,T)) generated by the interaction of the agent and the environment in an experience buffer *D*. When the size of experience buffer reaches 2*N* − 1, we perform training based on the first *N* samples (lines 13–19). For each sample at time *t*, the advantage function is calculated in line 15. To obtain the advantage function, we first make cumulative the discounted reward for this sample (we only consider an episode consisting of *N* consecutive samples after this sample and ignore the discounted reward after *N* samples) by,

$$Q\_{\pi}(s\_t, a\_t; \theta\_{p, \mathbb{T}}) = \sum\_{i=0}^{N-1} \gamma^i r\_{t+i\prime} t \in \{t\_0, t\_0 + N - 1\},\tag{2}$$

where *γ* is the discount factor, 0 < *γ* < 1. Then, the advantage of each action taken can be obtained by,

$$A(s\_t, a\_t; \theta\_{p, \mathbb{T}'}, \theta\_{v, \mathbb{T}}) = Q\_{\pi}(s\_t, a\_t; \theta\_{p, \mathbb{T}}) - V(s\_t; \theta\_{v, \mathbb{T}}).\tag{3}$$

Equation (3) indicates how much better the actual selected action is than the average. Note that an episode is defined to consist of *N* consecutive samples, where *N* is equal to batch size. This way, all samples needed to calculate the advantage function can be found in the experience buffer [14].

Then, the objective function of policy network *Lθp*,<sup>T</sup> and the loss function of value network *Lθv*,<sup>T</sup> can be used to calculate the gradient of the policy and the value network, and then the global parameters *θ*∗ *<sup>p</sup>*,<sup>T</sup> and *θ*<sup>∗</sup> *<sup>v</sup>*,<sup>T</sup> can be updated according to the gradient (line 18). *Lθp*,<sup>T</sup> and *Lθv*,<sup>T</sup> can be expressed as follows:

$$\begin{split} L\_{\boldsymbol{\theta}\_{p,\mathcal{T}}} &= -\sum\_{t=t\_0}^{t\_0+N-1} A(s\_t, a\_t; \boldsymbol{\theta}\_{p,\mathcal{T}}, \boldsymbol{\theta}\_{v,\mathcal{T}}) \log \pi(a\_t|s\_t; \boldsymbol{\theta}\_{p,\mathcal{T}}) \\ &- a \sum\_{t=t\_0}^{t\_0+N-1} \sum\_{a\_t \in \{1,2,\cdots,K\}} \pi(a\_t|s; \boldsymbol{\theta}\_{p,\mathcal{T}}) \log \pi(a\_t|s; \boldsymbol{\theta}\_{p,\mathcal{T}}), \end{split} \tag{4}$$

$$L\_{\theta\_{\mathbf{v},\mathbf{T}}} = \sum\_{t=t\_0}^{t\_0+N-1} A(s\_{t\prime}, a\_t; \theta\_{p,\mathbf{T}\prime}, \theta\_{\mathbf{v},\mathbf{T}})^2. \tag{5}$$

To increase the diversity of the actions, the second term of *Lθp*,<sup>T</sup> introduces the policy entropy to improve the agent's ability to explore the environment, and *α* controls the strength of the entropy regularization term. *β* and *η* are the learning rates.

The stopping criterion is that the model has converged. Specifically, we trace the changing of the average blocking probabilities. If the difference between consecutive average blocking probabilities is smaller than a pre-defined threshold, we regard the model to be converged and therefore criterion is satisfied. Through the above steps with Algorithm 1, we train a teacher model that can improve its RMSA policy under a certain task.

#### **Algorithm 1** Training algorithm of the teacher model.

	- *θp*,<sup>T</sup> ← *θ*<sup>∗</sup> *<sup>p</sup>*,T, *θv*,<sup>T</sup> ← *θ*<sup>∗</sup>
	- *θ*∗ *<sup>p</sup>*,<sup>T</sup> ← *θ*<sup>∗</sup> *<sup>p</sup>*,<sup>T</sup> − *βdθv*,<sup>T</sup> and *θ*<sup>∗</sup> *<sup>v</sup>*,<sup>T</sup> ← *θ*<sup>∗</sup> *<sup>v</sup>*,<sup>T</sup> − *ηdθv*,T.

21: **end while**

#### *4.5. Student Model*

Due to the similarities between tasks, we try to use the well-trained teacher model to "teach" the student model to learn the optimal RMSA policy for student tasks, as shown in Step 2 of Figure 1. This process is described in more detail in Step 2 of Figure 2. In this way, the student model adjusts its training according to the experience knowledge of the teacher model, in order to expect faster training speed or better performance.

Distillation is a method to transfer experience knowledge from a teacher model T to a student model S. To transfer the knowledge, a straightforward method is to minimize the distance between the output of the student model and the teacher model. Because the action probability distribution of the output of policy network reflects the learned RMSA policy, we use cross-entropy to fit the output of the two models' policy networks. In order to transfer more knowledge, the teacher model can utilize a relaxed (higher-temperature) softmax than the one used during training [21]. Choose a temperature *τ*, the outputs of the teacher model's and the student model's policy network are processed by softmax functions to obtain the distributions: *qτ*(*st*, *θp*,T) and *qτ*(*st*, *θp*,S),

$$q\_{\tau}(s\_t, \theta\_{p, \mathcal{T}}) = \text{softmax}(\frac{\pi(a\_t|s\_t; \theta\_{p, \mathcal{T}})}{\tau}),\tag{6}$$

$$q\_{\pi}(s\_t, \theta\_{p, \mathfrak{S}}) = \text{softmax}(\frac{\pi(a\_t|s\_t; \theta\_{p, \mathfrak{S}})}{\pi}).\tag{7}$$

The softmax(·) is defined by:

$$\text{softmax}(i) = \frac{e^i}{\sum\_{j} e^j}.\tag{8}$$

Algorithm 2 describes in detail the training process of the student model. The sampling part is same as the teacher model. When the training conditions are met, we first calculate the cumulative discounted reward for each sample (we only consider the first *N* samples and ignore the discounted reward after *N* samples) by:

$$Q\_{\pi}(s\_t, a\_t; \theta\_{p, \mathbb{S}}) = \sum\_{i=0}^{N-1} \gamma^i r\_{t+i\prime} t \in \{t\_0, t\_0 + N - 1\} \tag{9}$$

The advantage of each action can be calculated by:

$$A(\mathbf{s}\_t, a\_t; \theta\_{p, \mathbb{B}}, \theta\_{\mathbf{v}, \mathbb{B}}) = Q\_\pi(\mathbf{s}\_t, a\_t; \theta\_{p, \mathbb{B}}) - V(\mathbf{s}\_t; \theta\_{\mathbf{v}, \mathbb{B}}).\tag{10}$$

Let *H*(·, ·) be the cross-entropy function. Then, the similarity between the student model's and the teacher model's policy network can be increased by minimizing the objective function given below:

$$L\_{\theta\_{p,\mathbb{S}}}^{PD} = \sum\_{t=t\_0}^{t\_0+N-1} H(q\_\tau(s\_{t\prime}\theta\_{p,\mathbb{T}})\_\prime q\_\tau(s\_{t\prime}\theta\_{p,\mathbb{S}})).\tag{11}$$

During the distillation stage, although the value network did not directly obtain the experience knowledge from the teacher model by cross-entropy fitting, the output of the student model's policy network trained via policy distillation affected the generation of the samples, which indirectly affects the training of the value network.

The loss function *Lθv*,<sup>S</sup> of the student model's value network during distillation is given by:

$$L\_{\theta\_{v,\mathfrak{J}}} = \sum\_{t=t\_0}^{t\_0+N-1} A(s\_t, a\_t; \theta\_{p,\mathfrak{J}'}, \theta\_{v,\mathfrak{J}})^2. \tag{12}$$

By optimizing the objective and the loss function above, we can transfer knowledge from the teacher model to the student model.

When the student model is initialized, its DRL agents start from tabula rasa, which means that they have no professional knowledge about the optical network environment of the task, and therefore, they need to learn the optimal RMSA policy by exploring the state and action space for a long time. Therefore, we transfer the knowledge of the teacher model to the poorly performing student model through distillation to reduce ineffective exploration of the student model.

However, although the teacher model is well-trained for the teacher tasks, in the process of policy distillation, its policy has limitations guiding the training of the student model for the student tasks. Therefore, we conduct the policy distillation for the beginning *M TR*(*s*, *d*, *b*) requests, and then let the student model learn by itself as shown in Step 3 of Figure 2. The objective function and loss function of the first *M* traffic requests are given by Equations (11) and (12), and the afterward is given by:

$$\begin{split} L\_{\theta\_{p,\mathbb{S}}^{-}} &= -\sum\_{t=t\_{0}}^{t\_{0}+N-1} A(s\_{t}, a\_{t}; \theta\_{p,\mathbb{S}^{\prime}}^{-} \theta\_{v,\mathbb{S}}^{-}) \log \pi(a\_{t}|s\_{t}; \theta\_{p,\mathbb{S}}^{-}) \\ &- \alpha \sum\_{t=t\_{0}}^{t\_{0}+N-1} \sum\_{a\_{t} \in \{1,2,\cdots,K\}} \pi(a\_{t}|s; \theta\_{p,\mathbb{S}}^{-}) \log \pi(a\_{t}|s\_{t}; \theta\_{p,\mathbb{S}}^{-}), \end{split} \tag{13}$$

$$L\_{\theta\_{v,\mathbb{S}}^{-}} = \sum\_{t=t\_0}^{t\_0+N-1} A(s\_t, a\_t; \theta\_{p,\mathbb{S}'}^{-} \theta\_{v,\mathbb{S}}^{-})^2. \tag{14}$$

where *θ*− *<sup>p</sup>*,<sup>S</sup> and *θ*<sup>−</sup> *<sup>v</sup>*,<sup>S</sup> are the parameters of the policy and the value network of the student model during self-learning, respectively.


In this section, we introduce the simulation results of the proposed policy distillation design with the EONs. We applied the proposed method to three different scenarios: (1) policy distillation between different traffic patterns, (2) policy distillation between different topologies, and (3) policy distillation between different traffic patterns and topologies.

#### *5.1. Parameter Settings*

The common parameters used in the simulations are explained in below. For the simulations in Sections 5.2–5.5, these common parameters are used unless otherwise specified. Moreover, for convenience, the symbols of these key common parameters and their corresponding meanings and values are listed in Table 2.


**Table 2.** Key parameters and their corresponding meaning and values.

All the topologies used in the simulations are shown in Figure 3, where the weight of each edge of the topology represents the physical length of each link, and they will be used to calculate the FSs in Equation (1). We set the capacity of each fiber link to be 100 FSs. The traffic requests are generated according to independent Poisson processes. In order to ensure that the blocking probabilities of different topologies can fall within a reasonable range, we set a different traffic load for all the different topologies. The traffic patterns and the load for different simulation scenarios will be described in detail later. In addition, the bandwidth requirement of each traffic request is evenly distributed within [25, 100] Gb/s. The number of the shortest paths *K* is set to be 5, which means the DRL agent is to select a path from 5 candidate paths.

In terms of the neural network architecture, for the teacher model, the policy and value networks both have five hidden layers, with 256 neurons per layer. For the student model, the policy and value networks both have five hidden layers, with 128 neurons per layer. ReLU is used as the activation function for the hidden layers. We set the discount factor *γ*, the learning rate *β* and *η*, the coefficient of the entropy regularization term *α*, and the temperature of distillation *<sup>τ</sup>* to be 0.95, 1 × <sup>10</sup>−5, 1 × <sup>10</sup>−5, 0.01, and 5, respectively. In addition, the number of traffic requests for distillation *M* is 100,000. During the training, the mini-batch gradient descent algorithm and the Adam optimizer are used, with the mini-batch size *N* to be 200. The exploration rate *ε* is set to be 1 at the beginning and gradually decays by *ε*<sup>0</sup> (set to be 10−5) units during each training process until it reaches *εmin*, which is 0.05.

**Figure 3.** Optical network topologies: (**a**) 8-node, (**b**) 14-node NSFNET, (**c**) 11-node COST 239, and (**d**) 24-node US Backbone.

#### *5.2. Policy Distillation for Different Traffic Patterns*

We first evaluate the performance of our proposed scheme for different traffic patterns and the same network topology. In this subsection, both the teacher and the student models are trained over the same network topology: the 14-node NSFNET. The traffic patterns are different. We set the model trained under a uniformly distributed traffic pattern as the teacher model and the model applied for the non-uniformly distributed traffic patterns as the student models.

The traffic pattern is denoted by an *N* × *N* matrix *TP*, where *N*(=14) denotes the number of nodes of the NSFNET. The element *TPij* represents the traffic load ratio from node *i* to node *j*, where *TPij* = 0 when *i* = *j*. If *TPij* are the same for all *i*-*j* pairs (*i* = *j*), the traffic pattern is uniformly distributed. Otherwise, it is non-uniformly distributed. For the student model, we designed three different non-uniform traffic patterns, namely pattern A, pattern B, and pattern C, as shown in Figure 4a,c,e. They correspond to the following three settings:


For the uniform traffic patterns, the arrival rate is 12 arrivals per time unit and the average service time is 16 time units, while for the non-uniform traffic pattern, the arrival rate is 16 arrivals per time unit and the average service time is 25 time units. Table 3 records the traffic loads for all the traffic patterns in Section 5.2.

Figure 4b,d,f show the evolution of the simulation results as the number of requests increase, with the blocking probability calculated every 1000 *TR*(*vs*, *vd*, *b*) requests. The blue lines represent the blocking probabilities of the agents learning from scratch without policy distillation ("w/o PD"), while the red lines represent the blocking probabilities of the agents that learn with the policy distilled from the teacher model which is trained with the uniform traffic pattern ("PD-14-Node-uniform"). The green lines represent the blocking probabilities of the baseline algorithm: the K-shortest-path routing and first-fit spectrum allocation (KSP-FF) [36]. The "KSP-FF" in Figure 4b,d,f are the results of applying the KSP-FF algorithm to pattern A, pattern B, and pattern C of the 14-node NSFNET topology, respectively. We can see that, by policy distillation ("PD-14-Node-uniform"), the agent converges faster and achieves lower blocking probabilities, compared to the cases without policy distillation ("w/o PD"). Specifically, the blocking probability reductions are 10%, 10.7%, and 3.6% with pattern A, pattern B, and pattern C, respectively. These results imply that the policy distillation does well in traffic pattern variation tasks.

**Figure 4.** (**a**,**c**,**e**): The non-uniform traffic patterns for the student models. (**b**,**d**,**f**): Blocking probabilities under different traffic patterns ((**b**) pattern A, (**d**) pattern B, and (**f**) pattern C) for student model with policy distillation, student model without policy distillation, and the baseline KSP-FF algorithm.

**Table 3.** Traffic loads for all traffic patterns in Section 5.2.


#### *5.3. Policy Distillation for Different Topologies*

We have also conducted simulations for different topologies to evaluate the performance of the policy distillation scheme. In this case, we train two teacher models in the 8-node topology and the 14-node NSFNET topology, while the other two topologies (the 11-node COST 239 topology and the 24-node US Backbone topology) are used for training the student models. The traffic patterns for all the teacher and student models are the same in terms of distributions: uniform. For the 8-node, 11-node COST239, 14-node NSFNET, and 24-node US Backbone topology, the arrival rate is 14, 16, 12, and 12 arrivals per time unit, and the average service time is 25, 25, 16, and 14 time units, respectively. Table 4 records the traffic loads for all the traffic patterns in Section 5.3.


**Table 4.** Traffic loads for all traffic patterns in Section 5.3.

Figure 5a,b show the evolution of the blocking probability by the student models trained in different topologies. We denote the agents that learn with the policy distilled from the teacher models for the 8-node and 14-node NSFNET as "PD-Eight-Node" and "PD-14-Node", respectively. The KSP-FF algorithm is adopted as the baseline, it is applied to the training environment of the uniform distribution 11-node COST239 and 24-node US Backbone topology, respectively, and the results of the "KSP-FF" in Figure 5a,b are obtained. We can observe from Figure 5a that, for the student model trained in the 11-node COST239 topological environment, the cases with policy distillation ("PD-Eight-Node" and "PD-14-Node") reach the performance level of "KSP-FF" faster than the case without the policy distillation ("w/o PD"). Specifically, the blocking performance of the "PD-Eight-Node" and "PD-14-Node" matches that of the "KSP-FF" after about 150,000 and 244,000 traffic requests, but the "w/o PD" consistently performs worse than the "KSP-FF" before 1,000,000 traffic requests.

Similar results are observed in Figure 5b when the student model is trained in the 24-node US Backbone topological environment. Moreover, it can be seen from Figure 5a,b that the cases with the policy distillation ("PD-Eight-Node" and "PD-14-Node" ) have lower blocking probabilities after convergence compared with the case without the policy distillation ("w/o PD"). These results show that when the topology changes, policy distillation can assist the policy learning in the new environment. Figure 5c,d show the complementary cumulative distribution function (CCDF) with a blocking reduction compared to the "KSF-FF" from different schemes after training with 750,000 traffic requests. For the COST 239 topology, the "PD-Eight-Node" and "PD-14-Node" outperform the "KSP-FF" for around 54% and 52% cases, respectively, while the "w/o PD" only outperforms the "KSP-FF" for around 33% of the cases. For the US Backbone topology, the "PD-Eight-Node" and "PD-14-Node" outperform the "KSP-FF" for around 55.8% and 46.3% of the cases, respectively, while the "w/o PD" outperforms the "KSP-FF" for around 29.5% of the cases. This indicates the effectiveness of policy distillation.

**Figure 5.** (**a**,**b**): Blocking probability in training with different topologies, and (**c**,**d**): complementary cumulative distribution function (CCDF) with blocking reduction compared to KSP-FF algorithm after training with 750,000 traffic requests.

#### *5.4. Policy Distillation for Different Traffic Patterns and Topologies*

In this subsection, we change both the traffic patterns and the network topologies for the policy distillation. Similar with Section 5.3, two teacher models are trained under the 8-node topology and the 14-node NSFNET topology, while the student models are applied for the 11-node COST 239 topology and the 24-node US Backbone topology. Besides that, the teacher models are trained under uniform traffic patterns, while the student models are trained under a non-uniform traffic pattern. We have conducted four sets of simulations, denoted as Simulation T-1 to T-4. Detailed simulation settings of the student models are shown in Table 5, and the traffic loads of all the traffic patterns in Section 5.4 are shown in Table 6.

The simulation results are shown in Figure 6a–d. First, we can see that compared with the case without policy distillation ("w/o PD"), taking policy distillation from an eightnode-topology-and-uniform-traffic-pattern teacher ("PD-Eight-Node") and an NSFNETtopology-and-uniform-traffic-pattern teacher ("PD-14-Node") can effectively accelerate the training of student models and obtain lower blocking probabilities for all simulations. Specifically, the "PD-Eight-Node" achieves blocking reductions of 8.3%, 11.9%, 7.8%, and 9.8% for simulations T-1∼T-4, respectively. For the "PD-14-Node", the blocking probability reductions are 7.5%, 11%, 3.9%, and 2.4% for simulations T-1∼T-4, respectively. Meanwhile, Table 7 records the time (approximately) spent by different schemes when the blocking performance reaches the level of the "KSP-FF" in Simulation T-1∼Simulation T-4. In this section, the "KSP-FF" in Figure 6a–d are the results of applying the KSP-FF algorithm to the training environment of Simulation T-1∼Simulation T-4, respectively. We can notice that the "PD-Eight-Node" and "PD-14-Node" learn faster. In Simulation T-1∼Simulation T-4, when the blocking performance reaches that of the KSP-FF, the training time of the "PD-Eight-Node" is reduced by 31.4%, 14%, 57%, and 60.3% compared with that of the "w/o PD", respectively. A similar trend can be seen between the "PD-14-Node" and "w/o PD".


**Table 5.** Simulation settings for the student models in Section 4.4.

**Table 6.** Traffic loads for all traffic patterns in Section 5.4.


**Table 7.** Training duration when performance reaches KSP-FF (in seconds).


For all of the above simulations, we only use the KSP-FF heuristic algorithm as the baseline. As can be seen from the experimental figures, some DRL-based approaches can only achieve a comparable performance with the KSP-FF. For such results, we believe that the performance of the DRL-based approaches is limited by the design of the reward. In this regard, our work [37] has investigated the reward design, and the results are significantly better than the KSP-FF in terms of the blocking probability. However, the focus of this paper is not on the reward design. We pay more attention to the performance comparison before and after the introduction of knowledge distillation. From the above simulations, it can be seen that the blocking performance can be improved by integrating the knowledge distillation method.

**Figure 6.** Blocking probability of different topologies with different non-uniform traffic patterns.

#### *5.5. Policy Distillation with Different Neural Network Size of the Teacher Model*

We have also investigated the effect of the size of the teacher model's neural network on the performance of the proposed policy distillation design. Specifically, we design three different neural network settings for the teacher model: (1) three hidden layers with 64 neurons per layer (3 × 64) , (2) five hidden layers with 128 neurons per layer (5 × 128), and (3) eight hidden layers with 258 neurons per layer (8 × 256). The teacher model is trained under the uniform traffic pattern over the 14-node NSFNET, and the student models are trained under the uniform traffic pattern over the COST239 topology. The arrival rate and average service time are the same as in Section 5.2. The results of the blocking probability are shown in Figure 7.

**Figure 7.** Blocking probability in training with different size of teacher model's neural network.

The result shows that teacher models with different neural network sizes (PD-14-Node (3 × 64), PD-14-Node (5 × 128), and PD-14-Node (8 × 256)) can carry out policy distillation to the student models. This shows that the proposed policy distillation scheme is not limited by the size of the teacher models' neural network. When the neural network architecture of the teacher model and the student model are different, policy learning with policy distillation can also be carried out. This allows knowledge transfer in a broader context.

#### **6. Conclusions**

This paper proposes a deep reinforcement learning-based RMSA policy distillation design for the elastic optical networks. It allows the knowledge transfer from a welltrained teacher model under one training environment to a student model under a different environment, so that the training of the latter is accelerated with a better final performance. One highlight is that the student model and the teacher model can be different in terms of the neural network architecture. This allows the knowledge transfer in a broader context. Our method is verified by the simulations of the policy distillation over different traffic patterns and network topologies.

One limitation of our proposal is that the input dimension of the teacher model and the student model must be the same. Recall that the input represents the state of the elastic optical network; the above limitation poses constraints on the state representation. How to break this limitation can be considered for future work. Meanwhile, the performance of the learned RMSA policy in real optical networks should be studied experimentally in future work.

**Author Contributions:** B.T.: Conceptualization, Methodology, Software, Writing—original draft. Y.-C.H.: Conceptualization, Validation, Writing—review & editing. Y.X.: Conceptualization, Writing review. W.Z.: Supervision, Writing—review. All authors have read and agreed to the published version of the manuscript.

**Funding:** This research was funded by the National Natural Science Foundation of China (62006084), the Basic and Applied Basic Research Foundation of Guangdong Province (2020A1515111110), and the Guangdong Science and Technology Department (2016A010101020, 2016A010101021, and 2016A010101022).

**Institutional Review Board Statement:** Not applicable.

**Informed Consent Statement:** Not applicable.

**Data Availability Statement:** Not applicable.

**Acknowledgments:** The authors would like to express their sincere thanks to the Editors and Referees for their enthusiastic guidance and help.

**Conflicts of Interest:** The authors declare no conflict of interest.

#### **References**


### *Article* **Syntactically Enhanced Dependency-POS Weighted Graph Convolutional Network for Aspect-Based Sentiment Analysis**

**Jinjie Yang 1, Anan Dai 1, Yun Xue 1, Biqing Zeng <sup>2</sup> and Xuejie Liu 1,\***


**Abstract:** Aspect-based sentiment analysis (ABSA) is a fine-grained task of sentiment analysis that presents great benefits to real-word applications. Recently, the methods utilizing graph neural networks over dependency trees are popular, but most of them merely considered if there exist dependencies between words, ignoring the types of these dependencies, which carry important information, as dependencies with different types have different effects. In addition, they neglected the correlations between dependency types and part-of-speech (POS) labels, which are helpful for utilizing dependency imformation. To address such limitations and the deficiency of insufficient syntactic and semantic feature mining, we propose a novel model containing three modules, which aims to leverage dependency trees more reasonably by distinguishing different dependencies and extracting beneficial syntactic and semantic features to further enhance model performance. To enrich word embeddings, we design a syntactic feature encoder (SynFE). In particular, we design Dependency-POS Weighted Graph Convolutional Network (DPGCN) to weight different dependencies by a graph attention mechanism we proposed. Additionally, to capture aspect-oriented semantic information, we design a semantic feature extractor (SemFE). Extensive experiments on five popular benchmark databases validate that our model can better employ dependency information and effectively extract favorable syntactic and semantic features to achieve new state-of-the-art performance.

**Keywords:** aspect-based sentiment analysis; graph neural networks; dependency trees; dependency types; graph attention mechanism; syntactic; semantic

**MSC:** 18C50

#### **1. Introduction**

Aspect-based sentiment analysis (ABSA) is a popular topic in natural language processing with the purpose of identifying the sentiment polarities (i.e., positive, neutral and negative) toward the specific aspects in given sentences. Take this review "*The food in this restaurant is delicious but the service is terrible*" as an example. For aspect "*food*", the polarity is positive, while it is negative for aspect "*service*". An ABSA model aims to infer the sentiment polarities of the given aspects accurately on the fine-grained level.

The key to solving the ABSA task is to find the relations between aspects and corresponding opinion words properly. Early works combined recurrent neueal networks (RNNs) and the attention mechanism [1–5] to capture semantic information related to aspects and generate aspect-specific sentence representation. However, these methods are vulnerable to noises introduced by unrelated words. They also ignored the syntactic dependency information in the sentences, which makes it difficult for them to link aspects and corresponding opinion words due to the long position distance between them. Most recent works [6–10] applied graph-based networks such as graph convolutional networks (GCNs) and graph attention networks (GATs) over dependency trees to explicitly exploit syntactic dependency information, which achieved better performance. However, the drawback of these methods is that all dependencies are treated equally without weighting them

**Citation:** Yang, J.; Dai, A.; Xue, Y.; Zeng, B.; Liu, X. Syntactically Enhanced Dependency-POS Weighted Graph Convolutional Network for Aspect-Based Sentiment Analysis. *Mathematics* **2022**, *10*, 3353. https://doi.org/10.3390/ math10183353

Academic Editors: Jianping Gou, Weihua Ou, Shaoning Zeng and Lan Du

Received: 2 August 2022 Accepted: 13 September 2022 Published: 15 September 2022

**Publisher's Note:** MDPI stays neutral with regard to jurisdictional claims in published maps and institutional affiliations.

**Copyright:** © 2022 by the authors. Licensee MDPI, Basel, Switzerland. This article is an open access article distributed under the terms and conditions of the Creative Commons Attribution (CC BY) license (https:// creativecommons.org/licenses/by/ 4.0/).

according to their types. Linguistically, the dependencies with different types have dividual significance, with some dependencies among words providing benefits to the ABSA task, while others introduce noises that hurt model performance. As shown in Figure 1, the aspect "*food*" has dependencies with two other words "*the*" and "*delicious*". Obviously, "*delicious*" as the opinion word is more important for sentiment analysis of aspect "*food*", while "*the*" does not show explicit information. Thus, the dependency type "*nsubj*" means that "*food*" is a nominal subject of "*delicious*" and should be assigned more attention weight than "*det*" only meaning that "*The*" is a determiner of "*food*". It can be seen that modeling dependency types properly is necessary for advancing the ABSA task. Meanwhile, the previous GCN-based models omitted the correlations between dependency types and POS labels. For example, with comprehensive investigation from the datasets, we find that the important dependency type "*nsubj*" is frequently connected with noun labels and verb labels such as "*JJ*", "*NN*", "*NNS*", etc. Normally, there are many words in a sentence, but a few words are valuable for sentiment analysis of the aspect. After the investigation, we conclude that the POS of opinion words is typically an adjective or verb due to words with these POS usually carrying clear sentiment information. So, incorporating dependency information and POS information is a potential way for upgrading the ABSA task.

To tackle the above limitations and improve model performance, we propose a novel model including three effective modules: SynFE, DPGCN and SemFE. Particularly, the main module DPGCN incorporates dependency information and POS information to weight different dependencies. SynFE and SemFE supplement syntactic and semantic features for aspects to further improve model performance.

Our contributions are summarized as follows:


**Figure 1.** An example sentence with its dependency tree and POS labels.

#### **2. Related Work**

With the booming development of deep learning, relevant models were applied to this task. Many early attention-based neural network models [1–5] achieved promising performance, which aroused great concern. Ref. [1] proposed an attention-based LSTM which focuses on the key part of sentences to obtain contextual representations. Refs. [2,3] introduced a memory network with an attention mechanism to extract sentiment information related to aspects. Ref. [4] utilized a multi-grained attention mechanism to capture word-level interactions between aspects and contexts. Ref. [5] exploited the Attention over Attention network to learn aspect representations and sentence representations together. In addition, pre-trained language model BERT [11] has achieved remarkable performance in a number of NLP tasks, including ABSA. Ref. [12] transformed the ABSA task into

sentence–aspect pair classification and achieved excellent performance by fine-tuning the BERT model.

Early works lost sight of the usefulness of syntactic knowledge to the ABSA task. To explicitly exploit syntactic dependency information, ref. [13] proposed an attention model with syntactic dependency information to obtain attention weights, and ref. [14] introduced syntactic relative distance to reduce the negative effects of words that are weakly related to aspects. Graph Convolutional Network (GCN) [15] had achieved surprising performance in many NLP tasks, including ABSA. Applying a GCN-based model over dependency trees became a new trend, which developed several outstanding models. Refs. [6,7] applied GCN over dependency trees to capture the syntactic dependency information for aspects. Ref. [8] noticed the word co-occurrence information, building a hierarchical syntactic graph and lexical graph for graph convolution. Ref. [16] proposed a relational graph attention network (R-GAT) to encode the new dependency trees for sentiment analysis. Ref. [17] designed DualGCN including SynGCN and SemGCN to extract syntactic and semantic information for aspects, respectively.

#### **3. Method**

In this section, we elaborate the details of our proposed model. The overall structure of our model is shown in Figure 2, and the details of SynFE and DPGCN are depicted in Figure 3.

Our model mainly consists of three modules: (1) **SynFE**, which encodes the dependency information and POS information of the sentences to enrich word-level vector representations; (2) **DPGCN**, which captures the correlations between dependency types and POS labels to weight dependencies with different types; and (3) **SemFE**, which extracts semantic features from the overall sentence to supplement sentiment features for aspect representations. Each component will be presented in detail and analyzed for their contribution.

**Figure 2.** The overall structure of our proposed model.

#### *3.1. Problem Definition (ABSA)*

Given an n-word review sentence *S* = {*w*1, *w*2, ··· , *wa*+1, ··· , *wa*+*m*, ··· , *wn*} with an m-word aspect *A* = {*wa*+1, ··· , *wa*+*m*} in it, ABSA aims at identifying the sentiment polarity (i.e., positive, neutral or negative) of the given aspect in a sentence. If there is more than one m-word aspect in a sentence, our model processes the sentence several times, i.e., outputting the sentiment polarity of one aspect once.

**Figure 3.** The details of (**a**) SynFE and (**b**) DPGCN.

#### *3.2. Initial Embedding Module*

The pre-trained language model BERT has the ability to provide word embeddnings with rich feature information; thus, we construct a sentence–aspect pair (*S*, *A*) as the input of BERT to initialize aspect-aware word vectors with the input form: "*[CLS] sentence [SEP] aspect [SEP]*", where '*CLS*' is a symbol token for encoding overall sentence-level representation, and '*SEP*' is a separator for separating sentence and aspect. The calculations in BERT are as follows:

$$\{h^{\mathbb{C}LS}, H^S, H^A\} = BERT(\{\mathbb{C}LS, S, SEP, A,SEP\})\tag{1}$$

where *<sup>h</sup>CLS* ∈ R*da* is the overall sentence-level representation, *<sup>H</sup><sup>S</sup>* = {*h*1, *<sup>h</sup>*2, ··· , *hn*} ∈ R*n*×*da* are the word-level representations of the sentence, where n is the number of words in one sentence and *da* is the dimension of each word vector, and *<sup>H</sup><sup>A</sup>* = {*ha*+*i*, ··· , *ha*+*m*} ∈ R*m*×*da* are the aspect representations. We only adopt *HS*, which contain aspect representations and contextual representations of the sentence.

#### *3.3. Syntactic Feature Encoder (SynFE)*

The quality of textual representations is critical to all NLP tasks. To enrich the features for word representations of aspects and contexts, we encode syntactic information (i.e., dependency information and POS information) and fuse them into word representations.

According to the structures of dependency trees, we construct a key–value network to learn syntax-aware representations. In detail, our module obtains dependency trees of the given sentences from an off-the-shelf NLP toolkit (i.e., StanfordCoreNLP). We map dependencies to key sets *K*, and dependency types and POS labels are mapped to dependency value sets *V<sup>D</sup>* and POS value set *VP*. As illustrated in Figure 1, each word has dependencies with other words. For *wi* (i.e., the i-th word in the sentence), we map the dependencies related to it and corresponding dependency types to a key set *Ki* = {*ki*,1, *ki*,2, ··· , *ki*,*n*} in *K* and a value set *V<sup>D</sup> <sup>i</sup>* = {*v<sup>d</sup> <sup>i</sup>*,1, *<sup>v</sup><sup>d</sup> <sup>i</sup>*,2, ··· , *<sup>v</sup><sup>d</sup> <sup>i</sup>*,*n*} ∈ R*n*×*db* in *<sup>V</sup>D*, respectively. Each element in *Ki* represents the weight for corresponding dependency; *Ki*,*<sup>j</sup>* = 0 if there is no dependency between *wi* and *wj*. For *V<sup>D</sup> <sup>i</sup>* , the element is the embedding vector for the corresponding dependency type. For example, *v<sup>d</sup> <sup>i</sup>*,*<sup>j</sup>* ∈ R*db* in *<sup>V</sup><sup>D</sup> <sup>i</sup>* represents a *db*-dimensional embedding vector for the dependency type between *wi* and *wj*. In particular, the type is denoted as "*none*" if there is no dependency between two words, while it is "*self* " between one word and itself. *<sup>V</sup><sup>P</sup>* <sup>=</sup> {*v<sup>p</sup>* <sup>1</sup> , *<sup>v</sup><sup>p</sup>* <sup>2</sup> , ··· , *<sup>v</sup><sup>p</sup> <sup>n</sup>*} ∈ R*n*×*dc* represent the embedding vectors of POS labels

for words in the sentence. The embedding vectors in *V<sup>D</sup>* and *V<sup>P</sup>* are randomly initialized and trainable, but the weights in *K* are calculated by:

$$k\_{i,j} = \frac{a\_{i,j} \* \exp\left(h\_i h\_j^T\right)}{\sum\_{k=1}^n a\_{i,k} \* \exp\left(h\_i h\_k^T\right)}\tag{2}$$

where *T* denotes vector transpose, *ai*,*<sup>j</sup>* = 1 if a dependency exists between *wi* and *wj*, and it is 0 otherwise, ∗ is element-wise multiplication, and *hi* ∈ R*da* and *hj* ∈ R*da* are the word representations of *wi* and *wj* from BERT. The syntactic feature representations for words are learnt by:

$$
\boldsymbol{\sigma}\_{i}^{d} = \sum\_{j=1}^{n} \boldsymbol{k}\_{i,j} \ast \boldsymbol{v}\_{i,j}^{d} \quad \boldsymbol{\sigma}\_{i}^{p} = \sum\_{j=1}^{n} \boldsymbol{k}\_{i,j} \ast \boldsymbol{v}\_{j}^{p} \tag{3}
$$

where *v<sup>d</sup> <sup>i</sup>*,*<sup>j</sup>* <sup>∈</sup> <sup>R</sup>*db* and *<sup>v</sup><sup>p</sup> <sup>j</sup>* ∈ R*dc* are the embedding vectors for the dependency type between *wi* and *wj* and the POS label of *wj*, and *o<sup>d</sup> <sup>i</sup>* ∈ R*db* and *<sup>o</sup> p <sup>i</sup>* ∈ R*dc* are the dependency representation and POS representation for *wi* that provide beneficial syntactic features for word representation. Afterward, we incorporate them into the word representations from BERT by:

$$\mathbf{x}\_{i} = h\_{i} \oplus o\_{i}^{d} \oplus o\_{i}^{p} \tag{4}$$

where *xi* refers to the output of the key-value network for *wi*, and ⊕ denotes vector concatenation. *<sup>X</sup>* <sup>=</sup> {*x*1, *<sup>x</sup>*2, ··· , *xn*} ∈ <sup>R</sup>*n*×(*da*+*db*+*dc* ) are updated word representations with syntactic features output by the key-value network.

#### *3.4. Dependency-POS Weighted Graph Convolutional Network (DPGCN)*

An adjacency matrix *A* = {*ai*,*j*}*n*×*<sup>n</sup>* is used to represent the structure of a dependency tree in a traditional GCN-based model, A is a 0-1 matrix where *ai*,*<sup>j</sup>* = 1 if there exists a dependency between *wi* and *wj*, and *ai*,*<sup>j</sup>* = 0 otherwise. In order to make better use of syntactic dependency information, we propose a graph attention mechanism to construct DPGCN graph *G* = {*gi*,*j*}*n*×*<sup>n</sup>* by combining dependency types and POS labels, where *gi*,*<sup>j</sup>* ∈ [0, 1] while *ai*,*<sup>j</sup>* = {0, 1}. First, *ti*,*<sup>j</sup>* is denoted as the dependency type between *wi* and *wj*, and there is a trainable embedding vector *<sup>α</sup>i*,*<sup>j</sup>* ∈ R2∗*da* for each type *ti*,*j*. Then, to alleviate noises introduced from POS, the POS labels of all words are divided into five categories (i.e., Nouns, Verbs, Adjectives, Adverbs and Others), where there is a POS mapping matrix *<sup>w</sup><sup>p</sup>* ∈ R*da*×*da* corresponding to each POS category. DPGCN is an L-layer GCN-based module; each layer exists to a corresponding DPGCN graph *<sup>G</sup><sup>l</sup>* = {*g<sup>l</sup> i*,*j* }*n*×*n*, and all *g<sup>l</sup> <sup>i</sup>*,*<sup>j</sup>* are calculated by:

$$r\_{i,j}^l = \operatorname{Relu}(a\_{i,j}^T [h\_i^{l-1} w\_i^p \oplus h\_j^{l-1} w\_j^p]) \tag{5}$$

$$\log\_{i,j}^l = \text{softmax}(r\_{i,j}^l) = \frac{a\_{i,j} \* \exp(r\_{i,j}^l)}{\sum\_{k=1}^n a\_{i,k} \* \exp(r\_{i,k}^l)}\tag{6}$$

where *ai*,*<sup>j</sup>* <sup>∈</sup> *<sup>A</sup>*, *<sup>h</sup>l*−<sup>1</sup> *<sup>i</sup>* <sup>∈</sup> <sup>R</sup>*da* and *<sup>h</sup>l*−<sup>1</sup> *<sup>j</sup>* ∈ R*da* are hidden state representations for *wi* and *wj* output by the (L − 1)-th layer of DPGCN, *<sup>l</sup>* stands for the L-th DPGCN layer, and *<sup>g</sup><sup>l</sup> i*,*j* denotes the dependency weight between *wi* and *wj* at the L-th layer.

Based on DPGCN graph *G<sup>l</sup>* , DPGCN learns refined word representations (i.e., both aspect representations and contextual representations). First, *X* = {*x*1, *x*2, ··· , *xn*} ∈ R*n*×(*da*+*db*+*dc* ) from SynFE are fed into a feedforward network to obtain 768-dimensional hidden state representations *<sup>H</sup>*<sup>0</sup> = {*h*<sup>0</sup> <sup>1</sup>, *<sup>h</sup>*<sup>0</sup> <sup>2</sup>, ··· , *<sup>h</sup>*<sup>0</sup> *<sup>n</sup>*} ∈ R*n*×*da* as the input of DPGCN by:

$$H^0 = H^S \mathcal{W}^S + b^S \tag{7}$$

where *<sup>W</sup><sup>S</sup>* <sup>∈</sup> <sup>R</sup>(*da*+*db*+*dc* )×*da* and *<sup>b</sup><sup>S</sup>* <sup>∈</sup> <sup>R</sup>*da* are trainable weight matrix and bias. Then, all layers of DPGCN proceed convolution as follows:

$$h\_i^l = \sigma\left(\sum\_{j=1}^n g\_{i,j}^l \* \left(h\_j^{l-1}\mathcal{W}^l + b^l\right)\right) \tag{8}$$

$$H^l = \sigma\left(G^l(H^{l-1}\mathcal{W}^l + b^l)\right) \tag{9}$$

where *<sup>W</sup><sup>l</sup>* ∈ R*da*×*da* and *<sup>b</sup><sup>l</sup>* ∈ R*da* are the trainable weight matrix and bias in the L-th layer, *<sup>σ</sup>* is the Relu activation function, and *<sup>H</sup><sup>l</sup>* = {*h<sup>l</sup>* <sup>1</sup>, *<sup>h</sup><sup>l</sup>* <sup>2</sup>, ··· , *<sup>h</sup><sup>l</sup> <sup>n</sup>*} ∈ R*n*×*da* output by DPGCN represents the refined word representations for aspect and contexts. DPGCN learns highquality word representations in the way of more reasonably leveraging dependency trees.

#### *3.5. Semantic Feature Extractor (SemFE)*

To retrieve more sentiment features from the overall sentence for sentiment classification, we extract semantic features related to aspect upon contexts to generate aspectoriented sentence representation.

**Position Encoding**: In terms of the common sense of linguistics, contexts close to aspect generally have greater influences on the sentiment expression of aspect, so position weights are defined as follows:

$$d\_t = \begin{cases} 1, & \text{dis} = 0, \\\ 1 - \frac{\text{dis}}{n}, & 1 \le \text{dis} \le d\_t \\\ 0, & \text{dis} > d\_t \end{cases} \quad 1 \le t \le n \tag{10}$$

where dis denotes the distance from contexts to aspect; it is aspect itself when dis = 0, and we mask the word representations when *dis* > *d* to avoid introducing noises. Take this simple sentence "*The served food is delicious*" as an example; the position weights are set to *dt* = [0.8, 1, 1, 0.8, 0] when the aspect is "*served food*" and *d* = 1. The position-encoded word representations *P* are obtained by:

$$p\_t = d\_t \* h\_{t\prime}^l \quad 1 \le t \le n \tag{11}$$

$$P = p\_l(1 \le t \le n) = \{p\_1, p\_2, \dots, p\_n\} \tag{12}$$

where *h<sup>l</sup> <sup>t</sup>* ∈ R*da* is a word representation for *wt* from DPGCN. The final aspect-oriented sentence representation *s* containing abundant semantic information is generated by:

$$\delta\_t = \sum\_{i=1}^{m} h\_{a+i}^l p\_t^T, \quad \gamma\_t = \frac{\exp(\delta\_t)}{\sum\_{i=1}^{n} \exp(\delta\_i)} \tag{13}$$

$$s = \sum\_{i=1}^{n} \gamma\_t \* p\_t \tag{14}$$

where *h<sup>l</sup> <sup>a</sup>*+*<sup>i</sup>* ∈ R*da* is the aspect representation from DPGCN for the i-th aspect word, and m and n are the length of the aspect and sentence, respectively.

#### *3.6. Model Training*

We obtain the final aspect representation *ha* ∈ R*da* and incorporate it and aspectoriented sentence representation *<sup>s</sup>* ∈ R*da* to obtain final representation *<sup>z</sup>* ∈ R*da* by:

$$h\_d = \frac{1}{m} \ast \sum\_{i=1}^{m} h\_{a+i}^{l} \tag{15}$$

$$z = \varepsilon \ast h\_a + (1 - \varepsilon) \ast s \tag{16}$$

where *ε* ∈ (0, 1) is a trainable coefficient.The final representation *z* is passed through a fully connected layer and a softmax activation function to obtain the probability distribution of sentiment polarity, and the label of highest probability is chosen as the sentiment polarity of the specific aspect by:

$$
u = 
soth 
max(zw\_
u + b\_
u) \tag{17}$$

where *<sup>u</sup>* is the probability distribution, and *wu* ∈ R*da*×<sup>3</sup> and *bu* ∈ R<sup>3</sup> are the trainable weight matrix and bias.

The model is trained by using the standard stochastic gradient descent to minimize the cross-entropy loss of sentiment classification. The loss function is formulated as:

$$L(\theta) = -\sum\_{i}^{N} u\_i \log \dot{u}\_i + \lambda ||\theta|| \tag{18}$$

where *N* is the number of training samples, *ui* is the prediction for the sentiment polarity of aspect, *u*˙*<sup>i</sup>* is the target label, *θ* represents all trainable parameters, and *λ* is the L2 regularization coefficient.

#### **4. Experiments**

*4.1. Datasets*

For the experiments, five benchmark datasets are adopted for the ABSA task to evaluate our model, including the Twitter dataset constructed by [18] and another four datasets (*Lap*14, *Res*14, *Res*15, *Res*16) all from the SemEval tasks [19–21], which are user reviews related to computers and restaurants. The samples in these datasets contain a given sentence, a specific aspect and the sentiment polarity (i.e., positive, neutral or negative) toward the aspect. The information of these five datasets is shown in Table 1.


**Table 1.** The statistics of datasets.

#### *4.2. Experimental Settings*

In our experiments, the syntactic dependency trees and POS labels of the given sentences are constructed via using StanfordCoreNLP toolkit (https://stanfordnlp.github. io/CoreNLP/, accessed on 1 August 2022). The given sentences are encoded by pre-trained model BERT (we obtain the BERT model from (https://github.com/huggingface/pytorchpretrained-BERT, accessed on 1 August 2022) to obtain 768-dimensional initialized wordembedding vectors for each word. In addition, the dimensions of dependency-embedding vectors and POS-embedding vectors are set to 100 and 50, respectively. For the parameter settings, BERT is initialized with pre-trained parameters, while other trainable parameters are initialized by Xavier [22]. The loss function of the model training is the cross-entropy function, and the Adam optimizer is adopted. For key hyper-parameter settings, the batch size is 16, the learning rate is 1 × <sup>10</sup>−5, the L2 regularization coefficient *<sup>λ</sup>* is 0.001, and the dropout rate is 0.1. Our model is evaluated by both accuracy and macro-averaged F1 score.

#### *4.3. Baselines*

To verify the effectiveness of our model, many comparative state-of-art models are adopted and briefly described as follows:

**IAN** [23] proposes an interactive attention network based on LSTM and attention mechanism to generate representations for aspects and sentences.

**MGAN** [4] designs a novel multi-grained attention network model to capture interactions between the aspects and contexts.

**ASGCN** [6] first proposes leveraging GCN over dependency trees to learn aspectspecific representations for the ABSA task.

**CDT** [7] utilizes GCN over dependency trees to extract syntactic information for aspect representations.

**BIGCN** [8] constructs a syntactic graph and lexical graph for GCN to leverage word co-occurrence information and syntactic information.

**BERT** [11] is the vanilla BERT, which adopts "[CLS] sentence [SEP] aspect [SEP]" as input and uses the representation of [CLS] for predictions.

**TD** [24] implements a target-dependent BERT-based model with positioned output at the target terms and an optional sentence for the ABSA task.

**R-GAT** [16] employs a relational graph attention network to exploit syntactic structure information.

**DGEDT** [25] considers the flat representations and graph-based representations jointly to alleviate the noise and instability of dependency trees.

**LCFS** [14] models contextual and syntactical features for the ABSA task.

**TGCN** [26] encodes dependency types and integrates the representations from all GCN layers to learn aspect representations.

**DualGCN** [17] designs SynGCN and SemGCN to learn aspect representations by considering syntax structures and semantic correlations simultaneously.

#### *4.4. Results and Analysis*

Comparisons of all model performance are presented in Table 2. It can be seen that the results on the four SemEval task datasets (*Lap*14, *Res*14, *Res*15, *Res*16) outperform the previous models, but the performance on the Twitter dataset is lower than DualGCN and DGEDT due to the comments on the Twitter being informal and inclined to colloquial, which are insensitive to syntactic information. The comparisons show that our model surpasses many state-of-art GCN-based models, indicating that our model can better utilize syntactic dependency information, and extracting sufficient syntactic and semantic features for sentiment analysis is helpful for model performance.

Compared with the attention-based models IAN and MGAN, our model avoids the noises introduced by the ordinary attention mechanism. Compared with the GCN-based models utilizing dependency information, such as ASGCN, CDT, R-GAT, TGCN, DualGCN and so on, our model combines dependency types and POS labels to weight dependencies with different types according to their contribution to the ABSA task, which better exploits syntactic dependency information.


**Table 2.** The results of comparisons.

Models using BERT are marked by "-" and models using GCN and dependency information are maked by "!".

#### *4.5. Ablation Study*

To explore the significance of each module, we conduct a series of ablation experiments, where the results are shown in Table 3. The findings are listed as follows:

(1) The results decrease after removing the SynFE module, which indicates that the SynFE module can effectively encode the syntactic information in the sentences to enrich textual features. (2) If DPGCN removes dependency information (D) or POS information (P) separately, the results slightly reduce, but they strongly reduce when removing both simultaneously. It shows that DPGCN weighting dependencies with different types has a great impact on model performance; both (D) and (P) contribute to it. (3) For the SemFE module, similarly, the results demonstrate that position encoding helps to avoid introducing noises, and aspect-oriented sentence representation has the ability to supplement semantic features for a better performance. Overall, each component makes a contribution to the model performance.



#### *4.6. Case Study*

For the purpose of illustrating the effectiveness of our model, we randomly select sample data from the test set to launch a case analysis.

#### 4.6.1. Weights Visualization

As can be seen in Figure 4a, for the aspect "*food*", linguistically, the opinion word "*delicious*" leads to the sentiment polarity of "*food*"; our model has learnt to assign higher

weight to dependency type "*nsubj*" between "*food*" and "*delicious*". The dependency type "*conj*" between "*delicious*" and "*terrible*" may cause a negative effect for aspect "*food*"; our model reduces the dependency weight for "*conj*" denoting the relation between two elements connected by a coordinating conjunction. In addition, other secondary dependencies are assigned lower weights. In Figure 4b, under the guidance of DPGCN, aspect "*food*" has higher semantic-based attention weight with opinion word "*delicious*".

Similarly, as seen in Figure 4c,d, vital dependencies are assigned higher weights; thus, aspect "*service*" aggregates feature information from important contexts which account for "*service*" having higher semantic-based attention weights with them. Specially, our model notices the semantic reversal information that "*but*" with dependency type "*cc*" standing for a coordination relation expresses.

#### 4.6.2. Probability Distribution Visualization

As shown in Figure 5, the SemFE module increases the probabilities of the true sentiment polarities for aspects "*food*" and "*service*". Due to pre-trained model BERT being pretrained on large amounts of textual datasets, it can provide text embeddings containing rich semantic features. Thus, "*delicious*" is encoded with positive sentiment features and negative sentiment features for "*terrible*". SemFE generating aspect-oriented sentence representation *s* is able to supplement corresponding sentiment features for aspects. For aspect "*food*", the context words "*service*" and "*terrible*" carrying negative sentiment features are masked by position encoding, and the positive sentiment features from "*delicious*" are incorporated into *s*, which makes our model predict the sentiment polarity of "*food*" more accurately. Similarly, for "*service*", although it is disturbed by "*delicious*", "*service*" has higher semantic weight with "*terrible*", and the *s* also can provide correct sentiment features for "*service*".

**D**Probability distribution of aspect 'food' **E**Probability distribution of aspect 'service' **Figure 5.** Visualization results of probability distributions.

#### *4.7. Impacts of DPGCN Layer Number*

An appropriate layer number *L* for DPGCN is also beneficial to model performance, so the impacts of different *L* values are explored, and the results are shown in Figure 6. It shows that the accuracy and F1 score on all datasets are highest when *L* = 2, and we will analyze this phenomenon. As described in Figure 7a, aspect "*place*" has different syntactic distances with contexts; through one DPGCN layer, each word node aggregates feature information from neighbor word nodes (i.e., SD = 1), so aspect "*place*" only captures feature information from contexts that SD = 1 with it when *L* = 1. Obviously, "*good*" misleads the sentiment analysis for "*place*"; a two-layer DPGCN makes "*place*" capture vital sentiment features from *SD* = 2 context word "*not*". However, if *L* = 3 or more, it will introduce irrelevant feature information such as "*Thai*" and "*wonderful*", which carry positive sentiment features.

As described in Figure 7b, when *L* = 3, aspect "*place*" has high semantic weights (from SemFE) on itself and "*good*" through the 1st DPGCN layer. Through the 2nd DPGCN layer, the semantic weight on "*good*" decreases and increases on "*not*". After through the 3rd DPGCN layer, the distribution of semantic weights on contexts is decentralized, so aspect "*place*" is difficult to focus on vital information, and much irrelevant information is introduced.

In conclusion, our model is unable to capture sufficient contextual feature information for aspects when *L* = 1, but it may introduce much more irrelevant information for aspects that damage model performance when *L* = 3 or more. *L* = 2 is favorable for our model.

**Figure 6.** Impacts of the layer number *L*.

**Figure 7.** Syntactic distances and semantic weights for aspect 'place'.

#### **5. Conclusions**

In this paper, we propose a novel model containing three effective modules (i.e., SynFE, DPGCN and SemFE) for the ABSA task. To address the limitations upon utilizing the syntactic dependency information of previous works, we propose to weight dependencies with different types according to their contribution to the ABSA task, and the DPGCN module is designed to combine dependency information and POS information to more reasonably and comprehensively leveraged syntactic dependency information. To further improve the model performance, SynFE is designed to encode syntax-aware features to enrich word representations, and SemFE is designed to extract aspect-oriented semantic information that supplements sentiment features for sentiment classification.

Extensive experiments are conducted on five benchmark datasets to demonstrate the validity of our model, and the results show that our model achieves new state-of-the-art performance. Compared to other outstanding models, our model achieves a higher accuracy and F1 score. The ablation experiments show that each module in our model contributes to the model performance, and the case study demonstrates that DPGCN has learnt to assign appropriate weights to dependency types, which overcome the shortcomings of the previous model on utilizing dependency trees. In addition, we have analyzed how SemFE is auxiliary for sentiment prediction and explored the effects of the number of DPGCN layers on the model performance.

**Author Contributions:** Conceptualization, J.Y. and Y.X.; methodology, J.Y.; formal analysis, J.Y. and A.D.; writing—original draft preparation, J.Y.; writing—review and editing, X.L. and B.Z.; supervision, Y.X. and B.Z.; funding acquisition, X.L. All authors have read and agreed to the published version of the manuscript.

**Funding:** This work was funded by the Science and Technology Plan Project of Guangzhou under Grant Nos. 202102080258 and 202102020902.

**Institutional Review Board Statement:** Not applicable.

**Informed Consent Statement:** Not applicable.

**Data Availability Statement:** Not applicable.

**Conflicts of Interest:** The authors declare no conflict of interest.

#### **References**


### *Article* **Joint Semantic Intelligent Detection of Vehicle Color under Rainy Conditions**

**Mingdi Hu 1,\*,†, Yi Wu 1, Jiulun Fan <sup>1</sup> and Bingyi Jing 2,\*,†**


**Abstract:** Color is an important feature of vehicles, and it plays a key role in intelligent traffic management and criminal investigation. Existing algorithms for vehicle color recognition are typically trained on data under good weather conditions and have poor robustness for outdoor visual tasks. Fine vehicle color recognition under rainy conditions is still a challenging problem. In this paper, an algorithm for jointly deraining and recognizing vehicle color, (*JADAR*), is proposed, where three layers of *UNet* are embedded into *RetinaNet*-50 to obtain joint semantic fusion information. More precisely, the *UNet* subnet is used for deraining, and the feature maps of the recovered clean image and the extracted feature maps of the input image are cascaded into the Feature Pyramid Net (*FPN*) module to achieve joint semantic learning. The joint feature maps are then fed into the class and box subnets to classify and locate objects. The *Rain Vehicle Color*-24 dataset is used to train the *JADAR* for vehicle color recognition under rainy conditions, and extensive experiments are conducted. Since the deraining and detecting modules share the feature extraction layers, our algorithm maintains the test time of *RetinaNet*-50 while improving its robustness. Testing on selfbuilt and public real datasets, the mean average precision (*mAP*) of vehicle color recognition reaches 72.07%, which beats both sate-of-the-art algorithms for vehicle color recognition and popular target detection algorithms.

**Keywords:** vehicle color recognition; low–high level joint task; object detection; joint semantic learning; deep neural network; rainy image recovery

**MSC:** 54H30; 68U10; 94A08

### **1. Introduction**

Vehicle information recognition has been applied in the field of intelligent traffic management and criminal investigation. License plate, model, and vehicle color comprise the main vehicle information. Although license plate recognition is a commonly used vehicle information recognition technology [1], it also faces many challenges in criminal investigation and for intelligent traffic law enforcement, as license plates can be easily obscured (partially or fully) or faked/duplicated by criminals. As it can still be identified despite partial occlusion or viewpoint changes, vehicle color recognition is widely applied in video surveillance [2], vehicle detection [3], vehicle tracking [4], automatic driving [5,6], criminal investigation [7], etc. All the above-mentioned tasks inevitably encounter adverse weather conditions, especially rain. This, in turn, adversely affects the performance of object recognition/retrieval, because rain can significantly reduce the contrast of the scene and reduces visibility, compromising image quality. Many scholars have conducted research on how to improve the performance of object detection under rainy conditions.

**Citation:** Hu, M.; Wu, Y.; Fan, J.; Jing, B. Joint Semantic Intelligent Detection of Vehicle Color under Rainy Conditions. *Mathematics* **2022**, *10*, 3512. https://doi.org/10.3390/ math10193512

Academic Editors: Jianping Gou, Weihua Ou, Shaoning Zeng and Lan Du

Received: 9 August 2022 Accepted: 20 September 2022 Published: 26 September 2022

**Publisher's Note:** MDPI stays neutral with regard to jurisdictional claims in published maps and institutional affiliations.

**Copyright:** © 2022 by the authors. Licensee MDPI, Basel, Switzerland. This article is an open access article distributed under the terms and conditions of the Creative Commons Attribution (CC BY) license (https:// creativecommons.org/licenses/by/ 4.0/).

Vehicle color recognition methods are typically classified as traditional model-driven [8–11] or data-driven deep learning [12–18]. Traditional methods usually use handcrafted feature descriptors to extract visual features and train a classifier to recognize vehicle color. For example, Chen et al. [8] select the region of interest (ROI) of the vehicle to recognize its dominant color and than train a linear support vector machine to classify it. Jeong et al. [9] adopt *AdaBoost* to classify an *HSV* histogram of the vehicle's homogeneity patches into seven color categories.

Deep neural networks have been employed to learn effective feature representations from raw pixels, which has proven to be more powerful than traditional methods. To be more specific, these deep learning methods fall into two groups: general object detection algorithms [19–24] applied to vehicle color recognition [16–18] and algorithms specially designed for color recognition. All these algorithms are trained on datasets with 7–24 colors [8–10,18] obtained under normal weather conditions. Of course, there exists some research to address object recognition under rainy conditions; basically, these proposed methods adopt two-stage instead of end-to-end procedures, which inevitably increases the running time of the entire task.

On the other hand, a number of scholars have paid attention to joint processing of lowlevel and high-level tasks. Generally, they improve the robustness of object detection [25–28] by embedding domain adaptation, image restoration, style transfer, or other modules into the object detection framework, or by a few-shot transfer learning mechanism [29–32]. These methods have explored the robustness of performance for downstream tasks in many harsh environments except under rainy weather conditions, which motivated our work in the present paper.

In this paper, a Joint Algorithm for Deraining And Recognition (*JADAR*) is proposed for fine recognition of vehicle color under rainy conditions. The network architecture is shown in Figure 1. To be more specific, we embed the three-layer decoder of *UNet*-3 into the last three layers of the feature extraction submodule of *RetinaNet*-50. The main contributions are as follows:


**Figure 1.** Framework overview for our *JADAR*; detailed explanations are in the text.

Next, related work is introduced in Section 2. *JADAR* is constructed in Section 3. Section 4 shows that our method is superior to the state-of-the-art quantitatively and qualitatively. Section 5 concludes the main content.

#### **2. Related Work**

There exists some research on vehicle color recognition under normal weather and object detection under adverse weather conditions, which is reviewed below.

#### *2.1. Vehicle Color Recognition under Normal Weather Conditions*

Vehicle color recognition methods generally fall into traditional model-based methods [8–11] and data-based deep learning methods [12–18]. Regarding traditional model-based methods, Chen et al. [8] train a linear support vector machine classifier on the region of interest *ROI* in vehicle images based on eight color types; Jeong et al. [9] adopt the multi-class AdaBoost algorithm to classify the color of front-of-vehicle images into seven types; Dule et al. [11] train three classifiers (KNN, ANN, and SVM) for two *ROI*s (smooth hood section and semi-front of vehicle).

Data-based methods have been receiving increasing attention for vehicle color recognition. Hu et al. [12] were the first to apply a convolutional neural network (CNN) with a spatial pyramid strategy to boost the accuracy of vehicle color recognition. Zhang et al. [15] proposed a lightweight *CNN* for vehicle color recognition. Fu et al. [16] designed *MCFF*-*CNN* (Multiscale Comprehensive Feature Fusion Convolutional Neural Network) to recognize eight vehicle colors. Hu et al. [18] proposed vehicle color recognition based on a Smooth Modulation Neural Network with Multi-Scale Feature Fusion (*SMNN*-*MSFF*).

It is worth mentioning that there has been no research on vehicle color detection under bad weather conditions, which is the focus of this paper.

#### *2.2. Object Detection under Adverse Weather Conditions*

Bad weather includes rain, snow, haze, etc. The quality of outdoor images or videos collected in these weathers is severely degraded, so target-detection models trained on high-quality images have difficulty handling bad weather. This challenge has been investigated by many scholars with many solutions provided, such as embedding a domainadaptation module [25–28,31,33,34] into the object detection backbone, such as *YOLO*, *Faster RCNN*, *RetinaNet*, etc., two-stages methods consisting of preprocessing and object detection [35–40], or using a few-shot transfer learning mechanism [29–32].

For example, Chen et al. [33] embedded two domain adaptation modules into *Faster RCNN* to reduce the domain discrepancy on image level and instance level. Sindagi et al. [31] proposed an unsupervised domain-adapting method to improve generalization of object detection under hazy and rainy conditions. Style transfer is considered in [27], in which the authors construct a cross-domain representation learning method including domain diversification and a multi-domain invariant. Huang et al. [41] combine dual subnet frameworks for object detection under foggy conditions.

Except for the two-stage methods, the above-mentioned methods do not pay special attention to the rainy conditions. However, two-stage methods such as [35–39,42] do pay attention to image deraining instead of object detection. in other words, the joint tasks of deraining and object detection are not taken seriously. Motivated by the above considerations, we propose *JADAR* for joint semantic intelligent detection of vehicle color in rainy scenes.

#### **3.** *JADAR* **Algorithm**

#### *3.1. Fusion Network Design*

In this paper, a Joint Algorithm for Deraining And Recognition (JADAR) is designed for vehicle color recognition in inclement weather conditions; it is based on *RetinaNet*-50, as shown in Figure 1. In Figure 1, *O* is the rainy image input, *B* is the corresponding clean background image, *<sup>B</sup>* and *<sup>y</sup>o*, respectively, are the outputs of the deraining and detecting results. To see results clearly, we zoom in on the recognition results of each car in picture *yo <sup>i</sup>* (*<sup>i</sup>* = 1, 2, 3); *<sup>y</sup><sup>o</sup>* <sup>1</sup> is the enlarged result of the first car in the picture—the recognition color is silver-gray with a confidence level of 0.91; *y<sup>o</sup>* <sup>2</sup> is the enlarged result of the second car in the picture—the recognition color is black with a confidence level of 0.58; *y<sup>o</sup>* <sup>3</sup> is the enlarged result of the third car in the picture—the recognition color is dark gray with a confidence level of 0.81. The green/blue/purple/orange boxes represent the feature extraction module/*UNet-3*/ information fusion module/*class+bbox subnets*, respectively. *Lreg* is the regression loss using the smooth *L1* loss. *Lcls* is the classification loss using the focal loss. The loss function for deraining is *MSE* loss. *JADAR* is trained by the weighted sum of these three losses (see Equation (7)).

The *JADAR* framework is designed by embedding the three-layer decoder of *UNet-3* [43] into the last three sub-blocks of the feature extraction module, as illustrated by the green-tinted box in Figure 1. The whole framework consists of four main modules: image feature extraction module, deraining module, information fusion module, and *class* + *box* subnets. The rain removal and feature extraction modules share three layers, avoiding extra computational burden. In fact, Section 4.5 shows that *JADAR* has the same testing time as *RetinaNet*-50. The last three feature maps and the corresponding recovered feature maps are cascaded together and then fed into their respective *class* + *box* subnets, which can learn multi-scale joint semantic representations to improve object detection accuracy under rainy conditions. The feature fusion sub-module setting is illustrated in Figure 2.

The overall object function is back-propagated to train the deraining module and to improve rainy image deraining performance recursively. The object detection backbone network uses three-scale *class* + *box* subnets to leverage multi-scale fusion color feature maps to classify 24 car color types and locate the bounding-box.

**Figure 2.** Architecture and weights of the proposed network in detail.

#### *3.2. Model Formulation and Model Optimization*

Let the physical mechanism of rainy image corruption be

$$\mathbf{x} = \mathbf{y} + \mathbf{z} \tag{1}$$

where *x*, *y*, *z* denote rainy image *O*, recovered clean background image *B*, and rain layer *R*, respectively. To tackle the problem of supervised vehicle object detection by color in inclement weather conditions, a joint network is proposed to learn joint semantic representation from an input rainy image *x*. Let *y* denote the corresponding label of rainy image *x*.

As demonstrated by the green box in Figure 1, the last three feature maps *f*1(*x*), *f*2(*x*), *f*3(*x*) are taken from the feature extraction sub-blocks of *RetinaNet*. Then, *f*1(*x*) is fed into the corresponding last layer of the decoder of *UNet*-3, and *g*1(*x*) is output. Next, *g*1(*x*) and *f*2(*x*) are cascaded into the penultimate layer *UNet*-3, and *g*2(*x*) is output; then *g*2(*x*) and *f*3(*x*) are cascaded into the last decoder layer of *UNet*-3, and *g*3(*x*) is output. The output of the deraining module *<sup>y</sup>* is denoted by *<sup>g</sup>*3(*x*). Thus, the mean square error (*MSE* loss) is used as deraining object function *Lder* as follows:

$$L\_{dcr} = \frac{1}{n} \sum\_{i=1}^{n} \| (\hat{y} - y) \|^2 \tag{2}$$

where *n* is the number of rainy images. Finally, *f*1(*x*) and *g*1(*x*), *f*2(*x*) and *g*2(*x*), and *f*3(*x*) and *g*3(*x*) are cascaded and input into differently scaled *class* + *box* subnets, where joint semantic information is fused, 24 vehicle colors are classified, and box-bounded regressions are achieved; the last cascading output image is denoted *yo*.

The classification loss function is

$$L\_{cls}(p\_{it}) = -\sum\_{i=1}^{\mathbb{C}} (a\_t(1 - p\_{it})^\gamma \log(p\_{it}) + (1 - a\_t)(p\_{it})^\gamma \log(p\_{it})) \tag{3}$$

where *α<sup>t</sup>* is a balancing factor to balance the uneven proportion of positive and negative examples of every vehicle color category, *C* = 24 denotes the number of all vehicle color categories, *γ* ≥ 0 is a tunable focusing parameter (we take *γ* = 2.0 in Section 4 following [24]), *t* is equal to 0 or 1, which denotes the positive or negative sample, *pi*<sup>1</sup> ∈ [0, 1] denotes the prediction probability of the positive sample of the *i*-th vehicle color class, and 1−*pi*<sup>1</sup> indicates the prediction probability of negative examples of every vehicle color category *i* ∈ {1, 2, ··· , 24}; i.e.,

$$p\_{it} = \begin{cases} \begin{array}{c} rc1p\_i \quad \text{if} \\ 1 - p\_i \quad \text{if} \end{array} \begin{array}{c} t = 1 \\ otherwise \end{array} . \end{cases} \tag{4}$$

The loss function of the box bounding regression is

$$L\_{\rm r\%} = \frac{1}{n} \sum\_{i=1}^{n} L\_{\rm r\%}(i) \,\,,\tag{5}$$

with

$$L\_{\tau\xi\chi}(i) = \begin{cases} \begin{array}{c} 0.5a^2 \quad \text{if} \quad |a| < 1\\ |a| - 0.5 \quad \text{if} \quad \
\text{otherwise} \end{array} \prime \end{cases} \tag{6}$$

where *a* = *ti* − *t* ∗ *<sup>i</sup>* , and *ti* = {*tx*, *ty*, *tw*, *th*}, *t* ∗ *<sup>i</sup>* = {*t* ∗ *<sup>x</sup>*, *t* ∗ *<sup>y</sup>*, *t* ∗ *<sup>w</sup>*, *t* ∗ *<sup>h</sup>*}. Here (*x*, *y*) denotes the center coordinates of the bounding box, *w*/*h* denotes the width / height, and *ti*(*t* ∗ *<sup>i</sup>* ) represents the offset of the prediction box ( the ground truth box).

Now, *Lreg*(*i*) represents the regression loss for the *i*-th image, and *Lreg* represents the total regression loss for all images. The total loss function is then given by

$$L\_{\rm tol} = L\_{\rm cls}(p\_{it}) + L\_{\rm rcg} + \lambda L\_{\rm dcr},\tag{7}$$

where *λ* ∈ [0, 1] is a hyperparameter controlling the strength of the image deraining module's adjustment to the rainy weather target detection performance. In this context, for *λ* = 0.5, *mAP* of the proposed network detection is optimal from many ablation experiments. See Section 4.3 for details.

#### **4. Experiments**

#### *4.1. Experimental Setup*

**Implementation Details**. *JADAR* is trained end-to-end on the *Rain Vehicle Color*-24 image set using the *Adam* optimizer [44] to simultaneously learn image deraining, color classification, and object localization on the *PyTorch* platform. All experiments are implemented on the *AutoDL* platform with a *Tesla P*40. The hyper-parameters *α* and *γ* of the classification loss function *Lcls* are set to 0.25 and 2, respectively. We divide *Rain Vehicle Color*-24 into a training set, a validation set and a testing set at a ratio of 8:1:1. The batch size is 4, the epoch is 100, and the confidence threshold is 0.5. The learning rate is 10−<sup>4</sup> for the first 50 epochs, 10−<sup>5</sup> for the next 30 epochs, and 10−<sup>6</sup> for the last 20 epochs.

**Evaluation Metric**. Generally, object detection uses *IoU* (Intersection over Union) [21], *Precision* (accuracy) [45], *Recall* [45], *AP* (Average Precision) [18], *mAP* (mean Average Precision) [41], or other evaluation metrics; these concepts are well-known, so we list the formulas in brief:

$$IoI = \frac{A \cap B}{A \cup B} \tag{8}$$

where *A*/*B* denotes *GT* (bounding box of the object) / the prediction bounding box. Mathematical definitions of *Precision* and *Recall* are as follows:

$$precision = \frac{TP}{TP + FP} \tag{9}$$

$$recall = \frac{TP}{TP + FN} \tag{10}$$

where *TP* is true positives (correctly predicted as positive), *FP* is false positives (incorrectly predicted as positive), and *FN* is false negatives (failed to predict a positive).

*AP* is calculated by

$$AP = \int\_0^1 p(r) dr\tag{11}$$

where *p* is *Precision*, and *r* is *Recall*.

The *mAP* (mean Average Precision) is the average of *AP*, so *mAP* is calculated by

$$mAP = \frac{\sum\_{1}^{N} AP}{N} \tag{12}$$

where *N* is the number of categories.

#### *4.2. Datasets*

4.2.1. Synthetic Dataset *Rain Vehicle Color*-24

Few datasets are available for vehicle color recognition under rainy weather conditions. All our experiments are conducted on enhanced *Rain Vehicle Color*-24 [46], from which some examples are illustrated in Figure 3.

**Figure 3.** Examples from *Rain Vehicle Color*-24.

4.2.2. Real Rain Vehicle Datasets: *RID* and *RIS*

Li et al. collected two real rainy image vehicle datasets, *RID* and *RIS* [38], for testing object detection. *RID* is rainy images collected from in-vehicle cameras while driving on rainy days, and *RIS* is surveillance rainy images collected from network traffic surveillance cameras during rainy weather conditions. The two datasets differ in many aspects: rainfall type, image quality, target size and angle, etc. They represent real-world application scenarios where deraining may be required. *RID* includes 2495 images, and its rainy image effect is closest to "raindrops" on the camera lens. *RIS* includes 2048 images, and its rainy image effect is closest to "rain and fog" (many cameras have fog condensation when it rains, and lower resolutions also cause more fog effects) [47]. Due to the highly complex scenes of these two rainy image datasets, it is a challenging dataset, and we choose these two datasets for testing to better illustrate the effectiveness of our proposed algorithm. Examples of these two datasets are given in Figure 4.

**Figure 4.** Examples of RID and RIS images [38].

#### *4.3. Ablation Study*

To determine the optimal design of our proposed framework, we train four combinations on the *Rain Vehicle Color*-24 dataset: *RetinaNet*, *JADAR*1, *JADAR*2, and *JADAR*. All these models are trained and tested on *Rain Vehicle Color*-24 using different loss functions: *λ* = 0, 0.1, 1.0, and 0.5, respectively. Figure 5 shows that the testing *mAP* values of the *JADAR*1, *JADAR*2, and *JADAR* models are 2.92%, −3.99%, and 4.3% higher, respectively, than the *RetinaNet* model, which clarifies that joint semantic feature extraction is beneficial to improve vehicle color recognition performance under rainy weather conditions. Referring to Table 1, when the hyper-parameter *λ* is 0.1, the rain removal module provides a weak assisting effect on vehicle color recognition under rainy weather conditions. When *λ* is 1.0, it plays the opposite effect. When *λ* is 0.5, *JADAR* performs best; so we choose this value in our method.

**Figure 5.** *AP*s of different models on the Rain Vehicle Color-24 test set. The *x*-axis represents the average precision, and the *y*-axis represents the color categories (24 categories in total).


**Table 1.** The mAPs using different values for the loss function coefficient *λ* of the deraining module in *JADAR* on the *RVC*-24 test set.

*4.4. Experiments and Analysis*

4.4.1. Results on Synthetic Datasets

In this section, our proposed algorithm, the vehicle color recognition method, the target detection method, the two-stage method combining rain removal with target detection, and the transfer learning method are compared.

To discuss vehicle color recognition performance, *JADAR* and *SMNN*-*MSFF* [18] are compared. Both are trained on *Rain Vehicle Color*-24 training subset and tested on its test subset. The quantitative results are shown in the second column of Table 2. These quantitative results confirm that the *mAP* of our method reaches 72.07%, which is 23.49% higher than *SMNN*-*MSFF*. The qualitative results are shown in Figure 6. *JADAR* outperforms *SMNN*-*MSFF* under rainy conditions; for example, there are five vehicles recognized by *JADAR*, while three vehicles are recognized by *SMNN*-*MSFF*. A white vehicle is recognized by *JADAR* with a confidence score of 0.79, while *SMNN*-*MSFF* recognizes it with a confidence score of 0.62.

To compare object detection performance, *JADAR*, *RetinaNet* [24], *Faster RCNN* [19], *SSD* [20], and *YOLO V*3 [21] are compared qualitatively and evaluated by *mAP* quantitatively. In our experiments, the loss function and settings (i.e., scale, anchor or default box, backbone network, classifier, etc.) of each compared method remains unchanged from the original work. Furthermore, all methods are trained on the *Rain Vehicle Color*-24 dataset and tested on its test set. The qualitative results of *JADAR*, *Faster RCNN*, *YOLO V*3, *SSD*, and *RetinaNet* for vehicle color recognition in rain are shown in Figures 7 and 8. As can be seen from the figures, our proposed *JADAR* outperforms other models for fine vehicle color recognition. The quantitative results show that the proposed *JADAR* is 11.42%, 22.19%, 5.74%, and 4.3% better than *Faster RCNN*, *YOLO V*3, *SSD* and *RetinaNet*, respectively, from Table 2.

**Figure 6.** Test results of *JADAR* and *SMNN*-*MSFF* on the Rain Vehicle Color-24 test set. Each subtitle gives object detection method with the corresponding confidence value in parentheses.

To compare recognition performances of different joint methods, three state-of-the-art rain removal methods, i.e., *LPNet* [35], *PReNet* [48], and *RCDNet* [49]), are chosen to first derain the images, and then *RetinaNet* is leveraged to recognize vehicle colors. These methods are denoted *LR*, *PR*, and *RR*. Figures 9 and 10 give qualitative comparisons of our *JADAR* and three two-stage methods for vehicle color recognition under rainy weather conditions. *JADAR* performs better than other models. From Table 3, our *JADAR* is 15.56%, 20.37%, and 2.06% higher than *LR*, *PR*, and *RR*, respectively.

To compare with transfer learning methods, two domain-adaptation methods, *Daf aster* [33] and *ATF* [50], are compared with *JADAR*. Here, the *VC*-24 is the source domain, and *Rain Vehicle Color*-24 is the target domain; they are leveraged to train the above algorithms. From the 5-th and 6-th columns of Table 3, our method is 25.95% and 9.14% better

than *Da*-*f aster* and *ATF*, respectively. The qualitative results in Figures 9 and 10 show that JADAR identifies more vehicles with higher confidence than the other two methods.

**Figure 7.** Example 1 of test results of JADAR and object detection methods on the Rain Vehicle Color-24 test set.

**Figure 8.** Example 2 of test results of JADAR and object detection methods and domain adaptation methods on the Rain Vehicle Color-24 test set.

**Figure 9.** Example 1 of test results of JADAR and two-stage methods on the Rain Vehicle Color-24 test set.


**Table 2.** Comparison of recognition accuracy of 24 colors for different network classifications: SM, SMNN-MSFF; FR, Faster RCNN; Yolo, Yolo V3; RN, RetinaNet.

**Figure 10.** Example 2 of test results of JADAR and two-stage methods and domain-adaptation methods on the Rain Vehicle Color-24 test set.


**Table 3.** Comparison of recognition accuracy of 24 colors for different network classifications.

4.4.2. Results on Real Datasets

We train *JADAR*, *RetinaNet*, *Faster RCNN*, *SSD*, *YOLO V*3, *LR*, *PR*, *RR*, *Da*-*f aster*, and *ATF* on *Rain Vehicle Color*-24 and test them on real rainy image vehicle datasets, *RID* and *RIS*. The test results are shown in Figures 11–14. As can be seen from these figures, the test results of *JADAR* on the real datasets, *RID* and *RIS*, are generally better than those of other methods. As can be seen from Figure 11, the *JADAR* and *SSD* algorithms can correctly identify the two cars in the picture; *Yolo V*3 can also identify the two cars in the picture, but the black color is mistakenly identified as silver-gray; while the other three algorithms can hardly identify any vehicles in the picture. Referring to Figure 12, because the recognition effects of *Faster RCNN* and *SSD* are better than others', we find a limitation of *JADAR* in recognizing small targets. Referring to Figure 13, all algorithms can identify the color of the vehicle in the image but with different confidence values; specifically, *ATF* has the highest confidence value for blue vehicle, with 0.98. However, Figure 14 shows that only *JADAR* and *ATF* can identify a certain white vehicle.

**Figure 11.** Test results of JADAR and color recognition and object detection methods on the RID.

**Figure 12.** Test results of JADAR and color recognition and object detection methods on the RIS.

**Figure 13.** Test results of JADAR and two-stage and domain-adaptation methods on the RID.

**Figure 14.** Test results of *JADAR* and two-stage and domain-adaptation methods on the *RIS*.

#### *4.5. Inference Time*

In order to compare the test time of all methods, all network models are tested on a testing subset with an input of 1920 × 1080 images. The test times are shown in Table 4. JADAR takes 1.7 s per image on a single Tesla *P*40 GPU, which is the same as for *RetinaNet*, but JADAR is 21.8, 1.1, 4.4, 0.8, and 0.9 seconds faster than *LR*, *PR* and *RR*, *Da*-*f aster*, and *ATF*, respectively. Therefore, although *JADAR* has one more decoder module than *RetinaNet*, it still maintains its original high detection speed.

**Table 4.** Comparison of different network recognition speeds (GPU).


#### **5. Conclusions**

In this paper, we study vehicle color recognition under rainy conditions and propose a joint semantics learning method *JADAR*, which is designed by embedding *UNet*-3 into *RetinaNet*. The *UNet* module achieves rainy image removal and restores the clean background image. The recovered background image and the rainy image are input together into the *class* + *bbox* sub-module of *RetinaNet* network to accurately extract the joint semantic of the vehicle color features maps. *JADAR* outperforms other methods under rainy as well as normal conditions for fine vehicle color recognition. Extensive experimental results show that the *mAP* of the proposed method reaches 72.07% in identifying 24 colors under rainy conditions. Because our algorithm is trained on synthetic datasets, its generalization is not guaranteed. In the future, semi-supervised or few-shot learning is planned to further improve the generalization and realizability of the algorithm. As a further research topic, one can consider fusing overlap functions and fuzzy (rough) sets (see [51–55]) to develop the method of this paper.

**Author Contributions:** Writing—original draft preparation, M.H.; Experiments and editing, Y.W.; review, J.F., writing—review, B.J. All authors have read and agreed to the published version of the manuscript.

**Funding:** This study was funded by the National Natural Science Foundation of China (no. 62071378), the Shaanxi Province International Science and Technology Cooperation Program (no. 2022KW-04), and the Xi'an Science and Technology Plan Project (no. 21XJZZ0072).

**Data Availability Statement:** The data that support the findings of this study are openly available at humingdi2005@github.com.

**Conflicts of Interest:** The authors declare that they have no conflict of interest.

#### **References**


### *Article* **Resolving Cross-Site Scripting Attacks through Fusion Verification and Machine Learning**

**Jiazhong Lu †, Zhitan Wei †, Zhi Qin, Yan Chang and Shibin Zhang \***

**\*** Correspondence: cuitzsb@cuit.edu.cn

† These authors contributed equally to this work as co-first authors.

**Abstract:** The frequent variations of XSS (cross-site scripting) payloads make static and dynamic analysis difficult to detect effectively. In this paper, we proposed a fusion verification method that combines traffic detection with XSS payload detection, using machine learning to detect XSS attacks. In addition, we also proposed seven new payload features to improve detection efficiency. In order to verify the effectiveness of our method, we simulated and tested 20 public CVE (Common Vulnerabilities and Exposures) XSS attacks. The experimental results show that our proposed method has better accuracy than the single traffic detection model. Among them, the recall rate increased by an average of 48%, the F1 score increased by an average of 27.94%, the accuracy rate increased by 9.29%, and the accuracy rate increased by 3.81%. Moreover, the seven new features proposed in this paper account for 34.12% of the total contribution rate of the classifier.

**Keywords:** XSS attack; traffic detection; payloads; fusion verification

**MSC:** 68T09

**Citation:** Lu, J.; Wei, Z.; Qin, Z.; Chang, Y.; Zhang, S. Resolving Cross-Site Scripting Attacks through Fusion Verification and Machine Learning. *Mathematics* **2022**, *10*, 3787. https://doi.org/10.3390/ math10203787

Academic Editors: Jianping Gou, Weihua Ou, Shaoning Zeng and Lan Du

Received: 9 September 2022 Accepted: 8 October 2022 Published: 14 October 2022

**Publisher's Note:** MDPI stays neutral with regard to jurisdictional claims in published maps and institutional affiliations.

**Copyright:** © 2022 by the authors. Licensee MDPI, Basel, Switzerland. This article is an open access article distributed under the terms and conditions of the Creative Commons Attribution (CC BY) license (https:// creativecommons.org/licenses/by/ 4.0/).

#### **1. Introduction**

XSS (cross-site scripting) attacks have caused enormous damage to economics and individual privacy [1]. Moreover, XSS attacks have been adjusted from the seventh to the third most common in the newly released 2021 version of OWASP (Open Web Application Security Project) Top 10 [2].

Normally, there are three types of XSS attack, namely reflected XSS attack, stored XSS attack, and DOM-based XSS attack. These three attack types usually use the GET or POST methods of the HTTP protocol to inject malicious code at the URL or POST Body. Reflected XSS usually injects malicious code into the URL, which can only be triggered in the current browser and does not store malicious code permanently. The malicious code of stored XSS is injected into the server-side database through vulnerabilities, which can cause long-term information leakage and other hazards. In fact, we can think of DOM-based XSS as a special kind of reflected XSS. Its malicious code can only be triggered in the current browser when it runs the script on the client side for front-end page rendering.

In general, there are two popular methods to defend against XSS attack: static analysis and dynamic analysis. Static analysis finds vulnerabilities by scanning the source code to analyze information such as lexical, grammar, control flow, data flow, and other information. It is in the development and coding phase of the program that requires developers to master a lot of security-related knowledge. Dynamic analysis inputs test data during program execution and analyze the output information to determine whether there are loopholes. However, this method relies on the completeness of the test data.

In the face of frequent variations in XSS payloads, it is hard for traditional XSS detection to have a pleasing result. There are some factors that have a significant impact on the results. For example, traditional XSS detection requires a large number of manual participation and the integrity of the attack vector.

School of Cybersecurity, Chengdu University of Information Technology, Chengdu 610225, China

Recently, machine learning techniques have been widely used in XSS attack detection and achieved good results. However, most of the detection approaches based on machine learning only focus on one of the traffic or XSS payloads. On the one hand, traffic detection has certain timeliness, but it is difficult to accurately detect and identify XSS attacks. On the other hand, XSS payload detection has a certain degree of accuracy, but it lacks timeliness. Another reason for this may be that there is currently no public dataset that includes both normal traffic and XSS attack traffic (the only type of attack in the attack traffic is XSS).

As a result, a lot of XSS detection methods cannot meet the dual requirements of timeliness and accuracy in real environments, and the pros and cons of a single model will directly affect the performance of the entire detection model. This leads to the problem of low accuracy and a high false negative rate for a single model.

The primary contribution of this paper is to propose a fusion verification method that combines traffic detection and XSS payload detection. Previously, both traffic detection and XSS payload detection have been separately applied to XSS detection. However, to the best of our knowledge, fusion verification methods combining the two methods have not been reported in the literature for detecting XSS attacks. The main contributions of this paper are:


#### **2. Related Work**

For the study of XSS attack detection, network security researchers have successively put forward some effective detection methods and preventive measures.

In terms of static analysis, Medeiros et al. [3] proposed a cross-site scripting vulnerability detection method combining static source code analysis and data mining in 2015. The accuracy of XSS vulnerability detection and the effect of fixing code as improved by this method, but the disadvantage was false positives. Choi et al. [4] proposed an HXD (Hybrid XSS Detection) system. The system used both static string analysis and dynamic browser rendering with a black-box detection approach. Experimental results showed that HXD had a low false positive rate. Mohammadiet al. [5] detected XSS vulnerabilities through an automatic unit testing method. They preferred to automatically construct an XSS vulnerability unit test from each web page; the test input pair framework was then automatically generated using a grammar-based attack generator, which was then evaluated. The proposed method reduced the error rate of XSS vulnerabilities. In 2019, YAN et al. [6] proposed a PHP code vulnerability detection method based on sensitive path and taint analysis. The method first converted the background code of the web application into the intermediate representation of the code, such as the abstract syntax tree, then found the slot (dangerous function), then determined the sensitive path through the slot, and finally performed taint analysis on this path to determine whether the vulnerability exists. However, the disadvantages of static analysis were obvious, it relied on a lot of manual work by human experts with knowledge of both programming and security domains, and the source code was usually not open-source.

In terms of dynamic analysis, Parameshwaran et al. [7] designed a DOM-based XSS test platform, which was based on taint analysis in 2015. The platform included a vulnerability generator and a detection engine. Experiments showed that the method had an excellent effect on detecting DOM-based XSS attacks. Wang et al. [8] proposed a TT-XSS framework to detect DOM-based XSS using dynamic taint analysis. The application dynamically analyzed the collected URLs that were then sent to the taint tracking analysis module, the obtained taint trajectories were sent to the automatic vulnerability verification module, and the verification module was completed by generating attack vectors from taint trajectories. In 2021, Khalaf et al. [9] proposed an algorithm that allowed attack detection and prevention using an input validation mechanism. This approach supported web security testing by providing an easy-to-use and accurate vulnerability prediction model and validation method, which had the advantage of having a very low false positive rate. However, this method relied on the completeness of the testing dataset. If the testing dataset was not perfect or faced deformation attacks, it would produce a high false negative rate. In addition, this is a common problem for all dynamic analyses.

In recent years, zero-day attacks and deformation attacks are common, and it is difficult for traditional static analysis and dynamic analysis to play an effective role in XSS detection. Therefore, a large number of scholars have introduced machine learning technology for XSS detection and achieved good results. Zuhair et al. [10] also extracted features from Web pages and URLs but made a mixed feature subset division, combined with phishing attacks, and finally used the SVM algorithm for training and testing. Rathore et al. [11] proposed a machine learning method based on URLs, web pages, and SNSs to detect XSS attacks in 2017, extracted twenty-five XSS attack features, and used ten classifiers for detection. To achieve better performance, Hosseini et al. [12] proposed a model for detecting malicious crawler behavior using machine learning techniques and tested and compared several machine learning algorithms, such as Bayesian networks, SVM, and decision trees. Finally, in this experiment, it was found that the SVM-based model had higher detection accuracy for malicious crawlers and extracting effective features could improve the detection accuracy. In 2021, Hu et al. [13] designed and implemented an XSS attack detection model for web applications. This model added the verification code recognition function to solve the problem of submitting data to the server just by entering the verification code; this model had a low false positive rate. Malviya et al. [14] developed a web browser for machine learning classification to mitigate XSS attacks. Experimental results showed that the proposed method outperforms other proposed methods in classification accuracy, recall, precision, and F1-score. Mokbal et al. [15] proposed a novel XSS attack detection framework based on the ensemble learning technique for web applications, which used the XG boost (Extreme Gradient Boosting) algorithm and the extreme parameter optimization method. The proposed framework passed multiple tests on the testing dataset, and the accuracy could reach 99.59%. Soltani et al. [16] proposed a framework for a DID (Deep Intrusion Detection) system. The authors deployed and evaluated offline IDS (Intrusion Detection System) following this framework. Experiments showed that the evaluation indicators, such as the precision rate and recall rate, of this method, reached 0.992 and 0.998, respectively. In addition, the shortage of high-quality data has always been a key problem in machine learning. Multi-fidelity classification algorithms [17–19] solve this type of problem by incorporating information from other sources that can be obtained at a low cost while maintaining good correlation. In this regard, it can also be applied to the XSS attack detection model in the future to improve the generalization ability of the model.

Our previous work [20] can detect XSS attacks more accurately by using machine learning to jointly detect traffic and logs and at the same time, trace the process of XSS attacks in the entire network, but it needs to collect a large number of network device logs for analysis.

To sum up, the current XSS attack detection approaches still have the following problems:


Therefore, this paper focuses on developing a fusion verification method. We obtain a real-world experimental dataset by simulating XSS vulnerabilities in CVE (Common Vulnerabilities and Exposures) and capturing network traffic on the web server side. Then we combined traffic detection with XSS payload detection to form a fusion verification method to defend against XSS attacks. Moreover, this method combines the timeliness advantages of traffic detection and the accuracy advantages of payload detection. We expect that this method can improve the performance of detection models and solve the problems that existing solutions have that make it difficult to meet actual needs.

#### **3. Proposed Methodology**

Figure 1 shows the overall framework for detecting the XSS proposed in this paper. First, the original dataset of the experiment is obtained by reproducing the CVE vulnerability. Then we use the rdpcap function of the Scapy library in Python to read the pcap file of the original dataset and summarize the data according to the upstream and downstream of the two-way communication. In addition, it is divided into two detection modules, one is traffic detection and the other is XSS payload detection. We extract the traffic dataset and the payload dataset separately through different modules. The two modules perform preprocessing and feature extraction, respectively, to form a data format that can be recognized by machine learning input. Next, the two modules separately perform preprocessing and feature extraction to form a recognizable data format for machine learning input and send it to the classifier for detection. Due to the particularity of the traffic itself, we found that each flow in the pcap packet corresponds to multiple payloads at the same time, and each result of the traffic detection module may correspond to the results of multiple payload detection modules. Therefore, we can combine the results of the two modules by matching the source port feature (src\_port) common to both detection modules. Finally, the final detection result is obtained through the fusion verification of the two detection models so as to improve the detection performance of the entire model.

**Figure 1.** The framework of the proposed method.

#### *3.1. CVE Vulnerability Set*

This paper targets the widely used content management system—WordPress [21] (43.0% of websites worldwide use WordPress). From NVD [22] (National Vulnerability Database), we have selected 10 recent XSS vulnerabilities for both reflected XSS and stored XSS. The specific CVE list is shown in Table 1. Then, the original dataset is formed by simulating locally and using WireShark [23] to capture the traffic packets of the reproduced process, and the dataset format is pcap packet.


#### *3.2. Traffic Features Extraction*

After detection and analysis, it is found that the XSS attack traffic is different from the normal traffic. Since XSS attack traffic not only needs to load normal web pages but also needs to load malicious js files or external malicious links, resulting in extra network resources and system resources. Thus, the packets of XSS attack traffic are generally larger.

A total of 1947 flows have been extracted for analysis in this paper. Two types of features have been used for learning: traffic-related features and time-related features, which help the classifier to distinguish between normal traffic and XSS attack traffic. Moreover, the traffic-related features include the five-tuple features of the communication process (due to the particularity of the format of the IP address itself, the source IP address and the destination IP address are omitted), as shown in Table 2. This experiment has used enough traffic to reflect the real network environment and real traffic features.



#### *3.3. Payload Features Extraction*

In this section, after an in-depth study of XSS attack methods and causes, we have summarized three representative attack methods from the attackers' point of view. Then we extracted seven attribute features of the payloads according to the summarized attack methods.

#### 3.3.1. XSS Attack Methods


**Table 3.** Some on-events and descriptions.


Table 4 shows examples of three XSS attack methods:

#### **Table 4.** XSS attack examples.


#### 3.3.2. Attribute Features

Usually, experienced attackers will change the encoding or capitalization of malicious code to carry out deformation attacks. Therefore, this paper has preprocessed the extracted sentences to convert them into original sentences. The preprocessing includes lowercase conversion, URL decoding, HTML decoding, JavaScript decoding, ASCII decoding, Unicode decoding, and URL decoding twice. Values are then extracted from the processed dataset to fit the features proposed in this paper.

Through extensive research on XSS attack methods and analysis of their lexical features, we have found that text characters commonly found in malicious code are often combined with certain fixed symbols. Therefore, matching the combined form can reduce the detection of false positive rate compared to just matching text characters. The following seven attribute features are summarized:

(1) HTML\_Tags

HTML tags in XSS attacks typically appear more frequently than text loads in normal traffic. In HTML tags, the label starts with a left angle bracket. For example, <script, <iframe, and <img in Table 5 appear in the form of left angle brackets plus script, iframe, and img characters. Therefore, the combination of the left angle bracket and the label character is classified into a class of features.

**Table 5.** Seven new attribute features.


#### (2) JavaScript

The JavaScript pseudo-protocol is usually combined with HTML tags to form malicious code, such as <iframe src="javascript:alert('xss')">, where the code feature that will always appear is "javascript:".

(3) On\_Event

HTML5 allows browsers to trigger scripts through various events. For example, in the malicious code "<img src=#onerror="alert(document.cookie)">", the attacker deliberately sets the src attribute of the img tag to be wrong and then uses the onerror event (run the script when an error occurs) to trigger the malicious script. Therefore, the alert function is triggered here, causing the cookie to be leaked. The features of the event attribute are the form of the on-event followed by an equal sign, such as "onerror=".

(4) Function\_Body

Attackers can use some "dangerous functions" in JavaScript to steal sensitive information. For example, the "alert()" function is often used to pop up a dialog box. If an attacker combines it with the document object, the purpose of stealing cookies can be achieved. The code feature of the JavaScript function body is "alert()", which is obviously different from ordinary characters.

(5) Document\_Object

The document object is the root node of the HTML document. An attacker can use the "document.write" property to write JavaScript code to the document or use "document.cookie" to return all cookies associated with the current document. Its code feature is "document."

(6) Third\_Party\_Links

In order to better conceal cross-site scripting attacks, experienced attackers will build an XSS attack server to receive and store the stolen sensitive information. As a result, there will be third-party links in the attack traffic, which are mostly characterized by a combination of src or href and third-party links.

(7) Delimiter

Delimiters, such as spaces, are unavoidably used within HTML tags due to the grammatical nature of HTML. Therefore, attackers must use delimiters to construct attack statements when exploiting cross-site scripting vulnerabilities. "space", "/", and"+" are known to be used as delimiters for malicious code.

In this paper, we have added 7 new features to the 30 features extracted by Zhou and Wang (2019) [1], totaling 37 attribute features of the XSS payloads. Table 5 shows the seven new attribute features added in this paper:

#### *3.4. Fusion Verification*

Both the traffic detection module and the XSS payload detection module can present the malicious or normal status of the current stream or payload in binary form. In this paper, we have adopted the fusion verification method. If either of the two detection modules

declares that the current detection sample is malicious, it is considered to be malicious. In addition, if both of them declare that it is normal, it is considered to be normal.

In this paper, Boolean variables *Fv* and *Pv* are used to represent the detection results of the traffic detection module and the detection results of the XSS payload detection module, respectively. The Boolean variable *Rs* is used to represent the final result of fusion verification, and its calculation formula is as follows:

$$Rs = Fv \lor Pv \tag{1}$$

It is easy to know from Formula (1) that there are four cases in total. In these four cases, the final result is normal only when the traffic detection determines that it is normal and the payload detection determines that it is normal. In other cases, the result is judged to be malicious.

#### *3.5. Random Forest*

The research of a large number of scholars shows that the ensemble method has good performance in classification performance and robustness in the face of overfitting. Therefore, these kinds of algorithms are very popular in the field of machine learning. In this paper, the random forest algorithm has been used as the classification technique of the experiment. Random forest is an ensemble algorithm based on decision tree, which not only has good scalability but is also easy to use. The principle of random forest is to build a strong model with better generalization performance and less overfitting by separately averaging multiple decision trees affected by large variance.

In this paper, the random forest algorithm has been used as the classifier. The random forest algorithm does not need to worry about the choice of hyperparameter values, and pruning it is usually not necessary because of its strong resistance to noise from a single decision tree. In this experiment, we have taken the size of the training dataset as the size, n, of the bootstrap samples in order to obtain a better bias-variance tradeoff. We set the number of features, d, in each split to a value less than the total number of features in the training dataset. We have used the random forest classifier already implemented by scikit-learn with relatively reasonable parameter settings. The default value is <sup>d</sup> <sup>=</sup> <sup>√</sup>*m*, where m represents the total number of features in the training dataset. Additionally, we have chosen entropy as the criterion used for splitting nodes. We have set the value of the n\_estimators parameter of the number of decision trees to 100. Because when the n\_estimators parameter reaches 100, the accuracy of the model no longer increases. We have set the number of parallel computations, n\_jobs, to 10 to use the multi-core computer parallel computing model.

#### **4. Experiments and Discussions**

#### *4.1. Experimental Dataset*

This paper has formed a traffic dataset containing normal traffic and XSS attack traffic by simulating the CVE. This dataset is called "CVE traffic". "CVE traffic" contains 1747 normal traffic and 200 XSS attack traffic. Then we used Scapy's rdpcap function to extract the XSS payload dataset, referred to as "CVE payloads". It contains 10083 normal records and 231 XSS payloads.

XSS payloads [24] have been collected from GitHub and used as a training dataset with a total of 151,658 records, including 135,507 normal records and 16,151 XSS payloads. The testing dataset has been extracted from the traffic dataset above through the rdpcap function of the Scapy library.

The specific information on the experimental datasets is shown in Table 6


**Table 6.** Experimental dataset.

#### *4.2. Experimental Results*

This experiment uses twenty-fold cross-validations to assess the performance of the model. In this method, 19 of the 20 CVE traffic datasets are used as training datasets, and the remaining one is used as the test dataset. Additionally, each of the 20 subsets is only used once as a test dataset. The cross-validation process has been repeated 20 times, and the average of the twenty results for each CVE are taken as the result of this experiment. Then, we used the fusion verification method mentioned in Section 3.4 of this paper to take the average of 20 results for each CVE traffic detection result and XSS payload detection result as the final result of our method.

This experiment aimed to solve a typical binary classification problem. As shown in Table 7, we use a confusion matrix to represent the results.

#### **Table 7.** Confusion matrix.


The confusion matrix is divided into four categories. TP (True Positive) means the number of correctly classified as attack samples, and FP (False Positive) means the number of normal samples classified as attack samples. In addition, TN (True Negative) means the number of correctly classified as normal samples, and FN (False Negative) means the number of attack samples classified as normal samples. This paper evaluates the accuracy, precision, recall, and F1 score of the experimental results. The calculation formulae are as follows:

$$Accuracy = \frac{TP + TN}{TP + TN + FP + FN} \tag{2}$$

$$Precision = \frac{TP}{TP + FP} \tag{3}$$

$$Recall = \frac{TP}{TP + FN} \tag{4}$$

$$F1 = \frac{2 \times Precision \times Recall}{Precision + Recall} \tag{5}$$

The experimental results are shown in Figure 2.

**Figure 2.** *Cont*.

**Figure 2.** *Cont*.

**Figure 2.** Experimental results (**a**–**t**).

As can be seen from Figure 2, in 17 out of 20 CVE experiments, the recall of our fusion verification method can reach an astonishing 100%. We can know from Figure 3 that under such a high recall rate, our accuracy is an average of 94.9%, which also remains at a high level. Therefore, the fusion verification method can effectively defend against XSS attacks. In addition, as shown in Figure 3, the average accuracy, precision, recall, and F1 score of this method are significantly improved compared to the single traffic detection model. Among them, the average improvement in the recall rate is as high as 48%, the average increase in F1 score is as high as 27.94%, the average increase in precision is 9.29%, and the average increase in accuracy rate is 3.81%. The results show that our proposed fusion verification model outperforms the single traffic detection model. However, the number of experimental samples in the load detection process is relatively small. Therefore, the performance of a few fusion validation models is slightly lower than that of a single

detection model. In this regard, we consider using multi-fidelity classification algorithms in future research and experiments to solve the problem caused by fewer training samples.

In addition, taking XSS payloads [24] as the dataset, with a ratio of 7:3 between the training set and test set, random forest is used to evaluate the importance of 37 features of the payload detection link used in this paper. Among them, there are 30 features whose contribution rate is larger than 0.01%, as shown in Figure 4. The first is the feature "Function\_Body" proposed in this paper, whose contribution rate is as high as 23.95%. Moreover, the total contribution rate of the seven features in this paper is as high as 34.12%. This means that it is feasible to extract detection features by summarizing XSS attack methods in this paper, and it has better generalization and can detect variations in attacks more effectively.

**Figure 4.** Assess the importance of the features of Zhou and Wang (2019) [1] and the newly proposed features in this paper.

**Figure 3.** Average performance comparison.

#### **5. Conclusions**

We propose a fusion verification method that combines traffic detection with XSS payload detection to effectively detect XSS attacks. The results show that the method proposed in this paper has significant advantages for reducing the false negative rate of the model. Under the premise of uniform sample distribution, there will be almost no false negatives. Therefore, the fusion verification method can effectively defend against XSS attacks. Moreover, compared with the traditional single-flow detection model, the average recall rate of this method, F1 score, precision, and accuracy rate is increased by 48%, 27.94%, 9.29%, and 3.81%, respectively. Further, the seven new features of the XSS payloads proposed in this paper account for 34.12% of the total contribution rate of the 37 features.

However, the method proposed in this paper has certain limitations. The cost of keeping the false negative rate low is that the false positive rate of the entire model will increase. In the follow-up research, we will try to solve the existing problem.

**Author Contributions:** Data curation, Z.Q.; Resources, Y.C.; Software, Z.W.; Writing—original draft, S.Z.; Writing—review & editing, J.L. All authors have read and agreed to the published version of the manuscript.

**Funding:** This work was supported by the National Natural Science Foundation of China (Grant No. 62102049). Secondly, thanks to the Key Research and Development Project of Sichuan Province (No. 2022YFS0571, No. 2021YFSY0012, No. 2020YFG0307).

**Institutional Review Board Statement:** Not applicable.

**Informed Consent Statement:** Not applicable.

**Data Availability Statement:** Not applicable.

**Conflicts of Interest:** The authors declare that they have no known competing financial interests or personal relationships that could appear to influence the work reported in this paper.

#### **References**


### *Article* **Tensor Affinity Learning for Hyperorder Graph Matching**

**Zhongyang Wang, Yahong Wu and Feng Liu \***

School of Communications and Information Engineering, Nanjing University of Posts and Telecommunications, Nanjing 210003, China

**\*** Correspondence: liuf@njupt.edu.cn

**Abstract:** Hypergraph matching has been attractive in the application of computer vision in recent years. The interference of external factors, such as squeezing, pulling, occlusion, and noise, results in the same target displaying different image characteristics under different influencing factors. After extracting the image feature point description, the traditional method directly measures the feature description using distance measurement methods such as Euclidean distance, cosine distance, and Manhattan distance, which lack a sufficient generalization ability and negatively impact the accuracy and effectiveness of matching. This paper proposes a metric-learning-based hypergraph matching (MLGM) approach that employs metric learning to express the similarity relationship between high-order image descriptors and learns a new metric function based on scene requirements and target characteristics. The experimental results show that our proposed method performs better than state-of-the-art algorithms on both synthetic and natural images.

**Keywords:** hypergraph matching; similarity metric; information-theoretic metric learning

**MSC:** 68T20

#### **1. Introduction**

Graph matching has been applied in a variety of fields, including biological applications [1], remote sensing image recognition [2], and image retrieval [3]. The key to graph matching is to find correspondences between image visual features using particular algorithms. Graph matching is typically viewed as a quadratic assignment problem (QAP) [4], and since the quadratic objective function is also non-convex, obtaining the global optimal value is challenging [5]. Various approximation algorithms have been developed to settle them under relatively relaxed conditions. Ref. [6] proposed a matching approach based on linear programming. In [7], semidefinite programming was used to solve such a problem, and [8] adopted a similar strategy. However, these algorithms are locally optimal in the discrete domain, and discretization can cause extra errors. There are other methods based on tree search that focus on the suboptimality; for instance, Sanfeliu improved the previous method by considering the joint probability of points and edges in [9]. In [10], it is shown that random walk-based models greatly enhance the graph topological features. A. Robles-Kelly [11] introduced a novel algorithm based on the relationship between the adjacent matrix of the two graphs and their stationary distribution.

Matching-based techniques have been adopted in a variety of study fields. Early classification based on sparse representation (SRC) [12] is not satisfactory in the treatment of occlusion. With the development of multiview non-negative matrix factorization (NMF) methods [13], the local geometry is preserved while global representation under a global alignment strategy is obtained. However, these methods are still affected by various noises and cannot highlight the target characteristics. In [14], Ou et al. proposed a method that used adaptively estimated occlusion information and robustly selected features to improve the performance of facial recognition. The K-nearest neighbor (KNN) is also a nonparametric classifier that is widely used in pattern recognition. However, the performance

**Citation:** Wang, Z.; Wu, Y.; Liu, F. Tensor Affinity Learning for Hyperorder Graph Matching. *Mathematics* **2022**, *10*, 3806. https:// doi.org/10.3390/math10203806

Academic Editor: Catalin Stoean

Received: 17 July 2022 Accepted: 12 October 2022 Published: 15 October 2022

**Publisher's Note:** MDPI stays neutral with regard to jurisdictional claims in published maps and institutional affiliations.

**Copyright:** © 2022 by the authors. Licensee MDPI, Basel, Switzerland. This article is an open access article distributed under the terms and conditions of the Creative Commons Attribution (CC BY) license (https:// creativecommons.org/licenses/by/ 4.0/).

of KNN-based classification is severely affected by the sensitivity of the neighbourhood size, especially when the sample size is small and there are outliers. Refs. [15,16] improve this problem by weighting and averaging, and have a robust and effective classification performance.

In recent years, high-order graph matching algorithms have put a large amount of attention on the better fusion of structural similarities in order to improve the matching accuracy. Zass and Shashua [17] proposed a probabilistic setting-based hypergraph matching approach that uses an iterative successive projection process to find the global optimal solution. Lee et al. [18] expanded the reweighted random walk approach to hypergraph matching and probabilistically reinterpreted the idea of random walk on hypergraphs. However, these matching algorithms employ Euclidean distance to generate an affinity matrix, and each feature attribute is regarded as being independent from the others. Traditional methods lack a specific metric for feature description, their performance on different types of data is highly variable, and their overall accuracy is low. In this paper, we present an improved hypergraph matching algorithm on the basis of the metric learning theory. By learning the training dataset, it obtains a Mahalanobis matrix that is used to consummate the affinity formulation. Since the Mahalanobis distance is a measure considering the correlation between feature descriptions as well as scaling relations, the assignment matrix obtained by this method would be closer to the ground truth; in fact, experiments show that our algorithm can improve the accuracy of their matching results. The main contributions of this manuscript are described as follows:


The rest of this article is structured as follows. Section 2 begins with an overview of graph-matching-related work. Section 3 presents the proposed model as well as the generic formula for graph matching. Section 4 develops the MLGM approach for optimizing the suggested graph matching model. Section 5 evaluates and analyzes the experimental results of the proposed method on synthetic and natural picture benchmarks. The final section is the conclusion.

#### **2. Related Works**

In the last few years, spectral methods have developed into one of the most representative algorithms in graph matching. The eigenvalues of a matrix do not change when its rows and columns are shuffled; thus, we can utilize this fact to find the adjacency matrix that has the same eigenvalues between similar pictures. Earlier, the spectral method was applied to perform feature matching [19]. Ref. [20] introduced a method incorporating the grey level information around the feature points to improve the matching accuracy. Leordeanu et al. [8] proposed a matching algorithm by building an affinity matrix of the feature points that considered the effect of different weight functions in point matching. In another direction, the grouping method can also improve the effectiveness of matching. Egozi et al. [21] proposed a probabilistic interpretation of spectral matching schemes and developed a unique probabilistic matching (PM) scheme that outperforms earlier methods. Feature matching carried out by means of alternating the embedding and matching of the adjacency spectrum was introduced in [22], and, in [23], a relaxation scheme with matching constraints was proposed. Duchenne et al. [9] proposed a class algorithm that uses a tensor to represent similarities between higher-order features, and the graph can be matched after rank-one decomposition of the similarity tensor. This algorithm extends the spectral method to the hypergraph, and it has been further improved through research [24]. The adjacency spectrum optimization of undirected weighted graphs [25] and the approximation of the proximal matrix spectrum of undirected weighted graphs [26] have been developed in recent years, and results have been gained in image processing applications.

The graph edit distance (GED), which represents the matching link between nodes and edges in two graphs, was utilized to solve graph matching. For example, Ref. [27] proposed a self-organizing mapping algorithm to learn the distance, which makes the distance between similar images smaller, and this method was improved in [28]. Serratosa proposed a method based on an adaptive learning paradigm, which was improved in [29]. Andreas Fischer and Kaspar Riesen presented an algorithm [30] combining Hausdorff matching with greedy allocation to improve the quadratic time approximation of GED.

Metric learning has been widely used in face recognition [31,32], image retrieval [33], re-recognition [34], and other fields. For traditional metric methods such as Euclidean distance, it is challenging to capture the structure of diverse data sets. In order to increase the performance of classification models, it is important to learn a specific measure for various data sets, which is the objective of metric learning. The algorithm for metric learning based on the Mahalanobis distance is still the primary focus of metric learning research at the present time. Bohne et al. [35] proposed dividing the data and learning a metric for each cluster, and Wang et al. [36] suggested learning a set of basis metrics and a set of weights for each sample.

#### **3. Problem Statement**

As a mathematical expression of a relationship, a graph model [37] is composed of a node set and edge set. Finding the corresponding relationship between two graphs is the objective of graph matching. Generally, it seeks the relationship between nodes in graphs, and the specific node may be a pixel, a graph area, or a feature point. In the study of the graph matching and hypergraph matching algorithm, in order to express the relationship between graph features more comprehensively, the structure information of graph model is used to represent the problem in graph matching. Figure 1 depicts a graph matching example diagram.

**Figure 1.** Graph matching schematic diagram.

We now consider two sets of feature points *P* = {*p*1, *p*<sup>2</sup> ..., *pm*} and *Q* = {*q*1, *q*<sup>2</sup> ..., *qn*}, which are extracted from graphs *A* and *B*, respectively. The number of points obtained in each graph is *m* and *n*, which can be the same or different. In the high-order graph matching algorithm [9], what is different from the previous methods is that it matches a tuple of points instead of one to one or pair to pair, and *k* is used here to represent the number of points in each tuple. High-order graph matching has a good robustness under unfavorable conditions such as noise deformation and the rotation of external points [38], but it requires more space and time complexity. The third order can reflect the invariance of the similarity transformation

in the field of computer vision, and, as the smallest higher-order topology, it can measure the subtle differences between high-order graphs. For convenience, only third-order graph matching is discussed in this paper, and it is straightforward to generalize to higher-order potentials.

The matching problem of the two graphs is to compute an optimal assignment relationship between points. Mathematically, this is the equivalent of finding an *m* × *n* assignment matrix *X*. If the feature point *pi* in *P* matches *qj* in *Q*, then the corresponding *Xi*,*<sup>j</sup>* is equal to 1; otherwise, it is 0. In this paper, we assumed that each feature point in *P* can match zero or more feature points in *Q*, but each point in *Q* can match only one point in *P*. As a result, the set of assignment matrix *X* can be denoted as X .

$$\mathcal{X} = \left\{ X \mid X \in \{0, 1\}^{m \times n}, \forall i, \sum\_{j=1}^{n} X\_{i,j} = 1 \right\} \tag{1}$$

where *i* ∈ [1, *m*].

The universal second-order graph matching model can be used to solve *X* as

$$\max\_{X} \text{Score}(X) = \sum\_{i\_1, i\_2, j\_1, j\_2} M\_{i\_1, i\_2, j\_1, j\_2} X\_{i\_1, j\_1} X\_{i\_2, j\_2} \tag{2}$$

where *M* is an affinity tensor and represents the affinity relationship between point pairs (*i*1, *j*1) and (*i*2, *j*2). *Score*(*X*) is the sum of the affinity values of all of the matched tuples; the higher the value corresponds to, the more precise the matching result. Establishing the affinity tensor *M* for two graphs *A* and *B* requires taking into account the similarity between pairs of nodes and pairs of edges.

$$M\_{i\_1, i\_2, j\_1, j\_2} = \exp\left(-\gamma \left| \left| f\_{i\_1 j\_1} - f\_{i\_2 j\_2} \right| \right| \right) \tag{3}$$

*f* is the feature of each tuple, which is represented by Euclidean distance between points in second-order graph matching, and *γ* is the parameter [9].

However, the second-order graph matching model can only express paired relations, which are not resistant to scale changes and difficult to express higher-order feature information. Considering the high-order relation of feature points, we describe the similarity between feature point sets based on the relation between point tuples. Given two point sets *P* and *Q*, the affinity tensor can be expressed as

$$M\_{i\_1, i\_2, j\_1, j\_2, k\_1, k\_2} = \exp\left(-\zeta \left| \left| f\_{i\_1, j\_1, k\_1} - f\_{i\_2, j\_2, k\_2} \right| \right|^2\right) \tag{4}$$

where *i*1, *j*1, *k*<sup>1</sup> represent the point tuples in point set *P* of graph *A*, *i*2, *j*2, *k*<sup>2</sup> represent the three potential point tuples to be matched in point set *Q* of graph *B*, *ξ* is a constant that controls the distribution of the intimacy tensor value, and *fi*1,*j*1,*k*<sup>1</sup> and *fi*2,*j*2,*k*<sup>2</sup> represent the vectors constructed from the feature information of point tuples in point set *P* and *Q*, respectively. There are numerous ways to represent feature information; for ease of calculation, we use the sine value of the inner angle of the triangle formed by point tuples.

Similar to model (2), the high-order graph matching model can be formulated as

$$\max\_{X} Score(X) = \sum\_{i\_1, i\_2, j\_1, j\_2, k\_1, k\_2} M\_{i\_1, i\_2, j\_1, j\_2, k\_1, k\_2} X\_{i\_1, i\_2} X\_{j\_1, j\_2} X\_{k\_1, k\_2} \tag{5}$$

In (5), only when point tuples (*i*1, *j*1, *k*1) match (*i*2, *j*2, *k*2) separately does *Xi*1,*i*<sup>2</sup>*Xj*1,*j*<sup>2</sup>*Xk*1,*k*<sup>2</sup> equal 1. This is an optimal solution problem; by finding the assignment matrix corresponding to the maximum value of *Score*(*X*), the matching relation between tuples can be obtained. We can also write (5) as (6) by using the notation of tensor–vector multiplication:

$$\max\_{\vec{X}} \text{Score}(\vec{X}) = \vec{M} \otimes\_{\vec{3}} \vec{X} \otimes\_{\vec{2}} \vec{X} \otimes\_{1} \vec{X} \tag{6}$$

where *X*˜ represents the vector created by combining the *X* columns, *M*˜ stands for the symmetric matrix produced by tensor expansion, and *I*, *J*, and *K* are the three dimensions of tensor *M*.

In the traditional affinity measure, Function (4), each feature is considered to be of the same importance. However, because these features may have different correlations with sample categories, their weights need to be reconsidered. In other words, a suitable distance or similarity measure based on the feature space of the sample should be used to measure the difference in the sample. Due to its two characteristics, decoupling and dimensionality independence, Mahalanobis distance [39] is an excellent measurement function for image processing and computer vision. In this paper, we used the Mahalanobis distance function to measure the affinity of feature vectors and create the appropriate metric learning model.

#### **4. Tensor Affinity Learning for Hyperorder Graph Matching**

#### *4.1. A Short Introduction to Metric Learning*

The study of metric learning has significant theoretical implications. Metric learning is concerned with developing an accurate function model for an input feature vector and obtaining an accurate similarity measure by learning the model's parameters. It can enhance the performance of the classifier by generating similarity relationships with high accuracy [40]. However, how to accurately measure the similarity of samples affected by different external factors is overlooked. The simple normalization method is used to preprocess data samples in classical learning algorithms, and then Euclidean distance is used to measure similarity. These normalization and measurement methods are crude, and the resulting classifier's performance is easily influenced by noise and interference.

Euclidean distance is a representative distance metric function, defined as

$$d(\mathbf{x}\_1, \mathbf{x}\_2) = \sqrt{(\mathbf{x}\_1 - \mathbf{x}\_2)^T (\mathbf{x}\_1 - \mathbf{x}\_2)} \tag{7}$$

where *x*1, *x*<sup>2</sup> are paired sets of samples. Although Euclidean distance is simple to understand, it has several flaws. It treats the differences between the feature vectors of the sample as the same, which is incompatible with the application requirements of high-order graph matching. Another limitation of Euclidean distance is that it cannot handle data coupling relationships. When calculating the similarity of point tuples, for example, it is necessary to consider how points and edges are related to each other via the global structure formed.

To improve the deficiencies of traditional distance measurement, the (squared) Mahalanobis distance was used to measure similarity in this paper, which is defined as

$$d\_M(\mathbf{x}\_1, \mathbf{x}\_2) = (\mathbf{x}\_1 - \mathbf{x}\_2)^T \mathcal{W}(\mathbf{x}\_1 - \mathbf{x}\_2) \tag{8}$$

*W* represents the Mahalanobis matrix [39], which is a positive semidefinite symmetric matrix. The purpose of the metric learning process is to obtain a positive semidefinite symmetric matrix *W* for a given training dataset, which is used to establish the similarity measurement between the features of the samples. In other words, it aims to make the metric distance of similar features closer, and dissimilar features are estranged from each other.

#### *4.2. The Establishment of Training Constraints*

For graph supervised learning, the assignment matrix represents the correspondence between the points of two graphs. In hypergraph matching, it is easy to obtain the matching relation of tuples according to the relationship between points represented by the assignment matrix.

To consider the correlation of feature vectors in hypergraph matching and find the Mahalanobis matrix, we used the binary tuple constraint [41] to represent the similarity relation of training samples.

$$\{(x\_i, x\_j), w\_{ij}\} \tag{9}$$

where *wij* refers to whether the two training samples *xi* and *xj* are similar. If *wij* equals 1, indicating that (*xi*, *xj*) belongs to the set of similar samples, the given pair of samples should be close to each other under the learned distance metric function. Similarly, when (*xi*, *xj*) belongs to a dissimilar pairs set, *wij* equals −1, indicating that a given pair of samples should be far apart under the learned distance metric function. For a training dataset, *wij* can be easily obtained according to the assignment matrix. Each tuple is stored as a feature vector; in order to reduce the distance between similar pairs and increase the distance between dissimilar pairs in metric learning, we further constrain similar pairs set *S* and dissimilar pairs set *D* by establishing thresholds.

$$\begin{aligned} S &= \{ (f\_{i\prime} f\_{\hat{j}}) : d\_M(f\_{i\prime} f\_{\hat{j}}) \le \emptyset \} \\ D &= \{ (f\_{i\prime} f\_{\hat{j}}) : d\_M(f\_{i\prime} f\_{\hat{j}}) \ge h \} \end{aligned} \tag{10}$$

where *g* and *h* are constants, and (*fi*, *fj*) represents the pair of feature vectors.

#### *4.3. Metric Learning Algorithm*

Given a set of distance constraints as described in (10), our aim is to learn a positivedefinite matrix *W* that parameterizes the corresponding Mahalanobis distance. In order to improve the computational efficiency, an information-theoretic metric learning approach (ITML) [42] is introduced. It uses a natural information theoretic approach to handle constraints on the distance function while minimizing the relative entropy between two multivariate Gaussians. There is a straightforward bijection between the set of Mahalanobis distances and the set of equalmean multivariate Gaussian distributions, and the multivariate Gaussian that corresponds to a Mahalanobis distance parameterized by *W* can be stated as follows:

$$p(\mathbf{x}; \mathcal{W}) = \frac{1}{z} \exp(\frac{1}{2} d\_M(\mathbf{x}, \mu)) \tag{11}$$

where *μ* is the mean value and *z* is the normalization factor. The relative entropy between corresponding multivariate Gaussians is used to measure the distance between two Mahalanobis distance functions parameterized by *W*<sup>0</sup> and *W*:

$$KL(p(\mathbf{x};\mathcal{W}\_0)||p(\mathbf{x};\mathcal{W})) = \int p(\mathbf{x};\mathcal{W}\_0) \log \frac{p(\mathbf{x};\mathcal{W}\_0)}{p(\mathbf{x};\mathcal{W})} d\mathbf{x} \tag{12}$$

*KL*(·) stands for relative entropy, which is known as Kullback–Leibler divergence [43]. We use it to represent the difference between two probability distributions. Given a similar pair set *S* and a dissimilar pair set *D*, the distance measurement learning problem can be transformed into:

$$\begin{aligned} \min\_{W \ge 0} KL(p(x; \mathcal{W}\_0) || p(x; \mathcal{W})) \\ \text{s.t.} \\ d\_M(f\_{i\prime} f\_{\bar{j}}) \le g\_{\prime}(f\_{i\prime} f\_{\bar{j}}) \in \mathcal{S} \\ d\_M(f\_{i\prime} f\_{\bar{j}}) \ge h\_{\prime}(f\_{i\prime} f\_{\bar{j}}) \in D \end{aligned} \tag{13}$$

where *g* and *h* are constants.

It has been demonstrated that the Mahalanobis distance between mean vectors and the LogDet divergence between covariance matrices can be combined convexly to express the differential relative entropy between two multivariate Gaussians [44]. To solve this optimization function, the Logdet distance *Dld*(·) for measuring the difference of the matrix was introduced to calculate:

$$\begin{aligned} KL(p(\mathbf{x}; \mathcal{W}\_0) || p(\mathbf{x}; \mathcal{W})) &= \frac{1}{2} D\_{ld}(\mathcal{W}\_0^{-1}, \mathcal{W}^{-1}) \\ D\_{ld}(\mathcal{W}\_\cdot \mathcal{W}\_0) &= tr(\mathcal{W} \mathcal{W}\_0^{-1}) - \log \det(\mathcal{W} \mathcal{W}\_0^{-1}) - d \end{aligned} \tag{14}$$

where *d* is the number of rows in *W*.

To facilitate the solution in the wider feasible region, the ITML algorithm introduces the relaxation variable *ξ*, initializes it into *ξ*0, and further rewrites (14) as follows:

$$\begin{aligned} \min\_{W \ge 0, \boldsymbol{\xi}} & \left( D\_{ld} (W, \mathcal{W}\_0) + \rho D\_{ld} (\operatorname{diag} \{ \boldsymbol{\xi} \}, \operatorname{diag} \{ \boldsymbol{\xi}\_0 \}) \right) \\ \text{s.t.} & \\ & \operatorname{tr} (\mathcal{W} (f\_i - f\_j) (f\_i - f\_j)^T) \le \boldsymbol{\xi}\_{i, j, \boldsymbol{\prime}} (f\_i - f\_j) \in S \\ & \operatorname{tr} (\mathcal{W} (f\_i - f\_j) (f\_i - f\_j)^T) \ge \boldsymbol{\xi}\_{i, j, \boldsymbol{\prime}} (f\_i - f\_j) \in D \end{aligned} \tag{15}$$

*ρ* is the equilibrium parameter. According to the principle of Logdet distance optimization in [45], the iterative formula of optimization can be obtained:

$$\mathcal{W}\_{t+1} = \mathcal{W}\_t + \beta \mathcal{W}\_t (f\_i - f\_j)(f\_i - f\_j)^T \mathcal{W}\_t \tag{16}$$

where *Wt* is the metric matrix calculated by the *t*-th iteration, *β* is the mapping parameter, and *fi* and *fj* are the constraint pairs.

#### *4.4. Parallel Learning Algorithm*

It is not feasible to apply the ITML method for distance measurement learning with high-dimensional training data. The complexity of the ITML algorithm has a direct correlation with the square of the data dimension, which leads to high heterogeneity in processing high-dimensional data. Furthermore, the ITML method learns a full rank metric matrix that scales quadratically with the number of input data dimensions, imposing a significant computing overhead on the learning process.

Typically, actual high-dimensional datasets are contaminated with noise or contain redundant information, so the algorithm cannot learn an effective measurement matrix. Therefore, when the dimensions of training samples are large enough, the measurement matrix obtained through ITML algorithm learning cannot effectively suppress the noise and also has disadvantages, such as a low solving efficiency and vulnerability to inadequate training data. To address the above challenges, we improved the metric learning algorithm through parallel computing.

The following proposes a parallel computing process. We may reconstruct the Mahalanobis matrix *W* as *W* = *I* + ∑*<sup>i</sup> αiziz<sup>T</sup> <sup>i</sup>* using the principle in [46] that every positive semidefinite matrix can be decomposed into linear combinations of rank-one matrices, where *zi* ∈ R*d*. It is clear that *Wt*(*fi* − *fj*) is d-dimensional, and *Wt*(*fi* − *fj*)(*fi* − *fj*) *TWT t* is a rank-one matrix. We can concretize the expression for *W* by adding the Bregman projections [47] of all pairs of constraints:

$$\mathcal{W}\_{t+1} = I + \sum\_{i=1}^{\mathcal{C}} \beta\_i(t) z\_i(t) z\_i(t)^T \tag{17}$$

*I* represents the identity matrix of *d* dimensions, *C* is the number of constraint pairs that represent the mapping parameters, *β* denotes the learning rate, and *zi*(*t*) = *Wtci* and *ci* = *fj* − *fk* correspond to constraint pair (*fj*, *fk*). In the algorithm framework, only *z* is saved instead of *Wt*, preventing the issue where *W* tends to grow as *d* gets larger. Therefore, the iteration is changed into the update formula of *z*:

$$z\_k(t+1) = \mathcal{W}\_{t+1}c\_k = (I + \sum\_{i=1}^{\mathbb{C}} \beta\_i(t)z\_i(t)z\_i(t)^T)c\_k\tag{18}$$

According to the original algorithm, *βi*(*t*) should be the upper or lower bound constraint of the measured distance function. The key step in updating is to calculate the actual

distance. Equation (19) can be used to express the real distance of the *k*-th constraint *ck* when combined with Equation (17):

$$\begin{aligned} p\_k(t) &= c\_k^T \mathcal{W}\_l c\_k\\ &= c\_k^T \left( I + \sum\_{i=1}^C \beta\_i(t) z\_i(t) z\_i^T(t) \right) c\_k\\ &= c\_k^T c\_k + \sum\_{i=1}^C \beta\_i(t) c\_k^T z\_i(t) z\_i^T(t) c\_k \end{aligned} \tag{19}$$

Due to the decomposable nature of *W*, the task of updating *z* and *p* is assigned to *C* work units (worker), which reflects the concept of parallel execution. In our framework, worker *k* needs to receive all *z* values generated by previous iterations from other workers and then carry out the next iteration update. Each worker only needs to send the vector *z* and receive (*c* − 1)*z* instead of the entire metric matrix. Therefore, the transfer amount of each step is reduced from O(*d*2) to O(*d*). When *d* exceeds the number of constraints *C*, the transfer requirements will be significantly reduced, which greatly reduces the computational complexity.

We define affinity Ω in terms of the Mahalanobis distance instead of (4), which can better account for the correlation between tuples through learning the training set.

$$\begin{aligned} \Omega\_{i\_1, i\_2, j\_1, j\_2, k\_1, k\_2} &= \exp(-\frac{(f\_1 - f\_2)^T W (f\_1 - f\_2)}{\gamma}) \\ f\_1 &= (i\_1, j\_1, k\_1) \\ f\_2 &= (i\_2, j\_2, k\_2) \end{aligned} \tag{20}$$

Then, *M* is defined as the following:

$$\begin{aligned} M\_{i\_1, i\_2, j\_1, j\_2, k\_1, k\_2} &= \Omega\_{i\_1, i\_2, j\_1, j\_2, k\_1, k\_2, \prime} \, \text{if} ||f\_1 - f\_2|| \le \sigma \\ \text{otherwise } 0 \end{aligned} \tag{21}$$

The value of parameter *σ* corresponds to the degree of triplets deformation; a larger value of *σ* reduces the sensitivity of matching. The resulting algorithm is given as Algorithm 1.

#### **Algorithm 1** Parallel Metric Learning

**Input:** *S*: similar data; *D*: dissimilar data; *u*, *l*: distance thresholds: *γ*: slack parameter **Output:** *W*: Mahalanobis matrix

1: *W* = *I*, *C* = |*S*| + |*D*| 2: **for** constraint (*xp*, *xq*)*k*, *k* ∈ {1, 2, . . . , *C*} **do** 3: *λ<sup>k</sup>* ← 0 4: *dk* ← *u* for (*xp*, *xq*)*<sup>k</sup>* ∈ *S* otherwise *dk* ← *l* 5: *ck* ← (*xp* − *xq*)*k*, *zk* ← *ck* 6: **end for** 7: **while** *β* does not converge **do** 8: **for all** worker *k* ∈ {1, 2, . . . , *C*} **do in parallel** 9: *zk* <sup>=</sup> *ck* <sup>+</sup> *<sup>C</sup>* ∑ *i*=1 *βiziz<sup>T</sup> <sup>i</sup> ck* 10: *<sup>p</sup>* ← *<sup>c</sup><sup>T</sup> <sup>k</sup> zk* 11: **if** (*xp*, *xq*)*<sup>k</sup>* ∈ *S* **then** 12: *<sup>α</sup>* <sup>←</sup> min *λk*, <sup>1</sup> 2 ( 1 *<sup>p</sup>* <sup>−</sup> *<sup>γ</sup> dk* ) 13: *<sup>β</sup>* <sup>←</sup> *<sup>α</sup>* 1−*αp* 14: *dk* <sup>←</sup> *<sup>γ</sup>dk γ*+*αdk* 15: **else** 16: *<sup>α</sup>* <sup>←</sup> min *λk*, <sup>1</sup> <sup>2</sup> ( *<sup>γ</sup> dk* <sup>−</sup> <sup>1</sup> *p* ) 17: *<sup>β</sup>* <sup>←</sup> <sup>−</sup>*<sup>α</sup>* <sup>1</sup>−*α<sup>p</sup>* 18: *dk* <sup>←</sup> *<sup>γ</sup>dk γ*−*αdk* 19: **end if** 20: *λ<sup>k</sup>* ← *λ<sup>k</sup>* − *α* 21: *zk* ← *<sup>I</sup>* <sup>+</sup> *<sup>C</sup>* ∑ *i*=1 *βiziz<sup>T</sup> i* ! *ck* 22: send *zk* to other workers 23: **end for** 24: **end while** 25: *<sup>W</sup>* <sup>=</sup> *<sup>I</sup>* <sup>+</sup> *<sup>C</sup>* ∑ *i*=1 *βiziz<sup>T</sup> i*

#### **5. Experiments**

In the following, we compare our method to advanced hyper-graph matching algorithms using benchmark data sets in which the original information for a specific sample graph is the feature point set. In order to express conveniently, the proposed method was represented by MLGM. We used the following advanced methods to compare with MLGM: spectral matching (SM) [8], max-pooling matching (MPM) [22] and IPFP [29], probabilistic graph matching (HGM) [17], tensor matching (TM) [9], reweighted random walk hypergraph matching (RRWHM) [18], block coordinate ascent graph matching (BCAGM) [23], and alternating direction graph matching (ADGM) [48]. We introduced noise and distortion to several datasets to distinguish our method's performance from that of other approaches. We compared the results to other algorithms in terms of accuracy and matching score. Accuracy was determined as the ratio between the number of accurate matches and the total amount of points and score was determined by Equation (6). The parameter settings for all of the state-of-art algorithms were identical to those suggested in their respective articles.

In our method, the dimension of the feature vector for each tuple of points was set to 3. We used Equation (21) to compute the affinity tensor *M*, and *γ* was set as in [9]. In the calculation process, we simply randomly selected *N* × *m* triplets from the graph model, where *N* is a user-defined parameter (this paper was set to 50). For the best empirical performance, only *K* nearest tuples matching for each triplet in the target image were selected, where *K* was set to 300 in this paper.

#### *5.1. Synthetic Dataset*

In this section, we introduce the popular benchmark datasets Blessing and Fish [49] in our experiment, which are reliable in evaluating graph matching algorithms.

In Figures 2 and 3, we show examples of the existing synthetic database (the Chinese character "blessing" and a tropical fish). The model shape is shown in the first column, in which, the images of Blessing and Fish are composed of 105 and 98 points, respectively. In order to validate the robustness of the proposed algorithm under noise, deformation, outliers, and rotation conditions, we conducted four sets of experiments. Column b contains examples of noisy targets produced by the addition of Gaussian random noise. Column c contains examples of deformed targets created by applying nonrigid deformation to model points. Column d contains examples of targets with outliers created by combining random points with a normal distribution of unit variance and moderate degrees of rotation. Column e contains examples of targets with large rotations and moderate Gaussian noise. We then experimented with each group of graphs and evaluated the robustness of these methods. The results are shown in Figures 4–7.

**Figure 2.** (**a**) shows model fish point sets, and (**b**–**e**) show point sets added with deformation, noise, outliers, and rotation, respectively.

**Figure 3.** (**a**) shows model blessing point sets, and (**b**–**e**) show point sets added with deformation, noise, outliers, and rotation, respectively.

**Figure 4.** Accuracy comparison on the Fish dataset. (**a**) Accuracy with different degree of deformation. (**b**) Accuracy with different noise level. (**c**) Accuracy with different number of outliers. (**d**) Accuracy with different rotation angle.

**Figure 5.** The assignment matrix obtained by MLGM from matching results on the Fish dataset with different degree of deformation (**a**–**e**).

**Figure 6.** Accuracy comparison on the Blessing dataset. (**a**) Accuracy with different degree of deformation. (**b**) Accuracy with different noise level. (**c**) Accuracy with different number of outliers. (**d**) Accuracy with different rotation angle.

**Figure 7.** The assignment matrix obtained by MLGM from matching results on the Blessing dataset with different degree of deformation (**a**–**e**).

In the case of deformation disposal, the degree of deformation was set from 0.02 to 0.1; as it increases, the matching accuracy of all algorithms decreases correspondingly. For each graph pair, we fixed the points of an image and used algorithms to find the corresponding points in the other image. Each experimental result is the average of a multigroup parallel experiment to ensure the reliability of the test. The results show that the ITML method has an obvious advantage in this situation. For the noise condition experiments, the target points were obtained by adding Gaussian random noise from *σ* = 0.01 to *σ* = 0.05; we can see from Figures 4 and 6 that the methods using high-order graph matching can achieve a higher accuracy because the internal information of the image topology is applied to the feature description. For these groups of experiments on synthetic database, our algorithm can obtain more accurate matching results. In particular, our method can achieve 100% accuracy for datasets with outliers. Graphs in Figures 8 and 9 also display the matching scores under varying experimental conditions. In the case of increased interference, the matching score remains steady at a higher level, demonstrating that the affinity metric of the feature in our method is totally invariant to massive affine deformations and strong Gaussian noise. With the addition of rotation, the results of images show that the rotation angle has little effect on the matching results, but when the rotation angle reaches 90 degrees, the accuracy obtained by our method has a slight decline at this point. We believe that this is mainly because the 90-degree rotation has a degree of influence on the algorithm that focuses on correlation. The experiments show that, after the metric learning of the dataset, the results of matching can be improved.

**Figure 8.** Matching score comparison on the Fish dataset. (**a**) Matching score with different degree of deformation. (**b**) Matching score with different noise level. (**c**) Matching score with different number of outliers. (**d**) Matching score with different rotation angle.

**Figure 9.** Matching score comparison on the Blessing dataset. (**a**) Matching score with different degree of deformation. (**b**) Matching score with different noise level. (**c**) Matching score with different number of outliers. (**d**) Matching score with different rotation angle.

#### *5.2. Face Dataset and Duck Dataset*

In this section, we compare the performance of our method to other methods on the Face dataset and Duck dataset, which are the sub-datasets from Caltech-256 [50]. These datasets contain images from specific classes: 109 face images, and 50 duck images. The ground truth is known for each graph pair. We chose 70 pairs of faces at random from the data set for testing, manually picked 10 feature points from each picture, and chose 20 photographs from each class at random as the training dataset for metric learning. The baseline was varied from 10 to 80 frames, and we tested all algorithms and obtained the average of the results. The accuracy and matching scores were obtained by averaging experiments of 10 frames to 80 frames. To make it more intuitive, we show several examples of matching results in Figures 10 and 11. It can be seen that our algorithm performes better. Figures 12 and 13 show that the MLGM method achieves the largest score value, and obtains more accurate matching results than other test methods. It also demonstrates that the compared approaches are easily affected by noise and distortion. The MLGM method can obtain a better matching result for the entire dataset.

(**c**) (RRWHM-BCAGM-ADGM)

**Figure 10.** Example results of experiments on the Face dataset, in which red and yellow lines denote correct and incorrect matching results.

(**a**) (MLGM-SM-MPM)

(**b**) (IPFP-HGM-TM)

(**c**) (RRWHM-BCAGM-ADGM)

**Figure 11.** Example results of experiments on the Duck dataset.

**Figure 12.** Trend chart of matching accuracy and score of the Face dataset. (**a**) Accuracy of the Face dataset. (**b**) Matching score of the Face dataset.

**Figure 13.** Trend chart of matching accuracy and score of the Duck dataset. (**a**) Accuracy of the Duck dataset. (**b**) Matching score of the Duck dataset.

#### **6. Conclusions**

In this paper, we proposed a tensor graph matching model based on metric learning that uses Mahalanobis distance as the affinity measure function and makes full use of the distribution information and geometric information of hypergraphs. To solve the proposed model, a parallel distance metric learning approach was used, which can learn appropriate metrics from high-dimensional data without using low-rank approximation. The experimental results of testing on several databases, such as the synthetic datasets of Blessing, Fish, and Face datasets, and the Duck dataset, indicated that the suggested method performs better than the existing ones. In the future, we may consider combining this strategy with deep learning.

**Author Contributions:** Methodology, Z.W. and Y.W.; Project administration, F.L.; Supervision, F.L.; Writing—original draft, Z.W.; Writing—review & editing, Y.W. All authors have read and agreed to the published version of the manuscript.

**Funding:** This work is supported by the National Natural Science Foundation of China under Grant No. 62072256, the Natural Science Foundation of Nanjing University of Posts and Telecommunications (Grant No. NY221057 and NY220003) and the Postgraduate Research & Practice Innovation Program of Jiangsu Province, China (Grant No. SJCX19\_0248).

**Data Availability Statement:** Not applicable.

**Conflicts of Interest:** The authors declare no conflict of interest.

#### **References**


### *Article* **Plug-and-Play-Based Algorithm for Mixed Noise Removal with the Logarithm Norm Approximation Model**

**Jinhua Liu 1,\*, Jiayun Wu 1, Mulian Xu <sup>1</sup> and Yuanyuan Huang <sup>2</sup>**

	- Chengdu 610225, China

**Abstract:** During imaging and transmission, images are easily affected by several factors, including sensors, camera motion, and transmission channels. In practice, images are commonly corrupted by a mixture of Gaussian and impulse noises, further complicating the denoising problem. Therefore, in this work, we propose a novel mixed noise removal model that combines a deterministic lowrankness prior and an implicit regularization scheme. In the optimization model, we apply the matrix logarithm norm approximation model to characterize the global low-rankness of the original image. We further adopt the plug-and-play (PnP) scheme to formulate an implicit regularizer by plugging an image denoiser, which is used to preserve image details. The above two building blocks are complementary to each other. The mixed noise removal algorithm is thus established. Within the framework of the PnP scheme, we address the proposed optimization model via the alternating directional method of multipliers (ADMM). Finally, we perform extensive experiments to demonstrate the effectiveness of the proposed algorithm. Correspondingly, the simulation results show that our algorithm can recover the global structure and detailed information of images well and achieves superior performance over competing methods in terms of quantitative evaluation and visual inspection.

**Keywords:** mixed noise removal; matrix nuclear norm; logarithm norm; ADMM; plug-and-play

**MSC:** 68U10

### **1. Introduction**

Image denoising has been widely used in many applications, such as hyperspectral imaging (HSI) [1], scene recognition [2], and image restoration [3]. However, due to imaging conditions, natural images inevitably suffer from various kinds of noises, e.g., Gaussian, random, salt-and-pepper (S&P), and stripe noises, which critically influence subsequent applications. In particular, many images are contaminated by mixed noise, including Gaussian noise plus random noise or Gaussian noise plus stripe noise. Therefore, restoring a clean image from its corrupted version is the central issue in image denoising. From a mathematical perspective, the denoising problem is morbid and irreversible. Hence, to some extent, the prior knowledge of the image is of great importance.

In the past decade, scholars have proposed numerous image denoising models, such as bivariate probability [4], Gaussian–Hermite distribution [5], total variation [6], autoregressive [7], Block-Matching 3D (BM3D) [8], and sparse representation-based image modeling [9–11]. Among these models, the image sparse representation model has been extensively studied and applied. It transforms a natural image into a linear combination of a group of base or dictionary atoms and makes the transformed image coefficient sparse and compressible. Finally, only a few coefficients are unequal to 0. A few examples of this model are the common cosine, wavelet, and Fourier base methods. However, this image denoising method can only address white Gaussian noise. In actual applications, images

**Citation:** Liu, J.; Wu, J.; Xu, M.; Huang, Y. Plug-and-Play-Based Algorithm for Mixed Noise Removal with the Logarithm Norm Approximation Model. *Mathematics* **2022**, *10*, 3810. https://doi.org/ 10.3390/math10203810

Academic Editor: Radu Tudor Ionescu

Received: 2 September 2022 Accepted: 12 October 2022 Published: 15 October 2022

**Publisher's Note:** MDPI stays neutral with regard to jurisdictional claims in published maps and institutional affiliations.

**Copyright:** © 2022 by the authors. Licensee MDPI, Basel, Switzerland. This article is an open access article distributed under the terms and conditions of the Creative Commons Attribution (CC BY) license (https:// creativecommons.org/licenses/by/ 4.0/).

are often affected by many types of noise, such as Gaussian, S&P, or random noises. The traditional denoising method cannot easily remove impulse noises, because it maintains impulse noise points at edges [12,13].

In general, two types of typical impulse noises exist, i.e., S&P and random noises. Conventional methods use two approaches to remove the mixture of Gaussian–impulse noises. The first is the detection-based noise removal method, and the second is the modeling-based method. The detection-based denoising method has been discussed in existing research [14–17]. This method first detects the locations of damaged image pixels, then handles the mixed noise. In fact, the accuracy of the detection of the damaged pixels is very important for removing mixed noise. Generally, detection-based methods are effective in removing impulse noise. However, their fidelity terms do not take Gaussian noise into account. Therefore, they cannot remove Gaussian noise effectively.

The second method treats impulse noise as a sparse signal and constructs a statistical distribution model on the basis of the impulse noise. A previously reported method [18] adopts Laplacian scale mixture (LSM) modeling to characterize impulse noise and estimates the hidden variables and impulse noise jointly from the noisy image. This method utilizes a nonlocal low-rank regularizer to regularize the denoising model. Liu et al. [19] proposed a mixed noise removal algorithm using weighted dictionary learning. Although this method can handle mixed noise, its training process is time-consuming. Jiang et al. [20] developed an image denoising method by combining weighted encoding and nonlocal self-similarity. This method can remove Gaussian and impulse noises jointly. However, its denoising performance relies on the design of the diagonal weight matrix.

Recently, low-rank matrix recovery has attracted considerable attention in the field of image restoration [4,21–24]. The fundamental problem of this process is how to find and use the low-dimensional structures of images. In contrast to the traditional mixed noise denoising method, low-rank matrix recovery can handle different noise types without any noisy prior information. Therefore, many researchers have applied the low-rank matrix restoration model to reconstruct images. Zhang et al. [25] proposed a denoising method for hyperspectral images based on a low-rank matrix recovery model. Subsequently, a noise-adjustable low-rank matrix approximation model was applied to hyperspectral image denoising [26]. However, in the above two methods [25,26], the upper bound of the rank of a given matrix must be set. Nuclear norm was introduced to design the rank approximation function in [27] for hyperspectral image denoising to solve the above issue. This nuclear norm-based rank approximation function is mainly characterized by its treatment of each singular value as equal. However, this approach ignores the fact that the contribution of each nonzero singular value is different. As a result, some nonconvex low-rank-based approaches are exploited for hyperspectral image restoration [28,29]. In addition, the total variation-regularized low-rank restoration method has been developed to remove mixed noise from HSI images [30,31]. In recent years, deep learning-based approaches to image denoising have been extensively studied. Instead of mathematical model construction, learning-based methods directly learn a mapping function from a noisy image to a clean image. These methods include convolutional neural network-based CT denoising [32], autonomous illumination systems [33], and deep plug-and-play (PnP) image restoration [34]. Additionally, some low-rank tensor-based HSI restoration algorithms have been proposed. These algorithms include weighted group sparsity-regularized lowrank tensor decomposition (LRTDGS) [35] and fibered rank constrained tensor restoration PnP [36].

In this work, inspired by PnP-based [34,36–39] and low-rank based [40,41] methods, we propose a mixed noise removal algorithm by applying the PnP regularization-based logarithm norm approximation (LNAM) model. First, the LNAM is used to characterize the global low-rankness of the original image. Second, the PnP regularization method is adopted to preserve the image detail information. Finally, the experimental results obtained through simulations on test images are used to confirm the effectiveness of the

proposed denoising method. The contributions of the proposed method can be summarized as follows:

First, instead of utilizing the matrix-based low-rank approximation function, we introduce a logarithm norm-based smooth rank function and propose the LNAM. Compared with the nuclear norm-based low-rank function, the proposed model could more effectively exploit the global low-rank structure of HSI and provides a tighter approximation.

Second, the low-rankness prior is known to usually face limitations in preserving the local details of images. Therefore, the PnP framework is incorporated into the LNAM model to break through this limitation. Furthermore, we introduce a classic BM3D denoiser [8] that extensively exploits the nonlocal self-similarity prior of images.

Third, several simple subproblems are solved by decomposing the original problem by using the framework of the alternating direction multiplier method (ADMM) to address the LNAM optimization problem effectively.

The remainder of this article is organized as follows: Section 2 introduces the related works using mixed noise denoising models on hyperspectral images. As described in Section 3, the LNAM model is proposed and solved with the ADMM-based optimization algorithm. Section 4 presents the experimental results of the test images and a discussion on the effect of several parameters on the proposed algorithm. Finally, we conclude this paper in Section 5.

#### **2. Background of the Low-Rank-Based Hyperspectral Image Denoising Method**

Mixed noise removal techniques based on low-rank matrix recovery are mainly inspired by the robust principal component analysis (RPCA) [42]. The main concept of RPCA is that it aims to find the underlying low-dimensional subspace structure of highdimensional signals from the corrupted observation. The RPCA model can be expressed as

$$\begin{array}{l}\min\max\,\text{rank}(\mathbf{X}) + \lambda \||\mathbf{S}||\_0\\\text{s.t.}\mathbf{Y} = \mathbf{X} + \mathbf{S}\end{array} \tag{1}$$

where *λ* denotes the regularization parameter; *Y* represents the corrupted observational data; *X* and *S* are denoted the unknown low-rank matrix and the sparse matrix, respectively; and ·<sup>0</sup> represents the 0-norm, which attempts to promote sparsity. Although the RPCA model can be utilized to remove the sparse noise, however, it cannot work well when the hyperspectral image is polluted by mixed noise, e.g., Gaussian noise plus sparse noise. Therefore, an improved model has been proposed by considering the Gaussian noise *E* in the following:

$$\begin{array}{ll}\min\_{\mathbf{X}, \mathbf{S}, \mathbf{E}} \operatorname{rank}(\mathbf{X}) + \lambda \|\mathbf{S}\|\_{0} + \frac{\eta}{2} \|\mathbf{E}\|\_{F}^{2} \\ \text{s.t. } \mathbf{Y} = \mathbf{X} + \mathbf{S} + \mathbf{E} \end{array} \tag{2}$$

where *λ*, *η* are both the regularization parameters. Problems (1) and (2) are NP-hard problems. One common approach is replacing the rank function with the nuclear norm, and correspondingly, the 0-norm is replaced with the 1-norm [43].

$$\begin{array}{ll}\min\_{\mathbf{X}, \mathbf{S}, \mathbf{E}} \|\mathbf{X}\|\_{\*} + \lambda \|\mathbf{S}\|\_{1} + \frac{\eta}{2} \|\mathbf{E}\|\_{F}^{2} \\\text{s.t. } \mathbf{Y} = \mathbf{X} + \mathbf{S} + \mathbf{E} \end{array} \tag{3}$$

The low-rank matrix approximation model has been widely used in most hyperspectral image denoising applications. However, this model suffers from the following aspects: First, all nonzero singular values are known to have the same contribution to the rank function. In fact, different singular values have different contributions. Large singular values would be penalized more heavily than small ones by using the nuclear norm approach. This situation easily leads to the overshrinking of the rank. Second, the rank function may be impractical. Third, low-rank matrix approximation approaches require numerous iterations. This requirement results in low computational efficiency.

Recently, the nonconvex relaxation approach has been utilized to approximate the nuclear norm [44]. In particular, a well-known method named the weighted Schatten p-norm model was introduced [45] for hyperspectral image denoising. This method is represented as

$$\begin{array}{ll}\min\_{\mathbf{X},\mathbf{S}} \mathbf{C} \|\mathbf{X}\|\_{w,\mathbf{S}\_{P}}^{p} + \lambda \|\mathbf{S}\|\_{1} \\ \text{s.t. } \mathbf{Y} = \mathbf{X} + \mathbf{S} + \mathbf{E}, \ \|\mathbf{E}\|\_{F} \le \boldsymbol{\xi} \end{array} \tag{4}$$

where C denotes the weights for the low-rank constraint, *λ* represents the regularization constraint parameter, and *<sup>ξ</sup>* denotes the noise level. In *X<sup>p</sup> <sup>w</sup>*,*Sp* <sup>=</sup> <sup>∑</sup>*<sup>i</sup> wiσ<sup>p</sup> <sup>i</sup>* (*X*), *wi* represents the *i*th non-negative weighted value, and *σ<sup>i</sup>* is the *i*th singular value of matrix *X*. *E<sup>F</sup>* denotes the Frobenius-norm of matrix *E*.

This weighted Schatten p-norm model can effectively remove noise. However, it is sensitive to the initial parameters, such as the noise level and the weights. Furthermore, the model is difficult to adapt for the removal of mixed noise. Therefore, inspired by the idea presented in a previous work [40,41], in this work, we use the matrix LNAM to eliminate mixed noise from images.

#### **3. Proposed Mixed Denoising Algorithm**

As mentioned above, hyperspectral images are often contaminated by mixed noise, and a strong structural correlation exists among the image blocks. This situation prompted us to apply the rank function-based method. In this work, we propose a PnP-based LNAM for mixed noise removal from hyperspectral images. Next, we adopt the ADMM optimization algorithm to solve the proposed mixed noise removal model within the PnP framework and develop the corresponding hyperspectral image denoising algorithm.

#### *3.1. PnP-Based LNAM Model*

Given that various noises in natural images are independent, we propose the mixed noise removal model based on a logarithm norm-based rank approximation as follows:

$$\begin{array}{ll}\min\_{\mathbf{X},\mathbf{S}} \|\mathbf{X}\|\_{L} + \lambda \|\mathbf{S}\|\_{1} + \rho \phi(\mathbf{X})\\\text{s.t. } \|\mathbf{Y} - \mathbf{X} - \mathbf{S}\|\_{F}^{2} \leq \mathcal{J}\end{array} \tag{5}$$

where *λ*, *ρ* are the regularization parameters, *Y* is the corrupted image, and *S* denotes the sparse noise. *ζ* > 0. *X<sup>L</sup>* represents the logarithmic norm-based low-rank function. The subscript "L" is the first letter of the logarithm, which can be expressed as

$$\|\mathbf{X}\|\_{L} = \sum\_{i=1}^{\min\{m\_1, m\_2\}} \log(\sigma\_i^p(\mathbf{X}) + \delta),\tag{6}$$

where *X* denotes a clear image with the size of *m*<sup>1</sup> × *m*2, and *σi*(*X*) represents the ith singular value of *X*. 0 < *p* ≤ 1, and *δ* > 0 denotes a constant that is used to avoid dividing the result by 0.

In model (5), *φ*(*X*) denotes an implicit regularizer exploiting certain priors of natural images, which can be selected from many famous denoisers, such as the BM3D denoiser [8], DnCNN denoiser [46] and FFDNET [47]. In this work, the BM3D denoiser is selected as the embedded regularization module. In summary, *X<sup>L</sup>* characterize the global information of the original image, i.e., low-rankness. Additionally, the image details can be persevered by plugging the regularization module *φ*(*X*) into the PnP framework. To preserve the global structure and detailed information of the image, the two above complementary modules are used in our work.

Compared with the nuclear norm function, the logarithmic norm-based low-rank function can obtain a superior sparseness on real images. In reference to a previous work [48], we suppose that a constant *M* is the boundary of feasible set *X*, such that *X* = |*x*| ≤ *M*, and the convex envelop of rank(x) is <sup>1</sup> *<sup>M</sup> X*∗ <sup>=</sup> <sup>1</sup> *<sup>M</sup>* |*x*|1. The logarithmic function is clearly

closer to rank(x) than the convex envelope when the positive constant *δ* → 0. Therefore, the logarithmic function can achieve stronger sparsity than the nuclear norm.

#### *3.2. Optimization Method*

We introduce an auxiliary variable *L* to address the PnP-based logarithmic norm approximation model (7). Correspondingly, model (7) can be represented as

$$\begin{array}{ll}\underset{\mathbf{X},\mathbf{S}}{\min} \|\mathbf{X}\|\_{L} + \lambda \|\mathbf{S}\|\_{1} + \rho\phi(\mathbf{L})\\\text{s.t.} \ \|\mathbf{Y} - \mathbf{X} - \mathbf{S}\|\_{F}^{2} \leq \zeta : \mathbf{X} = \mathbf{L} \end{array} \tag{7}$$

Furthermore, the augmented Lagrangian function of (7) is constructed as

$$\begin{aligned} \ell(\mathbf{X}, \mathbf{L}, \mathbf{S}, \Lambda\_1, \Lambda\_2, \lambda, \rho, \beta\_1, \beta\_2) &= \|\mathbf{X}\|\_{\boldsymbol{L}} + \lambda \|\mathbf{S}\|\_{\boldsymbol{1}} \\ &+ \langle \Lambda\_1, \mathbf{Y} - \mathbf{X} - \mathbf{S} \rangle + \frac{\beta\_1}{2} \|\mathbf{Y} - \mathbf{X} - \mathbf{S}\|\_F^2 + \rho \phi(\mathbf{L}) + \langle \Lambda\_2, \mathbf{X} - \mathbf{L} \rangle + \frac{\beta\_2}{2} \|\mathbf{X} - \mathbf{L}\|\_F^2 \end{aligned} \tag{8}$$

where *Λ*1,*Λ*<sup>2</sup> denote the Lagrangian multipliers, and *β*1, *β*<sup>2</sup> represent the penalty parameters. Within the framework of ADMM, we minimize the augmented Lagrangian function (8) by using an alternating strategy, i.e., at the (*k* + 1)th step. We thus update the solution by fixing some variables and solving the remaining ones. Finally, the proposed mixed noise removal method can be divided into the following three subproblems and summarized in Algorithm 1.

(1) X-Subproblem

Given *S<sup>k</sup>* and *Lk*, we update *X<sup>k</sup>* as

$$\begin{array}{lll}\mathbf{X}^{k+1} &= \operatorname\*{argmin}\_{\mathbf{X}} \left\{ \left\lVert\mathbf{X}\right\rVert\_{\boldsymbol{L}} + \left\langle \mathbf{A}\_{1}, \mathbf{Y} - \mathbf{X} - \mathbf{S}^{k} \right\rangle + \frac{\beta\_{1}}{2} \left\lVert\mathbf{Y} - \mathbf{X} - \mathbf{S}^{k} \right\rVert\_{F}^{2} \\ &+ \left\langle \mathbf{A}\_{2}, \mathbf{X} - \mathbf{L}^{k} \right\rangle + \frac{\beta\_{2}}{2} \left\lVert\mathbf{X} - \mathbf{L}^{k} \right\rVert\_{F}^{2} \right\} \\ &= \operatorname\*{argmin}\_{\mathbf{X}} \left\{ \left\lVert\mathbf{X}\right\rVert\_{\boldsymbol{L}} + \frac{\beta\_{1}}{2} \left\lVert\mathbf{X} - \left(\mathbf{Y} - \mathbf{S}^{k} + \frac{\mathbf{A}\_{1}}{\beta\_{1}}\right) \right\rVert\_{F}^{2} + \frac{\beta\_{2}}{2} \left\lVert\mathbf{X} - \mathbf{L}^{k} + \frac{\mathbf{A}\_{2}}{\beta\_{2}} \right\rVert\_{F}^{2} \right\} \\ &= \operatorname\*{argmin}\_{\mathbf{X}} \left\{ \left\lVert\mathbf{X}\right\rVert\_{\boldsymbol{L}} + \frac{\beta\_{1} + \beta\_{2}}{2} \left\lVert\mathbf{X} - \frac{\beta\_{1}A + \beta\_{2}\mathbf{B}}{\beta\_{1} + \beta\_{2}} \right\rVert\_{F}^{2} \right\} \end{array} \tag{9}$$

where *<sup>A</sup>* <sup>=</sup> *<sup>Y</sup>* <sup>−</sup> *<sup>S</sup><sup>k</sup>* <sup>+</sup> *<sup>Λ</sup>*<sup>1</sup> *<sup>β</sup>*<sup>1</sup> , *<sup>B</sup>* <sup>=</sup> *<sup>L</sup><sup>k</sup>* <sup>−</sup> *<sup>Λ</sup>*<sup>2</sup> *<sup>β</sup>*<sup>2</sup> . We introduce the following theorem to obtain the solution to (9).

**Theorem 1 (Logarithmic Singular Value Thresholding [40]).** *Let <sup>G</sup>* ∈ *<sup>R</sup>m*1×*m*<sup>2</sup> *be a given matrix, and the SVD of G is G* = *UG* ∑*<sup>G</sup> VT <sup>G</sup>, where* ∑*<sup>G</sup> is the diagonal matrix whose diagonal elements are the singular values. For any α* > 0*,the closed-form solution of the following problem:*

$$\min\_{\mathbf{X}} \alpha \|\mathbf{X}\|\_{L} + \frac{1}{2} \|\mathbf{X} - \mathbf{G}\|\_{F}^{2} \,. \tag{10}$$

*is given by X* = *UG*T*α*, *<sup>ξ</sup>* (∑*G*)*V<sup>T</sup> <sup>G</sup>, where* T*α*, *<sup>ξ</sup>* (·) *represents the logarithmic singular value thresholding function, which can be expressed as*

$$\mathcal{T}\_{\mathfrak{A}, \frac{\pi}{2}}(\mathfrak{x}) = \begin{cases} \ 0, \Delta \le 0 \\ \operatorname\*{argmin}\_{y \in \{0, \ (x - \frac{\pi}{2} + \sqrt{\Delta})/2\}} \quad \text{and} \quad \text{pmatrix}\_{\mathfrak{A}} \mathcal{T}\_{\mathfrak{A}, \frac{\pi}{2}}(\mathfrak{x}) = \mathfrak{x} \end{cases} \tag{11}$$

*where* Δ = (*x* − *ξ*) <sup>2</sup> <sup>−</sup> <sup>4</sup>(*<sup>α</sup>* <sup>−</sup> *<sup>x</sup>ξ*) *and <sup>ϕ</sup>*(*y*) = *<sup>α</sup>* log(*<sup>y</sup>* <sup>+</sup> *<sup>ξ</sup>*)+(*<sup>y</sup>* <sup>−</sup> *<sup>x</sup>*) 2 /2.

(2) L-Subproblem

Given *X<sup>k</sup>* and *S<sup>k</sup>* , we update *L<sup>k</sup>* as

$$L^{k+1} = \underset{L}{\text{argmin}} \rho \phi \left( L^k \right) + \frac{\beta\_2}{2} \left\| X^{k+1} - L + \frac{\mathbf{A}\_2}{\beta\_2} \right\|\_F^2. \tag{12}$$

Let *σ*ˆ <sup>2</sup> = *<sup>ρ</sup> β*2 . Equation (12) can be represented as

$$\operatorname{prox}\_{\Phi}\left(L^{k+1}\right) = \operatorname\*{argmin}\_{L} \phi(L) + \frac{1}{2\theta^2} \left\| X^{k+1} - L + \frac{\mathbf{A}\_2}{\beta\_2} \right\|\_{F}^{2} \tag{13}$$

where *proxφ*(·) denotes the proximal operator of regularization, which is replaced by the embedded denoiser. It is known that BM3D [8] and FFDNET [47] are both famous image denoisers. The main advantage of the BM3D denoiser is that it can be applied to characterize the piecewise smoothness and the nonlocal self-similarity of images in a 3D transform domain. Recently, deep learning-based image denoisers have shown promising performance. However, the deep learning-based method needs a massive amount of training data, and these datasets are difficult to obtain. Therefore, the BM3D denoiser [8] is selected as a module within the PnP framework. By plugging in the BM3D denoiser, the solution can be expressed as

$$L^{k+1} = B M 3D \left( \mathbf{X}^{k+1} + \frac{\mathbf{A}\_2}{\beta\_2}, \boldsymbol{\vartheta} \right). \tag{14}$$

(3) S-Subproblem

Given *Xk*+<sup>1</sup> and *Lk*<sup>+</sup>1, we update *S<sup>k</sup>* as

$$\begin{split} \mathbf{S}^{k+1} &= \operatorname\*{argmin}\_{\mathbf{S}} \left\{ \lambda \left\| \mathbf{S} \right\|\_{1} + \left\langle \mathbf{A}\_{1} \mathbf{Y} - \mathbf{X}^{k+1} - \mathbf{S} \right\rangle + \frac{\beta\_{1}}{2} \left\| \mathbf{Y} - \mathbf{X}^{k+1} - \mathbf{S} \right\|\_{F}^{2} \right\} \\ &= \operatorname\*{argmin}\_{\mathbf{S}} \left\{ \lambda \left\| \mathbf{S} \right\|\_{1} + \frac{\beta\_{1}}{2} \left\| \mathbf{Y} - \mathbf{X}^{k+1} - \mathbf{S} + \frac{\mathbf{A}\_{1}}{\beta\_{1}} \right\|\_{F}^{2} \right\} \end{split} \tag{15}$$

We apply the soft thresholding operator *so f t*(·) to obtain the solution to the subproblem of (15). The operator is defined as *so f tτ*(*x*) = max(|*x*| − *τ*, 0)sgn(*x*), where *x* denotes the variable, and *τ* represents a parameter. Accordingly, the solution of (15) can be represented as

$$\mathbf{S}^{k+1} = \operatorname{soft}\_{\frac{\lambda}{\beta\_1}} \left( \mathbf{Y} - \mathbf{X}^{k+1} + \frac{\mathbf{A}\_1}{\beta\_1} \right) . \tag{16}$$

(4) Update Multipliers

The Lagrangian multipliers are updated as follows:

$$\begin{cases} \mathbf{A}\_1 = \mathbf{A}\_1 + \beta\_1 (\mathbf{Y} - \mathbf{X}^{k \star \mathbf{1}} - \mathbf{S}^{k \star \mathbf{1}})\\ \mathbf{A}\_2 = \mathbf{A}\_2 + \beta\_2 (\mathbf{X}^{k \star \mathbf{1}} - \mathbf{L}^{k \star \mathbf{1}}) \end{cases}. \tag{17}$$

**Input**: The noisy image *Y*, parameter *λ*, *ρ*, stopping criteria *ε*.

**Initialization**: *t* = 0, let *X*, *L*, *S*, and Lagrangian multiplies *Λ*1**,***Λ*<sup>2</sup> be zeros matrices, penalty parameter *β*1= 1.1; *β*2= 1.2. **Step 1**: Calculate *X* via (9). **Step 2**: Calculate *L* via (14). **Step 3**: Calculate *S* via (16). **Step 4**: Update the multiplies **Λ**1, **Λ**<sup>2</sup> via (17). **Step 5**: Check convergence criteria: *Xt*+<sup>1</sup>−*X<sup>t</sup> F X<sup>t</sup> F* ≤ *ε*. **Step 6**: If the convergence criteria are not met, set *t* = *t* + 1 and go to **Step 1**. **Output**: The restored HSI *X*.

#### **4. Experimental Results**

Simulated and real HSI image sets are selected to evaluate the performance of the proposed method. Meanwhile, we conduct comparison experiments on these HSI datasets with other mixed noise removal algorithms, including the modified BM3D method [8], low-rank matrix recovery (LRMR) [25], low-rank global total variation (LRGTV) [31], and a weighted group sparsity-regularized low-rank tensor decomposition (LRTDGS) method [35]. In all experiments, each band of the HSI data is normalized into [0, 1], and the parameters of the methods for comparison are based on the suggested values in the original article. Moreover, the modified BM3D method proposed in [8] is used to remove the Gaussian noise. Before denoising, the sparse noise is detected and removed through adaptive median filtering. Then, BM3D can remove the Gaussian noise. Hence, the modified BM3D method is called A-BM3D.

All the algorithm simulation environments used MATLAB R2018 and a 64-bit Windows 10 operating system with 2.6 GB CPU and 16 GB memory. The configuration of the experimental environmental parameters is summarized in Table 1.



#### *4.1. Simulated Data Experiment*

In this study, the ground truth of the Simu–Indian data [50] and the Pavia City Center data [51] are adopted to generate the synthetic data for our experiments. The sizes of the Simu–Indian and the Pavia data are 145 × 145 × 224 and 200 × 200 × 80, respectively. In addition, we normalize each band of the HSI data into [0, 1] and consider the synthetic HSI data as the clean data. The mean of the peak signal-to-noise ratio (MPSNR) and the mean of structural similarity (MSSIM) over all the bands are utilized to assess the performances of different mixed noise removal algorithms. For the generation of a noisy image, Gaussian and S&P noises are added into all the bands of the clean HSI data, as in the following two cases:

Case 1: In this case, the noise intensity is equal in all bands. First, we add the Gaussian noise with a zero mean into all bands with the noise standard variances G = 0.025, 0.05, 0.075, and 0.10. Second, we add S&P noise into all bands with the noise proportions S&P = 0.05, 0.10, 0.15, and 0.20.

Case 2: In contrast to that in Case 1, the noise intensity in different bands differs in Case 2. We add different zero-mean Gaussian noises into each band. In contrast to that in Case 1, the Gaussian noise variance is randomly selected from 0.02 to 0.10. Then, different percentages of S&P noise, which are randomly selected from [0.10, 0.20], are added into each band. In addition, five selected bands of the Simu–Indian data and 10 selected bands of the Pavia City Center data are corrupted with 10 and 15 stripes, respectively.

Tables 2 and 3 report the comparison results of different denoising methods for the Simu–Indian and Pavia datasets in the above two cases. MPSNR and MSSIM are used to evaluate the performances of different denoising algorithms. These two tables show that, on the whole, the proposed algorithm provides satisfactory PSNR and SSIM values in most cases when compared with other methods. This situation confirms the advantages of the proposed algorithm in mixed noise denoising. For the Simu–Indian data, the performance of the proposed algorithm is close to that of the LRTDGS algorithm when the mixed noise intensity is low. For the Pavia data, the quality results of the LRGTV method are the best likely, because the LRGTV algorithm processes all the patches together and uses the spatial–spectral total variation regularization method to recover the whole 3D HSI. The

restoration effect of the LRMR algorithm is relatively unsatisfactory when the Gaussian noise is strong. Although the A-BM3D algorithm adopts the adaptive filter to remove S&P noise, its denoising effect is not ideal when the density of the S&P noise is high. Table 3 shows that, surprisingly, the LRTDGS algorithm performs poorly on the Pavia data.


**Table 2.** Quantitative evaluation of the different methods on the Simu–Indian dataset.

**Table 3.** Quantitative evaluation of the different methods on the Pavia dataset.


Figures 1 and 2 provide a visual representation of the performances of different methods based on their restoration results for the Simu–Indian dataset. In Figure 1, the zero-mean Gaussian noise standard variance is 0.10, and the S&P noise intensity is 0.10. Meanwhile, in Figure 2, we set the Gaussian intensity to be the same as that in Figure 1, but the noise intensity of S&P is 0.20. Furthermore, the same subregion of each subfigure is marked with red boxes and enlarged. Figures 1 and 2 show that all the compared algorithms can remove mixed noise to some extent. The image tends to be blurry after the A-BM3D method is used. Although the two LRMR algorithms can remove noise and preserve spectral information, they cannot remove the Gaussian noise completely. LRGTV, by taking advantage of the whole 3D structure and spatial–spectral total variation regularization, can obtain satisfactory denoising results. However, it fails to recover the local details well. The performance of the proposed method is close to that of the LRTDGS algorithm mainly because we use the logarithm norm and PnP prior to describe the global structure and nonlocal similarity of the HSI image.

**Figure 1.** Restored results of band 35 on Simu–Indian. From top to bottom: the results under a subcase (the standard deviation of zero-mean Gaussian noise is G = 0.10, and the noise proportion of S&P noise is S = 0.10).

**Figure 2.** Restored results of band 57 on Simu–Indian. From top to bottom: the results under a subcase (the standard deviation of zero-mean Gaussian noise is G = 0.10, and the noise proportion of S&P noise is S = 0.20).

The visual results of the different denoising methods for the Pavia dataset are presented in Figures 3 and 4. The noise intensity in these figures is the same as that in Figures 1 and 2. Figures 3 and 4 show that the denoising performance of the proposed method is satisfactory. However, Figure 4 illustrates that LRGTV is the best algorithm, mainly because it employs the global structure and the spectral information in the low-rank constraint. Compared with the LRGTV method, the proposed method is more sensitive to S&P noise when the noise level is strong. We will address this issue in our future work.

**Figure 3.** Restored results of band 35 on Pavia. From top to bottom: the results under a subcase (the standard deviation of zero-mean Gaussian noise is G = 0.10, and the noise proportion of S&P noise is S = 0.10).

**Figure 4.** Restored results of band 57 on Pavia. From top to bottom: the results under a subcase (the standard deviation of zero-mean Gaussian noise is G = 0.10, and the noise proportion of S&P noise is S = 0.20).

Figures 5–8 provide the PSNR and SSIM values of each band for the Simu–Indian and Pavia datasets, respectively. As shown in Figures 5 and 6, the proposed algorithm presents satisfactory PSNR and SSIM values for almost all bands in the Simu–Indian dataset, indicating that the proposed algorithm outperforms the algorithms for comparison in mixed noise removal. As mentioned above, and as illustrated in Figures 7 and 8, LRGTV achieves the best PSNR and SSIM values for each band in the Pavia dataset. However, the performance of the proposed method is relatively weak. The main reason for this result is not yet clear and will be addressed in our next work.

**Figure 5.** PSNR and SSIM values of restored results by different methods on Simu–Indian data (G = 0.10, S = 0.10).

**Figure 6.** PSNR and SSIM values of restored results by different methods on Simu–Indian data (G = 0.10, S = 0.20).

**Figure 7.** PSNR and SSIM values of restored results by different methods on Pavia data (G = 0.10, S = 0.10).

**Figure 8.** PSNR and SSIM values of restored results by different methods on Pavia data. (G = 0.10, S = 0.20).

#### *4.2. Real Experiments*

Only the Hyper-spectral Digital Imagery Collection Experiment urban dataset, which can be downloaded online [52], is utilized in this experiment and described in this paper due to space limitations. The size of the urban image is 307 × 307 × 210. Figure 9 shows the real-world urban data.

**Figure 9.** Real-world urban data.

Figures 10 and 11 present bands 83 and 205 of the restored images. As shown in Figure 10, the restoration result of A-BM3D is oversmoothed, causing the local details to become distorted. Most other methods, such as LRMR and LRGTV, can effectively remove noise from the urban image. Overall, the results show that the proposed algorithm performs satisfactorily. However, when the band is in the range of [199, 210], the stripes are considered to be the low-rank part, which is assumed to be the clean data, in the low-rank decomposition. Although we use PnP-based regularization to mine the spatial information of the real urban image, the proposed method cannot completely remove the stripes in Figure 11. Therefore, we will explore and address the reason for this problem in our future work.

**Figure 10.** Restoration results on HYDICE urban image set: slight noise band.

**Figure 11.** Restoration results on HYDICE urban image set: moderate noise band.

Figure 12 shows the vertical mean profiles of band 205 before and after restoration. Concretely, it illustrates the spectral curves at one spatial location of the restored results by different algorithms. In this figure, the horizontal axis represents the band index, and the vertical axis represents the mean digital number value of each column. Rapid fluctuations are observed in the curve given the presence of mixed noise. After restoration, the fluctuations are more or less suppressed. Here, the proposed method appears to perform satisfactorily in accordance with the visual results presented in Figure 11. In summary, the above observations in Figure 12 prove that the proposed algorithm achieves satisfied results on mixed noise removal and fine details preservation. The reason why our method performs well is that it utilizes the logarithm norm-based rank function to exploit the global information and PnP regularization module to preserve the details of the image. Furthermore, the small singular values can be eliminated by using the logarithm-norm rank

function. It helps to reconstruct the global structure information. However, the elimination of small singular values results in the loss of image details. This can be restored by using the BM3D regularization method.

**Figure 12.** The vertical mean profiles of band 205 on a real urban image.

#### *4.3. Performance Analysis*

Generally speaking, HSI mixed noise removal is a highly ill-posed problem. In this work, we introduce a PnP prior to make the problem produce feasible results. The nonconvex optimization of the proposed model is challenging, and with the idea of auxiliary variables and the ADMM scheme, one problem that has been noted is convergence.

Therefore, we show the traces of the quality index PSNR with respect to the iterations in Figure 13 to further verify the stability of the proposed algorithm. Figure 13 provides the curve of PSNR vs. iteration number for the Simu–Indian and Pavia datasets. The Gaussian and S&P noise intensities are set as 0.10 and 0.20, respectively. Figure 13 shows that, when the iteration number exceeds 60, the PSNR value tends to be stable. Therefore, the effectiveness of the proposed algorithm is further demonstrated by these experimental results.

**Figure 13.** PSNR values with respect to the iterations for different datasets.

Finally, we provide the computational time of the different methods in Table 4. Note that all the results are implemented in MATLAB R2018. The Gaussian and S&P noise intensities are also set as 0.10 and 0.20, respectively. As shown in Table 4, most of the denoising methods have high computational efficiency. A-BM3D has the shortest running time. However, the proposed algorithm has relatively low computational efficiency, mainly because we use the PnP-based BM3D module to restore the HSI image, which is highly time-consuming. Concretely, this is mainly because the whole HSI image has been divided into image patches, and each image patch is restored by using the BM3D module separately.


**Table 4.** Computational times of different methods (unit: s).

#### **5. Conclusions**

We propose a logarithm norm nonconvex approximation-based HSI algorithm for mixed noise removal. Specifically, the logarithm norm-based nonconvex low-rank is used to characterize the global spatial–spectral correlation among all hyperspectral image bands, and PnP-based regularization is introduced to further exploit the local detailed information of HSI. Then, we develop the ADMM optimization scheme to address the proposed model. Finally, through simulations, real experiments, and discussion, we demonstrate quantitatively and qualitatively that the proposed algorithm achieves satisfactory performance, because the logarithm norm-based low-rank can help restore the global information of the target hyperspectral image, while the embedded BM3D denoiser helps preserve the image details and remove the image structure noise. Our future work will include investigating a novel mixed noise removal algorithm by applying other technologies, such as LSM modeling, deep convolution neural network, attention mechanism, and transformer frameworks. **Author Contributions:** J.L. conceived the idea, designed the experiments, and wrote the paper; J.W. and M.X. helped to analyze the experimental data; and Y.H. helped to review this manuscript. All authors have read and agreed to the published version of the manuscript.

**Funding:** This research was funded by the Science and Technology Research Program of Shangrao (No. 2021J005). This work was also supported by the Natural Science Foundation of Sichuan (No. 2022NSFSC0557).

**Data Availability Statement:** From this study, the ground truth of the Simu–Indian data can be downloaded online at https://engineering.purdue.edu/~biehl/MultiSpec/hyperspectral.html (accessed on 12 March 2022) [50], and the Pavia City center data used in our work can be downloaded from http://www.ehu.eus/ccwintco/index.php?title=Hyperspectral\_Remote\_Sensing\_Scenes (accessed on 19 March 2022) [51]. The Hyper-spectral Digital Imagery Collection Experiment (HYDICE) urban dataset can be downloaded online from [52].

**Conflicts of Interest:** The authors declare no conflict of interest.

#### **References**


### *Article* **A KGE Based Knowledge Enhancing Method for Aspect-Level Sentiment Classification**

**Haibo Yu 1,2, Guojun Lu 1, Qianhua Cai 2,\* and Yun Xue <sup>2</sup>**


**\*** Correspondence: caiqianhua@m.scnu.edu.cn

**Abstract:** ALSC (Aspect-level Sentiment Classification) is a fine-grained task in the field of NLP (Natural Language Processing) which aims to identify the sentiment toward a given aspect. In addition to exploiting the sentence semantics and syntax, current ALSC methods focus on introducing external knowledge as a supplementary to the sentence information. However, the integration of the three categories of information is still challenging. In this paper, a novel method is devised to effectively combine sufficient semantic and syntactic information as well as use of external knowledge. The proposed model contains a sentence encoder, a semantic learning module, a syntax learning module, a knowledge enhancement module, an information fusion module and a sentiment classifier. The semantic information and syntactic information are respectively extracted via a self-attention network and a graphical convolutional network. Specifically, the KGE (Knowledge Graph Embedding) is employed to enhance the feature representation of the aspect. Then, the attention-based gate mechanism is taken to fuse three types of information. We evaluated the proposed model on three benchmark datasets and the experimental results establish strong evidence of high accuracy.

**Keywords:** aspect-level sentiment classification; external knowledge; KGE; GCN

**MSC:** 18C50

### **1. Introduction**

The aspect-level sentiment classification, as a fine grained sentiment analysis task, is widely considered as a main focus in the field of natural language processing. In ALSC tasks, the sentiment polarity of a given aspect in a given text is classified as either positive, neutral or negative [1]. As an example, in the sentence 'the ambience was nice, but service wasn't so great', the sentiment of the two discussed aspects, 'ambience' and 'service', are predicted as positive and negative, respectively. In practice, ALSC has become an effective approach to identify opinions and preferences towards products, stock and anything in the world.

Currently, most methods involving ALSC are performed using the following steps: sentence encoding, syntax dependency tree constructing, syntactic information capturing via graph convolution network (GCN) [2], semantic information extracting based on attention mechanism, information fusion and sentiment classification. So much is the effectiveness of attention networks in attentive weights distribution, a number of studies show their superiority in ALSC tasks [3–5]. Notwithstanding, for a long distance between aspect and its dependency-words, more weight may be assigned to irrelevant words. On this occasion, the establishment of the relation between aspect and its opinion words is thus proposed, which exploits the sentence syntax dependency tree [6]. Figure 1 shows the syntax dependency tree of a given sentence. One can easily see that the syntactical-related words to the aspect, such as 'nice' and 'great', have impressive effects on sentiment polarity prediction. In spite of the significance of syntax structure, the ALSC for informal grammar

**Citation:** Yu, H.; Lu, G.; Cai Q.; Xue Y. A KGE Based Knowledge Enhancing Method for Aspect-Level Sentiment Classification. *Mathematics* **2022**, *10*, 3908. https://doi.org/10.3390/ math10203908

Academic Editors: Jianping Gou, Weihua Ou, Shaoning Zeng and Lan Du

Received: 21 September 2022 Accepted: 17 October 2022 Published: 21 October 2022

**Publisher's Note:** MDPI stays neutral with regard to jurisdictional claims in published maps and institutional affiliations.

**Copyright:** © 2022 by the authors. Licensee MDPI, Basel, Switzerland. This article is an open access article distributed under the terms and conditions of the Creative Commons Attribution (CC BY) license (https:// creativecommons.org/licenses/by/ 4.0/).

styles (e.g., colloquial comments, slang language, etc.) remains challenging. In these cases, the connection between aspect and opinion words can be confusing. Thereby, the extracted syntax can even become noise, which results in the misunderstanding of the sentiment.

**Figure 1.** Syntax dependency tree.

Encouragingly, according to recent publications, external knowledge is also employed to enhance the information of aspect for ALSC [7]. Generally, the exploiting of external knowledge is carried out by searching the information related to the given aspect. That is, the aspect is taken as the central node of the knowledge graph, based on which the subgraphs are built up using its neighbor nodes. In such a manner, the selection of the neighboring nodes is highlighted. The distinctiveness of the external knowledge is mainly restricted by the selection method. Further, for the searched knowledge of substantial distinction, the selected nodes must be revised to a large extent. Moreover, when dealing with the knowledge graph, most of the previous methods used graph neural networks such as the graph convolutional network to search the knowledge graph nodes, which is inefficient.

In Consideration of the aforementioned issues, we propose a method that integrates the sentence semantics and syntax as well as the external knowledge toward the aspect. In order to fully extract the sentence information, the semantic relation between aspect and its contexts is built. Likewise, the connection of opinion words to the aspect is set up. With respect to external knowledge, the knowledge graph embedding (KGE) [8] is employed to obtained the knowledge embeddings of the aspect which makes it more efficient to deal with the knowldege graph. In addition, a fusion module is devised to incorporate the relevant external information and the sentence information for sentiment classification. The contributions of this paper are threefold and summarized as follows:


The rest of this paper is organized as follows: we review the recent studies on ALSC methods and the KGE applications in Section 2. Section 3 presents the proposed model in detail. In Section 4, experiments are carried out to investigate the working performance of our model. Finally, concluding remarks of this work are given in Section 5.

#### **2. Related Work**

#### *2.1. Aspect-Level Sentiment Classification*

Early deep-learning based ALSC methods generally concentrate on extracting the contextual semantics by using the integration of a RNN (Recurrent Neural Network) and attention mechanism [9]. In terms of multiple aspects, the sentiment polarity determination via only semantic information becomes insufficient. In addition to the semantic-based models, the exploiting of sentence syntax is one such approach as well. The relation between an

aspect and its opinion words can be conveyed by a syntax dependency tree. Because of the graph structure of dependency trees, graph neural networks [10] are employed to cope with the syntactic information. Distinctively, the graph convolutional network is most pronounced for processing graph structured data in a variety of tasks. In terms of ALSC, GCN-based models are capable of not just aggregating and delivering information among neighboring nodes, but also of extracting features and syntactic information of the graph. Zhao [11] takes a GCN to model the sentiment dependencies between aspect words, and thereby captures the sentiment relationships of multiple aspects in a sentence. Zhang [12] characterizes the sentence using a syntax dependency tree, and extracts syntactic information via the GCN. Furthermore, aiming to distinguish the importance of each node in the graph, the attention mechanism is integrated into GCN-based methods. To comprehensively understand the relation between aspect and its opinion words, Tian [13] exploits the attention mechanism to assign the attention weight to each word syntactically connected with the aspect word, based on which the syntactic information can be precisely extracted by GCN. By constructing an aspect-centered syntax dependency tree, Wang [14] focuses on identifying each node using a graph attention, and thus aggregating information from neighboring nodes.

#### *2.2. Semantics and Syntax*

Since both semantics and syntax have their own advantages and disadvantages, some recent research solves ALSC by combining these two pieces of information together. Zhang et al. [15] propose an aspect-aware attention mechanism combined with self-attention to obtain the attention score matrices of a sentence, which can not only learn the aspect-related semantic correlations, but also learn the global semantics of a sentence. Bie et al. [16] propose an end-to-end ABSA model, which fuses the syntactic structure information and the lexical semantic information, to address the limitation that existing end-to-end methods do not fully exploit the textual information. Zhang et al. [17] also analyze sentences both syntactically and semantically, and they propose a simple and effective fusion mechanism to make the integration of aspect information and context information more adequate. Some researchers also utilize GCN to capture the neighbor's information [18–20]. However, this research generally ignores that the sentence may not be well formed, and that slang language and informal writing can be found in most user-generated content. As a result, more information is required to help in these situations.

#### *2.3. Knowledge Graph*

A knowledge graph involves a great number of entities and their relationship types. The application of a knowledge graph is carried out in a variety of domains, such as education [21], medicine [22], cybersecurity [23], etc. More recent work validates the significance of the knowledge graph in natural language processing [24]. As such, the utilization of the knowledge graph is currently a main focus in NLP tasks. This also gives rise to new opportunities for its use in ALSC. Zhou [25] has devised a GCN-based method that combines syntactic information and external knowledge. Liang [26] introduced knowledge from the SenticNet knowledge base, thus enhancing the information about aspectual word sentiment in this context. However, these approaches generally ignore the inefficiency of the GCN-based method when dealing with the knowledge graph.

Knowledge graph embedding (KGE) is a creative and practical method for introducing the knowledge graph. Theoretically, KGE aims to represent both complex and sparse entity relationship types with low-dimensional and continuous embeddings, which facilitates the computation of introduced knowledge. KGE is currently a widely-used approach in question answering [27], semantic retrieval [28] and recommendation systems [29]. Early KGE methods, such as TransE [30], and TransH [31], consider the "relationship" as the interpretation between head and tail entities. Furthermore, advances in deep-neural networks have optimized the working performance of KGE. The state-of-the-art KGE methods, such

as ConvE [32] and CapsE [33], are developed based on capsule neural networks, which obtain the feature and calculate the credibility of a triplet through convolutional layers.

#### **3. Methodology**

Figure 2 shows the architecture of the proposed model. There are five main components, namely the sentence encoder, semantic learning module, syntax learning module, knowledge enhancement module, information fusion module and sentiment classifier. More details of each component are presented as follows.

**Figure 2.** Model architecture.

#### *3.1. Sentence Encoder*

Let *x* = *ws* <sup>1</sup>, *<sup>w</sup><sup>s</sup>* <sup>2</sup>,..., *<sup>w</sup><sup>t</sup> <sup>m</sup>*,..., *w<sup>t</sup> m*+*l* ,..., *w<sup>s</sup> n* be an n-word sentence containing the aspect. Each word is mapped into a low-dimensional vector by looking up in a pretrained word embedding matrix. We can thus obtain the sentence embedding.

Then, the hidden state of the given sentence is extracted via Bidirectional-Gate Recurrent Unit (Bi-GRU) which outperforms other methods in extracting the long-term information of a sentence. As a result, we use Bi-GRU to encode the sentence for further processing. The forward and backward hidden states of the sentence are delivered as −→*<sup>H</sup> GRU* = −→*<sup>h</sup> <sup>s</sup>* 1, −→*h s* <sup>2</sup>,..., −→*<sup>h</sup> <sup>t</sup> <sup>m</sup>*,..., −→*<sup>h</sup> <sup>t</sup> m*+*l* , −→*h s n* and ←− *H GRU* = ←− *h s* 1, ←− *h s* <sup>2</sup>,..., ←− *h t <sup>m</sup>*,...,

←− *h t m*+*l* ,..., ←− *h s n* , respectively. The sentence representation is the concatenation of −→*<sup>H</sup> GRU* and ←− *H GRU*, i.e.,

$$H^{GRII} = \begin{bmatrix} \overleftrightarrow{H}^{GRII}, \overleftrightarrow{H}^{GRII} \end{bmatrix} \tag{1}$$

#### *3.2. Semantic Learning Module*

The semantic learning module is mainly developed to establish the semantic relation between aspect and its context. With the input sentence representation, in order to corcapture the semantic relation between aspect and its context, we proposed two attention mechanisms. The self-attention mechanism is first performed to obtain the contextual dependency of the given sentence. Subsequently, the aspect-specific attention mechanism is carried out to determine the relation between the aspect and context. Concretely, the attention weights of each context word is computed:

$$SelfAttt = \frac{\left(H^{GRLI}W^k\right)\left(H^{GRLI}W^q\right)^T}{\sqrt{d\_k}}\tag{2}$$

where *W<sup>k</sup>* and *W<sup>q</sup>* are trainable parameter matrices and *dk* is the dimension of input vector.

Based on the attention weight, the hidden state in relation to the aspect can be derived, which is: 

$$H^{\rm sc} = A \text{tt}\left(H^{\rm Col}{}^{Att}{}^{Att}, H^a\right) \tag{3}$$

where *HSel f Att* represents the outcome of the self-attention network and *H<sup>a</sup>* is the hidden state of the aspect word output from Bi-GRU. We take *Hse* as the semantic representation for further processing.

#### *3.3. Syntax Learning Module*

Syntax can be seen as a supplement of semantics and it has shown to be helpful in sentiment classification. So, to fully extract sentence information, syntactic information is necessary. With respect to the syntactic information, the syntax dependency tree of the given sentence is built in advance. In the syntax learning module, the syntax dependency tree is transformed to the graph *Gsy* = *HGRU*, *Asy* to facilitate processing. Notably, *HGRU* is the feature matrix derived from Bi-GRU, while *Asy* is the adjacency matrix of the syntax dependency tree.

We employ GCN to extract the syntactic information of the sentence, which can be written as:

$$H^{sy(l+1)} = \text{GCN}\left(H^{sy(l)}, \tilde{A}^{sy}, W^{sy(l+1)}\right) \tag{4}$$

$$\text{GCN}\left(H^l, \tilde{A}, \mathcal{W}^{(l+1)}\right) = \text{ReLU}\left(H^l \tilde{A} \mathcal{W}^{(l+1)}\right) \tag{5}$$

with

$$
\tilde{A}^{sy} = \tilde{D}^{-\frac{1}{2}} \left( A^{sy} + I^f \right) \tilde{D}^{-\frac{1}{2}} \tag{6}
$$

where *Hsy*(*l*+1) stands for the output of the *l*-th layer in the GCN. The initial *Hsy*(0) is the output from Bi-GRU. *<sup>A</sup>sy* represents the adjacency matrix with self-circulation. *<sup>W</sup>sy*(*l*+1) is the learnable-parameter-matrix of the *l*-th layer.

With the convolution of each layer, the information of every single node is aggregated from its neighboring node, based on which the node information can be updated during the iterative computation of the GCN. Thus, the syntax representation is the output of the GCN after the last iteration.

#### *3.4. Knowledge Enhancement Module*

For the purpose of the aspect feature, supplementary, external knowledge is leveraged to enhance the information of the aspect. Specifically, we use Freebase [34] as an external knowledge base, which contains a large number of words together with various semantic relations.

For a word beyond comprehension, one can search for known information involved with this word for better understanding. In such a manner, the external knowledge can be applied to complement information related to the aspect during learning.

In most user-generated content, informal writing, such as errors in spelling and grammar and slang language, can be found. On this occasion, the exploiting of external knowledge makes a contribution to the determination of sentiment polarity. For instance, the sentence 'check out these songs! Especially that amazing rock one' contains an aspect word 'songs'. Syntactically, there is no explicit opinion word in direct relation to the aspect 'songs' for sentiment classification. For this reason, external knowledge can be introduced, based on which the relation between 'songs' and 'rock' is set up. That is, the word 'rock' indicates a type of song, which is a subordinate of 'songs'. Seeing that the opinion word toward 'rock' is 'amazing', the sentiment polarity is identified as positive. In this way, the sentiment polarity of the aspect 'songs' is similar to that of 'rock'.

In the knowledge enhancement module, we introduce the knowledge graph and take KGE to tackle the external knowledge from Freebase. Notably, most state-of-the-art methods employ GCN to encode the external knowledge. Whereas, a certain amount of external knowledge bases contain heterogeneous graphs, which is challenging for the GCN to deal with. In our model, the external knowledge is mapped into a continuous vector space using KGE, which is more efficient. The enhancement of aspect is conducted by computing the weights between aspect words and the knowledge embeddings.

On this occasion, we select DistMult [35] as the KGE of the proposed model. Every single entity within the knowledge graph base is delivered as:

$$y\_c = f(\mathcal{W}x\_c) \tag{7}$$

where *f* stands for either a linear or nonlinear function. *W* is the parameter matrix. *xe* is a vector that represents an entity. Notably, the relationship representation is typically obtained from the score function. DistMult takes the basic bilinear score function as:

$$\mathbf{g}\_r^b(y\_{\varepsilon1}, y\_{\varepsilon2}) = \mathbf{y}\_{\varepsilon1}^T \mathbf{M}\_r \mathbf{y}\_{\varepsilon2} \tag{8}$$

where the relation matrix *Mr* is a diagonal matrix whilst *ye*<sup>1</sup> and *ye*<sup>2</sup> are the vector representations of entities *xe*<sup>1</sup> and *xe*2, respectively. The aspect-based knowledge embedding *Hkg* can be obtained by computing the attentive weight between the aspect and its knowledge embedding:

$$H^{\text{kg}} = Att(DistMult(\mathfrak{x}\_{\mathfrak{e}}), H^{\mathfrak{a}}) \tag{9}$$

#### *3.5. Information Fusion Module*

Since we have gained different kinds of information including syntactic information, semantic information and external knowledge information, how to effectively combine these three kinds of information is of vital importance. The information fusion module is devised to make full use of the syntactic information, the semantic information and the external information. Both the syntax and the semantics can be considered as sentence information while the external knowledge is the supplementary. During information fusion, each type of information has to be controlled within a certain extent to prevent the introduction of noise. Therefore, we shall compute the attention weights of syntactic information toward the other two types of information. The attention weight between *Hsy* and *Hse* is expressed as:

$$Att(H^{sy}, H^{sc}) = \sum\_{i=1}^{N} \mathfrak{a}\_{(i)} \cdot H^{sy}\_{(i)} \tag{10}$$

$$\mathfrak{a}\_{(i)} = \frac{\exp\left(\sum\_{i=1}^{N} H\_{(i)}^{sy^T} H\_{(i)}^{sc}\right)}{\sum\_{j=1}^{N} \exp\left(\sum\_{i=1}^{N} H\_{(i)}^{sy} H\_{(i)}^{sc}\right)} \tag{11}$$

Likewise, the attention weight of *Hsy* and *Hkg* is:

$$Att\left(H^{\text{sg}}, H^{\text{kg}}\right) = \sum\_{i=1}^{N} \mathfrak{a}\_{(i)} \cdot H^{\text{kg}}\_{(i)}\tag{12}$$

Then, two gating units are established to filter the noise from the input information, which are:

$$H\_i^L = \tanh\left[Att\left(H\_i^{sy}, H\_i^{sz}\right) \cdot \mathcal{W}\_s + b\_s\right] \tag{1.3}$$

$$H\_i^K = \operatorname{ReLU}\left[\operatorname{Att}\left(H\_i^{sy}, H\_i^{kg}\right) \cdot \mathcal{W}\_k + b\_k\right] \tag{14}$$

where *Wk*, *Ws*, *bk* and *bs* are trainable parameters of the proposed model. The aspect-related sentence representation is computed using cross product operation:

$$H = H^L \times H^K \tag{15}$$

#### *3.6. Sentiment Classifier*

The sentence representation *H* is sent to the sentiment classifier for sentiment polarity classification. A fully connected layer is developed to obtain the score for each sentiment polarity. The final sentiment probability distribution of the aspect is determined using a SoftMax classifier, which is written as:

$$
\tilde{H} = \text{Re}\, L\mathcal{U} \Big(\mathcal{W}\_1^T H + b\_1\Big) \tag{16}
$$

$$
\tilde{y} = \operatorname{softmax}(\dot{H})\tag{17}
$$

where *W<sup>T</sup>* <sup>1</sup> and *<sup>b</sup>*<sup>1</sup> are trainable parameters, and *<sup>y</sup>* is the predicted sentiment polarity.

The training of the proposed is conducted using the cross entropy and regularization as the loss function, i.e.

$$L = -\sum\_{i} \sum\_{j=1}^{N} y\_i^j \log \tilde{y}\_i^j \tag{18}$$

where *i* represents the *i*-th sample while *j* represents the *j*-th sentiment polarity. *N* is the number of sentiment polarities. *<sup>y</sup>* is the real distribution of sentiment and *<sup>y</sup>* is the predicted one.

#### **4. Experiment**

*4.1. Dataset*

In this experiment, three publicly available benchmark datasets are used for working performance evaluation, which are Laptop14 and Restaurant14 from SemEval2014 [36] and Twitter [37]. All the samples in the experiment are labeled as three different polarities, i.e., positive, neutral and negative. Each sample is a review sentence with the tagged aspect within it. Details of each dataset are exhibited in Table 1.

**Table 1.** Statistics of datasets.


#### *4.2. Implementation Details*

The initialization of sentence embeddings is conducted using both Glove [38] and Bert [39]. The batch sizes of Restaurant14, Laptop14 and Twitter are 32, 64 and 32, respectively. The learning rates of the Glove-based model and BERT-based model are separately set to 1e-3 and 2e-5. In addition, the Adam optimizer is adopted during model training.

#### *4.3. Baseline Methods*

Aiming to corroborate the working performance of the proposed model, seven stateof-the-art methods are taken for comparison.

Syntax- and semantic-based methods:


KG-based methods:


#### *4.4. Experiment Results*

Table 2 shows the experiment results on all datasets. As presented in Table 2, the proposed model outperforms the-state-of-the-art methods on the datasets Restaurant14 and Twitter. Notably, there is a considerable gap between our model and the baselines. The minimum accuracy gaps of the Glove-based model and Bert-based model are 3.57% (versus SK-GCN) and 3.15% (versus RGAT+BERT), which are significant. The main reason is that the introduction of external knowledge from Freebase provides a large amount of semantic information and relationships. With the enhancement of external information toward the aspect, the sentiment classification performance can be optimized. With respect to Laptop14, the working performance of the Sentic-GCN model is slightly better than the proposed method. One possible explanation for this is that the syntactic structure plays a more important role in the sentiment determination in sentences from Laptop14. The utilization of SenticNet [43] brings information to the adjacency matrices. In this way, the syntactic information can be extracted via graph convolution. Moreover, the pre-training of Bert further provides an improvement to the ALSC results. Since the proposed model is capable of integrating the sentence semantics, the sentence syntax and the external knowledge, we can thus expect better sentiment classification results with information supplementary on each other.


**Table 2.** Experimental results.

#### *4.5. Impact of GCN Layer Number*

An GCN is a key component in the syntax learning module for syntactic information encoding. On this occasion, we tend to explore the optimal GCN layer number for ALSC. The number of GCN is set to 1, 2, 3, 4 and 5, respectively. According to Table 3, the GCN

layer number of 2 obtains the best result in all evaluation settings. Comprehensively, the configuration of the GCN determines the amount of contextual information that is aggregated toward the aspect. It is clear that a one-layer GCN fails to capture sufficient syntactic information from the sentence. When the GCN layer number ranges from 3 to 5, the working performance of our model declines with the increasing number of layers. As such, there are two main considerations. Firstly, the connected context words increase in line with the increment of layer number, based on which the syntactic noise is introduced. Secondly, after multi-layer graph convolution, the nodes become less distinguishable whilst the node representation vectors tend to be consistent, which results in the over-smoothing problem of multi-layer GCN.


**Table 3.** ALSC accuracy in line with GCN layer numbers.

#### *4.6. Impact of KGE*

We employ four distinguishing KGE methods and investigate their effectiveness in external knowledge enhancement. Table 4 exhibits the ALSC results of the Glove-based model of different KGEs.

**Table 4.** ALSC results of different KGE methods.


TransE, TransR and TransH have minor accuracy compared with DistMult. The reason for this is that these three translation models determine the word relationship by using head and tail entities, rather than semantic information. By contrast, DistMult uses bilinear methods, which are capable of computing the semantic credibility of entities and relationships within vector space. That is, the introduction of semantic information results in the incorporating of external knowledge, and thus a better sentiment classification accuracy.

#### *4.7. Run Time and Parametric Amount*

To further evaluate the efficacy of the proposed model, the run time for training and testing, as well as the size of the parametric quantities of different methods are compared, see Table 5. Both SK-GCN and our model take advantage of the knowledge graph. Our model has a better performance in not only run time, but also the parameter amount. In this way, our model shows its superiority over the GCN-based method in dealing with knowledge graphs. On the other hand, the run time of BiGCN and the proposed model is comparable, but the test accuracy of our model is far better than RGAT and BiGCN, which indicates a higher working efficiency.


**Table 5.** Results of run time and the parameter amount of different methods.

#### *4.8. Case Study*

The visualization of attention weights distribution of a given sentence is presented in Figure 3. Words in the darker color are of greater weight, and vice versa. The former is processed by integration of the semantic learning module and syntax learning module, while the latter incorporates the external knowledge as well. According to Figure 3, more attention is given to the words that are close to the aspect by using only sentence-related information. One can easily see that the opinion word 'love' to aspect 'drinks' obtains a higher attentive weight, which is the same with 'great' to 'food'. However, for the aspect 'lychee martini', few syntactic- or semantic-related words are identified via the semantic learning module and syntax learning module. The introduction of external knowledge facilitates the sentiment word determination of 'lychee martini', which contributes to the sentiment classification.

#### **5. Conclusions**

In this work, we propose a model that integrates semantics, syntax and external knowledge on the task of ALSC. Aiming to sufficiently incorporate the external information into aspect words, we employ the KGE and aspect-specific attention mechanism to enhance the aspect features. Further, a semantic-learning module and a syntactic-learning module are devised to extract the sentence information. In addition, an information fusion module is established to integrate three types of information for sentiment classification. Experiments are carried out on three benchmark datasets. Our model is the best-performing method compared with the baselines.

Further work will focus on more details of the knowledge graph processing. The loss of graph structural information is still a question that in suspense.

**Author Contributions:** Conceptualization, H.Y. and Y.X.; methodology, H.Y.; formal analysis, H.Y. and G.L.; writing—original draft preparation, H.Y.; writing—review and editing, Q.C.; supervision, Y.X. and Q.C.; funding acquisition, Q.C. All authors have read and agreed to the published version of the manuscript.

**Funding:** This work was funded by the Characteristic Innovation Projects of Guangdong Colleges and Universities (Nos. 2018KTSCX049), the Science and Technology Plan Project of Guangzhou under Grant Nos. 202102080258 and 201903010013.

**Institutional Review Board Statement:** Not applicable.

**Informed Consent Statement:** Not applicable.

**Data Availability Statement:** Not applicable.

**Conflicts of Interest:** The authors declare no conflict of interest.

#### **References**


### *Article* **Pairwise Constraints Multidimensional Scaling for Discriminative Feature Learning**

**Linghao Zhang 1, Bo Pang 1, Haitao Tang 2,3, Hongjun Wang 2,3,\*, Chongshou Li 2,3 and Zhipeng Luo 2,3**


**\*** Correspondence: wanghongjun@swjtu.edu.cn

**Abstract:** As an important data analysis method in the field of machine learning and data mining, feature learning has a wide range of applications in various industries. The traditional multidimensional scaling (MDS) maintains the topology of data points in the low-dimensional embeddings obtained during feature learning, but ignores the discriminative nature between classes of lowdimensional embedded data. Thus, the discriminative multidimensional scaling based on pairwise constraints for feature learning (pcDMDS) model is proposed in this paper. The model enhances the discriminativeness from two aspects. The first aspect is to increase the compactness of the new data representation in the same cluster through fuzzy *k*-means. The second aspect is to obtain more extended pairwise constraint information between samples. In the whole feature learning process, the model considers both the topology of samples in the original space and the cluster structure in the new space. It also incorporates the extended pairwise constraint information in the samples, which further improves the model's ability to obtain discriminative features. Finally, the experimental results on twelve datasets show that pcDMDS performs 10.31% and 8.31% higher than PMDS model in terms of accuracy and purity.

**Keywords:** discriminative feature learning; multidimensional scaling; fuzzy *k*-means; pairwise constraint propagation; iterative majorization algorithm

**MSC:** 62P25

#### **1. Introduction**

The high-dimensional nature of large amounts of image data, text data, and video data is inevitable in today's big data era. Although image data and text data are simple and intuitive for humans, for machine learning models, there is a dimensional disaster. Because the direct use of raw data will not only increase the processing time of subsequent machine learning models, but may also reduce the performance of classification models and clustering models due to the influence of information such as redundancy and noise in the data. Based on this, how to obtain a more discriminative feature from the raw data has also become a research objective for many scholars.

In feature learning, supervised, semi-supervised and unsupervised feature learning methods are classified by whether or not the annotation information of the data is used in the learning process. The classical methods for unsupervised feature learning, semisupervised feature learning and unsupervised feature learning are principal component analysis (PCA) [1], semi-supervised dimensionality deduction (SSDR) [2] and linear discriminant analysis (LDA) [3], respectively. PCA, SSDR and LDA are all linear feature learning methods, which have the advantage of fast computation and the ability to quickly compute the data representation of a new sample through the projection matrix when a

**Citation:** Zhang, L.; Pang, B.; Tang, H.; Wang, H.; Li, C.; Luo, Z. Pairwise Constraints Multidimensional Scaling for Discriminative Feature Learning. *Mathematics* **2022**, *10*, 4059. https:// doi.org/10.3390/math10214059

Academic Editor: Jianping Gou

Received: 17 September 2022 Accepted: 24 October 2022 Published: 1 November 2022

**Publisher's Note:** MDPI stays neutral with regard to jurisdictional claims in published maps and institutional affiliations.

**Copyright:** © 2022 by the authors. Licensee MDPI, Basel, Switzerland. This article is an open access article distributed under the terms and conditions of the Creative Commons Attribution (CC BY) license (https:// creativecommons.org/licenses/by/ 4.0/).

<sup>3</sup> Manufacturing Industry Chains Collaboration and Information Support Technology Key Laboratory of Sichuan Province, Chengdu 611731, China

new sample arrives. In contrast, nonlinear feature learning based on stream learning allows the low-dimensional data representation to preserve the local topology of the original data as much as possible, such as locally linear embedding (LLE) [4], multidimensional scaling (MDS) [5] and laplacian eigenmaps (LE) [6], etc. Nonlinear feature learning can discover the potential flow structure inside the data well, but face the problem of new samples [7], so there are also a number of algorithms that maintain the local topology as much as possible in the projection process. For example, locality preserving projection (LPP) [8] and neighborhood preserving embedding [9] are projection matrices added to LE and LLE, respectively.

Feature learning has important research significance because of its many applications, such as data visualization [10], information retrieval [11], and clustering [12]. The MDS, as a commonly used streaming learning method, considers the distance information between samples in the feature learning process, but ignores the discriminative nature between data categories. Based on this, a discriminative multidimensional scaling based on pairwise constraints for feature learning (pcDMDS) is proposed in this paper in order to obtain more discriminative features.

The main contributions of this paper are shown below.


The remainder of this paper is organized as follows. In Section 2, existing works that related to this paper is reviewed. In Section 3, some preliminaries about our work are introduced. In Section 4, the details of the proposed model, including objective function and inference are illustrated. Experiments and results are described in Section 5. Finally, conclusions are drawn in Section 6.

#### **2. Related Work**

As a feature learning method that maintains the non-similarity of samples (generally distance), the MDS is widely used because of its simplicity and efficiency. Feature learning based on MDS can be divided into two categories [13], one is metric multidimensional scaling (MMDS) and the other one is non-metric multidimensional scaling (NMDS). In MMDS, the learned low-dimensional data representation is to preserve the distance of the original data as much as possible. But in NMDS, the low-dimensional data representation is to maintain the relationship of distance of the original data. Since the model proposed in this paper is a MDS of metrics, the MDS of metrics is described in detail below, and MDS generally refers to MMDS.

Different MDS methods have been proposed successively. The most classical MDS is to give the distance between samples and then find a suitable low-dimensional embedding. This method belongs to a nonlinear feature learning method, so that the sample distance between the low-dimensional embedding points keeps the distance corresponding to the original sample as much as possible, and its disadvantage is that it faces the problem of new samples. Webb [14] introduced a set of basis functions for feature mapping, and then achieved dimensionality reduction through a projection matrix. At the same time, the new data representation keeps the Euclidean distance of the original samples as much as possible, and an iterative update method was proposed to optimize the projection matrix. As an important manifold learning method, isometric feature mapping [15] uses the geodesic distance between samples to represent the dissimilarity between samples, and finally uses

classical MDS to get low-dimensional embedding of data. Bronstein et al. [16] proposed a generalized multidimensional scaling (GMDS), which uses a non-euclidean distance to represent the non-similarity of samples, and applied GMDS to 3D face matching. In order to enhance the discriminativeness of the features learned by MDS, Biswas [17] not only considered that the low-dimensional embedding points should keep the distance between the original images, but also considered that the distance between the low-dimensional embedding points corresponding to the same face should be as small as possible.

Clustering, as an unsupervised machine learning method, is widely used in many fields [18–20], and its purpose is to divide data into different clusters or subsets by some criteria. In order to efficiently discover potential cluster structures in data, different scholars have proposed different clustering algorithms, such as *k*-means (KM) algorithm [21], affinity propagation (AP) algorithm [22], and density peak (DP) algorithm [23]. With the proposal and refinement of fuzzy set theory, fuzzy clustering was proposed [24]. Unlike hard clustering such as *k*-means, soft clustering algorithms such as fuzzy clustering can not only discover the cluster structure among data efficiently, but also give the degree of affiliation between samples and class clusters, which can discover the overlapping class cluster structure well.

Fuzzy *k*-means was proposed by Bezdek et al. [24], which adopted the idea of fuzzy sets. They believe that there is a degree of attribution between a sample and a cluster ranging from 0 to 1. To improve the clustering performance of fuzzy *k*-means, Wang et al. [25] proposed a fuzzy *k*-means model based on the Euclidean distance with weights by considering the feature weights while calculating the distance. The traditional FKM fails when the input sample point information is not known and only the non-similarity information of sample points is available. Therefore, Hathaway et al. [26] proposed a non-euclidean relational fuzzy clustering, which can complete the fuzzy clustering under the condition of only given the dissimilarity between sample points. In order to adopt the clustering algorithm to noisy data, Nie et al. [27] combined fuzzy *k*-means with principal component analysis so that fuzzy *k*-means can be performed in the low-dimensional subspace obtained by principal component analysis. To obtain the potential cluster structure of the data on multi-view data, Zhu et al. [28] proposed an adaptive weighted multi-view clustering method. This method can not only automatically discover the importance, dispersion and other information of each view from multi-view data, but also synthesize the common information of each view to accomplish the clustering task.

Paired constraint information is widely used in feature learning to enhance the discriminant of the learned features because of its ability to provide similar relationships between samples. Zhang et al. [2] proposed a semi-supervised dimensionality reduction method based on paired constraint information, whose idea is to obtain new sample points by transforming the matrix so that the points with must-connect constraints are close together after the transformation, while the points with do-not-connect constraints are far away after the transformation. Du et al. [29] applied constraint transferring to dimensionality reduction and proposed a new semi-supervised feature learning method. The method first requires a pairwise constraint matrix with only 1, 0 and −1 values initially, where 1 means constraints must be connected, −1 means constraints do not connected and 0 means the constraint information is unknown. Then the constraint transferring algorithm is used to extend the constraint information to other samples. Then it constructs a new weight matrix using the extended constraint matrix, and finally uses the LPP algorithm for the new data representation.

#### **3. Preliminaries**

#### *3.1. Multidimensional Scaling*

The classical MDS is a nonlinear feature learning method. Its characteristic is that when only the dissimilarity between any two points is given, the corresponding new data representation can be directly obtained so that the Euclidean distance between samples is as equal to the given dissimilarity as possible, but it faces the problem of new samples. Webb [14] proposed the projective MDS (PMDS), so that the new data representation can be obtained from the original data representation by projection transformation. In this paper, a PMDS-based feature learning method is proposed and its principles are described in detail below.

Given the original data matrix *<sup>X</sup>* = [*x*1,..., *xN*] ∈ R*n*×*N*, where *<sup>n</sup>* and *<sup>N</sup>* denote the dimensionality and the number of the original samples, respectively. The learned low-dimensional data representation *<sup>Y</sup>* = [*y*1,..., *yN*] ∈ R*l*×*<sup>N</sup>* , where *<sup>l</sup>* denotes the dimensionality of the low-dimensional data representation. The loss function of MDS, a feature learning method that maintains the sample distance, is [30]:

$$\mathcal{O}\_{mds}(Y) = 1/2 \sum\_{i=1}^{N} \sum\_{j=1}^{N} s\_{ij} \left( d\_{ij} - \hat{d}\_{ij} \right)^2. \tag{1}$$

*dij* denotes the distance between the original data points *xi* and *xj*, and ˆ*dij* denotes the distance between the corresponding low-dimensional data representation *yi* and *yj*. And *S* = *sij* ∈ R*N*×*<sup>N</sup>* is a non-negative symmetric weight matrix, with larger *sij* indicating a greater desire for ˆ*dij* to be close to *dij* , and the literature [6] gives two ways of constructing the weights.


The MDS in Equation (1) is a nonlinear feature learning method that obtains a direct low-dimensional data representation *Y*. If new data arrives, its corresponding lowdimensional data representation cannot be obtained directly, that is, the so-called new sample problem. Webb incorporated the projection matrix into the MDS by means of pre-given radial basis functions to achieve nonlinear transformations, and proposed the PMDS, whose objective formulation is [14]:

$$\begin{split} \text{O}\_{pmds}(\mathcal{W}) &= \sum\_{i=1}^{N} \sum\_{j=1}^{N} s\_{ij} \left( d\_{ij} - \hat{d}\_{ij} \right)^{2} \\ &= \sum\_{i=1}^{N} \sum\_{j=1}^{N} s\_{ij} \left( d\_{ij} - \left\| \mathcal{W}^{\mathrm{T}} (\mathbf{x}\_{i} - \mathbf{x}\_{j}) \right\|\_{2} \right)^{2} . \end{split} \tag{2}$$

·<sup>2</sup> denotes the two-parametric number of vectors and *<sup>W</sup>* ∈ R*n*×*<sup>d</sup>* denotes the projection matrix, and it can be seen that *Y* is directly projected from *X*.

#### *3.2. Fuzzy k-Means Clustering*

Fuzzy clustering can give the degree of affiliation of samples with clusters, and the objective formula for fuzzy *k*-means is:

$$\begin{aligned} \text{CO}\_{f\text{km}}(\text{LI}, V) &= \sum\_{k=1}^{\mathbb{C}} \sum\_{i=1}^{N} \mu\_{ik}^{m} \|\mathbf{x}\_{i} - \boldsymbol{\upsilon}\_{k}\|\_{2}^{2}, \\ &\text{s.t. } \sum\_{k=1}^{\mathbb{C}} \boldsymbol{u}\_{i\bar{k}} = 1, \forall i = 1, 2, \dots, N, \boldsymbol{u}\_{i\bar{k}} \ge 0, \forall i = 1, \dots, N, \forall k = 1, \dots, \mathbb{C}. \end{aligned} \tag{3}$$

*<sup>U</sup>* = [*uik*] ∈ R*N*×*<sup>C</sup>* is the affiliation matrix, *uik* denotes the affiliation of *xi* with cluster *Ck*, and *m* 1 denotes the fuzzy index weights.

*3.3. Pairwise Constraint Transmission*

Given a sample *<sup>X</sup>* = [*x*1,..., *xN*] ∈ R*n*×*<sup>N</sup>* , and the pairwise constraint matrix *P* = *pij* <sup>∈</sup> <sup>R</sup>*N*×*N*. *pij* <sup>=</sup> 1 if there is a must-connect constraint between samples *xi* and *xj*, *pij* = −1 if there is a do-not-connect constraint between samples *xi* and *xj*, and *pij* = 0, if the constraint between *xi* and *xj* is unknown.

The constraint-passing algorithm is to extend the constraint matrix *P* to obtain more pairwise constraint information. The result matrix is *F* = *fij* <sup>∈</sup> <sup>R</sup>*N*×*N*, and *<sup>F</sup>* has the following properties:


#### **4. Proposed Method**

*4.1. Discriminative Multidimensional Scaling Based on Pairwise Constraints for Feature Learning*

The overall process of model pcDMDS is shown in Figure 1, which shows that after obtaining some of the pairwise constraint information through data *X*, more constraint information is first extended by the constraint transferring algorithm. For the extended constraint information, its value is [−1, 1]. If the value is greater than 0, it indicates a must-connect constraint, while if it is less than 0, it indicates a do-not-connect constraint. And the larger the absolute value, the higher the confidence level of the constraint. After obtaining the extended pairwise constraint information, for each iteration of the model, we hope to maintain the topology of the samples on the one hand. On the other hand, we hope to find the cluster structure within the samples and make the data representations of the samples in the same cluster close to their cluster centers. Furthermore, we hope to make the data representations of the samples with the must-connect constraints close to each other and the data representations of the samples with the do-not-connect constraints far from each other through pairwise constraints. After several iterations, the model can reach a balance between these three aspects. Thus, it further improves the discriminative properties of the learned features. After the model converges or reaches the maximum number of iterations, the new data representation is obtained by transforming the matrix.

Following this idea, the loss function can be described as the minimum of O*pcdmds*(*W*, *U*, *V*). Moreover,

$$\begin{split} \mathcal{O}\_{\text{pounds}\text{dus}}(\mathcal{W}, \mathcal{U}, \mathcal{V}) &= \sum\_{i=1}^{N} \sum\_{j=1}^{N} s\_{ij} \Big( d\_{ij} - \left\| \mathcal{W}^{T} (\mathbf{x}\_{i} - \mathbf{x}\_{j}) \right\|\_{2} \Big)^{2} \\ &+ \beta \sum\_{i=1}^{N} \sum\_{k=1}^{c} u\_{ik} \, ^{m} \Big\| \mathcal{W}^{T} \mathbf{x}\_{i} - \mathbf{z}\_{k} \Big\|\_{2}^{2} \\ &+ \lambda \left( \frac{1}{2N\_{\text{ML}}} \sum\_{(i,j) \in \text{ML}} \phi\_{ij} \right) \left\| \mathcal{W}^{T} (\mathbf{x}\_{i} - \mathbf{x}\_{j}) \right\|\_{2}^{2} \\ &- \frac{1}{2N\_{\text{CL}}} \sum\_{(i,j) \in \text{CL}} \phi\_{ij} \left\| \mathcal{W}^{T} (\mathbf{x}\_{i} - \mathbf{x}\_{j}) \right\|\_{2}^{2} \\ &= \mathcal{O}\_{1}(\mathcal{W}) + \beta \mathcal{O}\_{2}(\mathcal{W}, \mathcal{U}, \mathcal{V}) + \lambda \mathcal{O}\_{\text{pLs}}(\mathcal{W}), \\ \text{s.t.} &\sum\_{k=1}^{c} u\_{ik} = 1, \quad i = 1, 2, \dots, N\_{\text{s}} \\ &u\_{ik} \ge 0, \quad i = 1, \dots, N\_{\text{s}} \quad k = 1, \dots, \mathbb{C}. \end{split} \tag{2.10}$$

In Equation (4), ML denotes the set of the indexes of the sample pairs with mustconnect constraints and *CL* denotes the set of the indexes of the sample pairs with donot-connect constraints. *NML* denotes the number of sample pairs with must-connect constraints, and *ML* is the size of the set. Similarly, *NCL* denotes the number of sample pairs with do-not-connect constraints, and *CL* is the size of the set. Φ = *φij* denotes the confidence of the pairwise constraint between samples *xi* and *xj*, which takes the values [0, 1], and a larger value indicates a higher confidence of the pairwise constraint and a symmetric matrix.

**Figure 1.** The overall process of discriminative multidimensional scalar learning based on pairwise constraints.

From Equation (4), it can be seen that the objective formulation of the pcDMDS model can be divided into three parts. It can be seen that the pcDMDS model is a balance among these three terms.


To simplify Equation (4) for subsequent optimization, note the matrix Ψ = *ψij* ∈ R*N*×*N*, whose elements are defined as:

$$\psi\_{ij} = \begin{cases} \frac{1}{N\_{\text{ML}}} \phi\_{ij} & (i,j) \in \text{ML}\_{\prime} \\ -\frac{1}{N\_{\text{CL}}} \phi\_{ij} & (i,j) \in \text{CL}\_{\prime} \\ 0 & \text{otherwise} \end{cases} \tag{5}$$

Since Φ = Φ" , it follows that Ψ = Ψ*T*. Then Equation (4) can be rewritten as:

$$\begin{split} \mathbf{O}\_{\text{prd}m\text{ds}}(\mathcal{W},\mathcal{U},\mathcal{V}) &= \sum\_{i=1}^{N} \sum\_{j=1}^{N} \mathbf{s}\_{ij} \Big( d\_{ij} - \left\| \mathcal{W}^{\text{T}} (\mathbf{x}\_{i} - \mathbf{x}\_{j}) \right\|\_{2} \Big)^{2} \\ &+ \beta \sum\_{i=1}^{N} \sum\_{k=1}^{C} u\_{ik} \, ^{m} \Big\| \mathcal{W}^{\text{T}} \mathbf{x}\_{i} - \boldsymbol{\nu}\_{k} \Big\|\_{2}^{2} \\ &+ \frac{\lambda}{2} \sum\_{i=1}^{N} \sum\_{j=1}^{N} \Psi\_{ij} \Big\| \mathcal{W}^{\text{T}} (\mathbf{x}\_{i} - \mathbf{x}\_{j}) \Big\|\_{2}^{2} \\ &\text{s.t. } \sum\_{k=1}^{C} u\_{ik} = 1, i = 1, 2, \dots, N, \\ &u\_{ik} \ge 0, i = 1, \dots, N, k = 1, \dots, C. \end{split} \tag{6}$$

*4.2. The Inference of Discriminative Multidimensional Scaling Based on Pairwise Constraints for Feature Learning*

For the objective Equation (6), the parameters to be solved are the transformation matrix *W*, the samples and cluster affiliation matrix *U*, and the cluster center matrix *V*. Since the closed-form solutions of Equation (6) with respect to *W*, *U*, and *V* cannot be obtained directly, an iterative optimization approach is used for solving the problem.

(1) Fix *U* and *V*, and update *W*. At this point the target equation in Equation (6) is a function of *W* only and can be expressed as:

$$L\_1(\mathcal{W}) = \mathcal{O}\_1(\mathcal{W}) + \beta \mathcal{O}\_2(\mathcal{W}) + \lambda \mathcal{O}\_{\text{pclloss}}\,(\mathcal{W}),\tag{7}$$

to facilitate the solution, first rewrite O*pcloss*(*W*):

$$\begin{split} \mathcal{O}\_{\text{pcloss}}\left(\mathcal{W}\right) &= \frac{1}{2}\operatorname{Tr}\Big(\mathcal{Y}D\_{\mathbf{Y}}Y^{\mathrm{T}}\Big) - \operatorname{Tr}\Big(\mathcal{Y}\Psi Y^{\mathrm{T}}\Big) + \frac{1}{2}\operatorname{Tr}\Big(\mathcal{Y}D\_{\mathbf{Y}^{\mathrm{T}}}Y^{\mathrm{T}}\Big) \\ &= \operatorname{Tr}\Big(\mathcal{Y}(D\_{\mathbf{Y}}-\Psi)Y^{\mathrm{T}}\Big) \\ &= \operatorname{Tr}\Big(\mathcal{Y}L\_{\Psi}Y^{\mathrm{T}}\Big) \\ &= \operatorname{Tr}\Big(\mathcal{W}^{\mathrm{T}}XL\_{\Psi}X^{\mathrm{T}}\mathcal{W}\Big). \end{split} \tag{8}$$

Since *A*<sup>2</sup> <sup>2</sup> <sup>=</sup> Tr *AA*T = Tr *A*T*A* , Tr(·) denotes the trace of the matrix, a simplification of the second term O2(*W*) in Equation (7) gives:

$$\mathcal{O}\_2(\mathcal{W}) = \text{Tr}\left(\mathcal{W}^T X D\_{\hat{\mathcal{Q}}} X^T \mathcal{W}\right) - 2 \operatorname{Tr}\left(\mathcal{W}^T X \hat{\mathcal{U}} V^T\right) + \text{Tr}\left(V D\_{\hat{\mathcal{Q}}^T} V^T\right). \tag{9}$$

*U*˜ = *um ik* ∈ R*N*×*C*, *DU*˜ and *DU*<sup>ˆ</sup> <sup>T</sup> are all diagonal arrays,

$$D\_{\mathcal{V}} = \begin{bmatrix} (D\_{\mathcal{V}})\_{11} & & \\ & \ddots & \\ & & (D\_{\mathcal{D}})\_{NN} \end{bmatrix}, \quad D\_{\mathcal{Q}^{\mathrm{T}}} = \begin{bmatrix} (D\_{\mathcal{Q}^{\mathrm{T}}})\_{11} & & \\ & \ddots & \\ & & (D\_{\mathcal{Q}^{\mathrm{T}}})\_{\mathrm{CC}} \end{bmatrix}. \tag{10}$$

The objective function in Equation (8) can be optimized using the IMA [5,14,17] algorithm, the constructed auxiliary function is *σpcdnds*(*W*, *Z*) , which is defined as:

$$\begin{split} \sigma\_{\text{pcmdMs}}(\mathcal{W}, \mathcal{Z}) &= \text{Tr}\left(\mathcal{W}^{\text{T}} A \mathcal{W}\right) + \sum\_{i=1}^{N} \sum\_{j=1}^{N} s\_{ij} d\_{ij}^{2} - 2 \,\text{Tr}\left(\boldsymbol{Z}^{\text{T}} \mathcal{D}(\boldsymbol{Z}) \mathcal{W}\right) \\ &+ \beta \Big(\text{Tr}\left(\mathcal{W}^{\text{T}} \mathcal{X} D\_{\Omega} \mathcal{X}^{\text{T}} \mathcal{W}\right) - 2 \,\text{Tr}\left(\mathcal{W}^{\text{T}} \mathcal{X} \hat{\mathcal{U}} V^{\text{T}}\right) + \text{Tr}\left(V D\_{\Omega \Gamma} V^{\text{T}}\right)\Big) \\ &+ \lambda \,\text{Tr}\left(\mathcal{W}^{\text{T}} \mathcal{X} L\_{\Psi} \mathcal{X}^{\text{T}} \mathcal{W}\right). \end{split} \tag{11}$$

*A* in Equation (11) is defined as:

$$A = \sum\_{i=1}^{N} \sum\_{j=1}^{N} s\_{ij} \left(\mathbf{x}\_i - \mathbf{x}\_j\right) \left(\mathbf{x}\_i - \mathbf{x}\_j\right)^T. \tag{12}$$

The definition of *D*(*Z*) in Equation (11) is:

$$\begin{array}{ll} \mathrm{D}(Z) = \sum\_{i=1}^{N} \sum\_{j=1}^{N} c\_{i\bar{j}}(Z) \left( \mathbf{x}\_{i} - \mathbf{x}\_{\bar{j}} \right) \left( \mathbf{x}\_{i} - \mathbf{x}\_{\bar{j}} \right)^{\mathrm{T}}, \\\\ c\_{i\bar{j}}(Z) = \begin{cases} s\_{i\bar{j}} d\_{i\bar{j}} / \hat{d}\_{i\bar{j}}(Z) & \hat{d}\_{i\bar{j}}(Z) > 0, \\ 0 & \hat{d}\_{i\bar{j}}(Z) = 0. \end{cases} \end{array} \tag{13}$$

In Equation (13), <sup>ˆ</sup>*dij*(*Z*) = *Z*T *xi* − *xj* 2

Calculate the gradient of *W* with respect to Equation (11) and set the gradient to be 0, then we have the update equation of *W*:

.

$$\mathcal{W} = \left(A + \beta X D\_{\hat{\mathcal{U}}} X^{\mathrm{T}} + \lambda X L\_{\mathbb{P}} X^{\mathrm{T}}\right)^{-1} \left(\mathcal{D}(Z) Z + \beta X \hat{\mathcal{U}} V^{\mathrm{T}}\right). \tag{14}$$

(2) Fix the matrices *W* and *V*, and solve for *U*. At this point, the first and third terms in Equation (6) are constant terms, and the optimization of Equation (6) is equivalent to the optimization of:

$$\begin{split} \mathcal{L}\_2(\mathbf{U}) &= \sum\_{i=1}^N \sum\_{k=1}^\mathbb{C} u\_{ik}^{\text{m}} \| y\_i - v\_k \|\_2^2 \\ &= \sum\_{i=1}^N \sum\_{k=1}^\mathbb{C} u\_{ik}^{\text{m}} d^2(y\_i, v\_k), \\ &\text{s.t. } \sum\_{k=1}^\mathbb{C} u\_{ik} = 1, i = 1, 2, \dots, N, \\ &u\_{ik} \ge 0, i = 1, \dots, N, k = 1, \dots, \mathbb{C}. \end{split} \tag{15}$$

Using the Lagrangian multiplier method [31]:

$$\mathcal{L}\_{\lambda}(\mathcal{U}) = \sum\_{i=1}^{N} \sum\_{k=1}^{\mathcal{C}} \mu\_{ik}^{m} d^{2}(y\_{i}, v\_{k}) + \lambda \left(\sum\_{k=1}^{\mathcal{C}} \mu\_{ik} - 1\right). \tag{16}$$

By:

$$\begin{aligned} \frac{\partial \mathcal{L}\_{\lambda}(\mathcal{U})}{\partial u\_{ik}} &= m(u\_{ik})^{m-1} d^2(y\_{i\prime} v\_k) - \lambda = 0, \\\\ \frac{\partial \mathcal{L}\_{\lambda}(\mathcal{U})}{\partial \lambda} &= \sum\_{k=1}^{C} u\_{ik} - 1 = 0, \end{aligned} \tag{17}$$

solve the update equation for *uik* as:

$$u\_{i\hbar} = \frac{1}{\sum\_{j=1}^{c} \left(\frac{1}{d\left(y\_i, v\_j\right)}\right)^{\frac{2}{m-1}}} \left(\frac{1}{d\left(y\_i, v\_k\right)}\right)^{\frac{2}{m-1}} = \frac{1}{\sum\_{j=1}^{c} \left(\frac{d\left(y\_i, v\_k\right)}{d\left(y\_i, v\_j\right)}\right)^{\frac{2}{m-1}}}.\tag{18}$$

The iterative update of the U matrix is given by:

$$\mu\_{ik} = \begin{cases} 1/\sum\_{j=1}^{c} \left( \frac{d(y\_i, x\_k)}{d\left(y\_i, x\_j\right)} \right)^{\frac{2}{m-1}} & \mathcal{I}\_i = \mathcal{Q}\_i \\ \frac{1}{|\mathcal{I}\_i|} & \mathcal{I}\_i \neq \mathcal{Q}\_i k \in \mathcal{I}\_{i\prime} \\ 0 & \mathcal{I}\_i \neq \mathcal{Q}\_i k \notin \mathcal{I}\_i \end{cases} \tag{19}$$

I*<sup>i</sup>* = {*r* ∈ N≤*<sup>C</sup>* | *yi* = *vr*} , N≤*<sup>C</sup>* denotes the set of positive integers less than or equal to *C*, and |I*i*| denotes the number of elements in the set I*i*. It means that when there exists a sample point *yi* that happens to be the cluster center of multiple clusters, *yi* has equal affiliation with these clusters, both being 1/|*Ii*|.

(3) Fix *W* and *U*, and update *V*. Similar to step (2):

$$\begin{split} \mathbb{L}\_{3}(V) &= \sum\_{i=1}^{N} \sum\_{k=1}^{C} \mu\_{ik}^{\text{m}} \left\| \mathbf{W}^{\text{T}} \mathbf{x}\_{i} - \boldsymbol{\upsilon}\_{k} \right\|\_{2}^{2} \\ &= \sum\_{i=1}^{N} \sum\_{k=1}^{C} \mu\_{ik}^{\text{m}} \text{Tr} \left( y\_{i} \mathbf{y}\_{i}^{\text{T}} - y\_{i} \boldsymbol{\upsilon}\_{k}^{\text{T}} - \boldsymbol{\upsilon}\_{k} \mathbf{y}\_{i}^{\text{T}} + \boldsymbol{\upsilon}\_{k} \boldsymbol{\upsilon}\_{k}^{\text{T}} \right) . \end{split} \tag{20}$$

Calculate the partial derivative with respect to *vk* for Equation (20):

$$\frac{\partial \mathcal{L}\_3(V)}{\partial v\_k} = \sum\_{i=1}^N u\_{ik}^m (-y\_i - y\_i + 2v\_k) = \sum\_{i=1}^N u\_{ik}^m (2v\_k - 2y\_i). \tag{21}$$

According to Equations (20) and (21), the iterative update of *V* can be derived as:

$$w\_k = \sum\_{i=1}^{N} u\_{ik}^m y\_i / \sum\_{i=1}^{N} u\_{ik}^m. \tag{22}$$

#### *4.3. Algorithm*

#### 4.3.1. Algorithm Description

It can be seen from Algorithm 1 that the algorithm flow of pcDMDS is mainly divided into two processes. The first process is mainly to expand pairwise constraint information through constraint transferring. The second process is to update it iteratively according to the update formulas of *W*, *U* and *V*, and output the transformation matrix after the iteration is completed. Specifically, for the first process, the pairwise constraint matrix *P* is first constructed according to the set of sample pairwise constraints. Then the extended pairwise constraint information *F* is obtained through the constraint transfer algorithm, and *F* is post-processed and assigned to Ψ. Then the distance matrices *D*, *S* and *A* are calculated respectively, and then the *W*, *V* and *U* matrices are initialized. The second process starts the iteration process, updating *W*, *U* and *V* in turn, and stops iteration when *W* and *U* are stable or reach the maximum number of iterations. Finally, the transformation matrix *W* is returned.


4.3.2. Study on Computational Complexity

The time complexity of the model is discussed. According to the algorithm flow in Algorithm 1, pcDMDS needs to call the constraint passing algorithm of the *E*2*CP* with a time complexity of *O N*<sup>3</sup> . The time complexity of the matrix *D*(*Z*) is *O n*2*N* + *nN*<sup>2</sup> . The symmetric matrix of size *D*(*Z*) and its Moore-Penrose inverse can be obtained by singular value decomposition, and since the time complexity of singular value decomposition is *O n*3 [32], the time complexity of updating *W* once is *O n*2*N* + *nN*<sup>2</sup> + *n*<sup>3</sup> according to Equation (14). According to Equation (19), the time complexity of updating the matrix *U* once is *O NC*2*l* . From Equation (22), it is known that the time complexity of updating the cluster center matrix *V* once is *O*(*NCl*). Considering that the updates of matrices *W*, *U* and *V* are performed sequentially, and the time complexity of the three updates and the time complexity of constraint passing are combined, it is known that the time complexity of the pcDMDS algorithm is *O N*<sup>3</sup> = *T nN*<sup>2</sup> + *nn*2*N* + *n*<sup>3</sup> + *NC*2*l* , where *T* is the maximum number of iterations.

Then, the space complexity of the model is discussed. The input data matrix *X* has size of *Nn*. The space complexity of *P*, *F*, *D* and *S* are *O N*<sup>2</sup> . The space complexity of *A* is *O Nn* + *N*<sup>2</sup> + *n*<sup>2</sup> . *W*, *V* and *U* has the size of *nl*, *lC* and *NC*, respectively. During the iteration, the space complexity of *<sup>U</sup>*˜ and *DU*˜ are *<sup>O</sup>*(*NC*) and *<sup>O</sup>*(*N*). The space complexity of *D*(*Z*) is *O*(*Nl*). The space complexity of *W* is *O n*<sup>2</sup> + *nN* + *N*<sup>2</sup> + *nl* + *NC* + *Cl* . The space complexity of *Y* is *O*(*lN* + *ln* + *nN*). Therefore, the total space complexity is *O Nn* + *N*<sup>2</sup> + *n*<sup>2</sup> + *nl* + *lC* + *NC* + *Nl* .

#### 4.3.3. Visualization

Figure 2 shows the visualization results of the wine dataset with 178 samples, 3 categories, and the number of attributes of each sample is 13. It can be seen from the visualization results in Figure 2a that the boundaries of different categories in the 2D data representation are fuzzy and unclear, that is, the discriminability between different categories has not been improved, and since the MDS method maintains the distance between samples, the samples in the same category are not more compact. In order to more intuitively show that pcDMDS can learn more discriminative features, the visualization result graph of pcDMDS is shown in Figure 2b. By comparing Figure 2a,b, it can be found that compared with MDS, pcDMDS has a more compact sample distribution in the same category in the new data representation, and the boundaries between different categories are clearer, which makes the learning features more discriminative.

**Figure 2.** Visualization of wine dataset after dimensionality reduction using MDS and pcDMDS.

#### **5. Experiments**

#### *5.1. Datasets*

The datasets used for the experiments on the discriminative multidimensional scalar feature learning algorithm based on pairwise constraints are from 12 publicly available datasets in the MSRA- MM [33] database. Table 1 describes the details of the 12 datasets used.



**Table 1.** *Cont.*

#### *5.2. Experimental Setting*

The pairwise constraint loss terms in pcDMDS are controlled by the parameter *λ* to control their weights. The pairwise constraint information in the experiment is obtained directly from the ten percent label information, and then the constraint transferring algorithm obtains the extended constraint information as the final pairwise constraint information. For the pcDMDS algorithm, the parameter *λ* is set to 0.8, and the parameter *α* in the constraint transferring algorithm is set to 0.1. In order to reduce the differences in the experimental results, all feature learning algorithms are run 10 times in the experiments, and then the average of the 10 times is taken as the final result.

The experiments of pcDMDS algorithm are to evaluate the ability of pcDMDS to learn discriminative features. The experiments are designed in such a way that multiple clustering experiments are performed on the low-dimensional data representation obtained from the original data, the low-dimensional data representation obtained from the PMDS algorithm and the data representation obtained from pcDMDS, respectively. If the data representation is more discriminative, the clustering algorithm performs better. The selected clustering algorithms include KM, AP and DP.

#### *5.3. Evaluation Metric*

Since features with discriminative properties tend to improve the performance of subsequent machine learning tasks, the discriminative properties of the learned features can be evaluated by evaluating the performance of subsequent machine learning tasks. The subsequent machine learning tasks include clustering tasks and classification tasks, so the performance of the learned features is evaluated by using the evaluation metrics of clustering and classification.

#### 5.3.1. Accuracy

Accuracy, a common metric for clustering, measures the degree of difference between the sample cluster results given by a clustering model and the true labels of the samples. The calculation of clustering accuracy and classification accuracy is slightly different. For clustering, the accuracy is computed as [34].

$$Acc = \frac{1}{N} \sum\_{i=1}^{N} \delta(l\_{i\prime} map(r\_i)). \tag{23}$$

*N* denotes the number of sample points, and *map*(·) is a function that maps the cluster index to the category label. *li* and *ri* denote the category label and the cluster index of sample point *xi*, respectively. *δ*(*a*, *b*) is a function whose value is 1 when *a* = *b*. Otherwise, it is 0. For the classification task, *ri* denotes the classifier's predicted category label, at this time *map*(·) can be considered as a constant mapping. The output value is equal to the input value.

#### 5.3.2. Purity

Purity is a common metric used to measure the performance of clustering algorithms and is defined as [35]:

$$\text{Purity} = \frac{1}{N} \sum\_{k=1}^{C} \max\_{1 \le r \le q} n\_k^r. \tag{24}$$

*N* denotes the number of sample points, *C* denotes the number of clusters, *k* denotes the cluster index, and *q* is the number of classes. In general, *q* is equal to *C*. *n<sup>r</sup> <sup>k</sup>* denotes the number of samples with class label *r* in the *k* cluster.

#### 5.3.3. Friedman Test

Friedman statistic is a statistical method for non-parametric testing to evaluate the overall difference in performance of a set of algorithms on different datasets. Friedman statistic requires first getting the ranking of each algorithm's performance on the same dataset, with the best performing algorithm ranked as 1, the next best algorithm ranked as 2, and so on to get the rankings of all algorithms, and if there is the same performance, the average ranking value is taken. The ranking value of an algorithm is also called rank value. Specifically, the Friedman statistic is defined as [36]:

$$X\_2^F = \frac{12a}{b(b+1)} \left[ \sum\_{j=1}^b R\_j^2 - \frac{b(b+1)^2}{4} \right]. \tag{25}$$

The *a* denotes the number of datasets, *b* denotes the number of algorithms, *Rj* = <sup>1</sup> *<sup>a</sup>* <sup>∑</sup>*<sup>a</sup> <sup>i</sup>*=<sup>1</sup> *rji* , *rji* denotes the rank value of the *j*-th algorithm on the *i*-th dataset, and it can be seen that *Rj* denotes the average rank value of the *j*-th algorithm on all datasets, *X*<sup>2</sup> *F* obeys the chi-square distribution with degrees of freedom *b* − 1.

Iman and Davenport improved the deficiencies of the Friedman statistic *X*<sup>2</sup> *<sup>F</sup>* by proposing a better statistic defined as [37]:

$$F\_F = \frac{(a-1)X\_F^2}{a(b-1) - X\_F^2}.\tag{26}$$

*FF* is the *F* distribution with degrees of freedom *b* − 1 and (*b* − 1)(*a* − 1). The *p*-value is obtained by looking up the table, and the significance of the differences between all algorithms is evaluated based on the *p*-value.

#### *5.4. Results*

Tables 2 and 3 give the accuracy and purity results obtained by clustering the 12 data sets by KM, AP and DP under three different data representations, respectively. Specifically, in Table 2, columns KM, AP and DP are the clustering accuracies of the three algorithms on the original data representation. PMDS-KM, PMDS-AP and PMDS-DP are the clustering accuracies of the three clustering algorithms on the low-dimensional data representation obtained by the PMDS algorithm. pcDMDS-KM, pcDMDS-AP and pcDMDS-DP are the clustering accuracies of the three clustering algorithms on the low-dimensional data representation obtained by the pcDMDS algorithm. The Avg column is the mean of columns KM, AP and DP. Column PMDS-Avg is the mean value of columns PMDS-KM, PMDS-AP and PMDS-DP. Similarly, column pcDMDS-Avg is the mean value of columns pcDMDS-KM, pcDMDS-AP and pcDMDS-DP. The meaning of the table headers in Table 3 is similar to that in Table 2, except that the data in the table are purity rather than accuracy, which is not repeated here.

From Table 2, it can be seen that 10 of the models with the highest accuracy in these 12 datasets are on the data representation learned by pcDMDS features (bolded data in the table), and 2 are on the original data representation, which indicates that pcDMDS can improve the discriminatory performance of the data representation. Moreover, for the same clustering algorithm, the performance exhibited on the data representation obtained by the pcDMDS algorithm is overwhelmingly better than the original data representation and the PMDS data representation. In addition, in terms of the average accuracy, the 12 highest average accuracies are in the feature representation of the pcDMDS algorithm, and the average accuracy of the data representation obtained by the pcDMDS algorithm is 10.31% and 7.41% higher than that of the PMDS and the original space, respectively. This also reflects that the data representation obtained after the DMDS feature learning algorithm can improve the performance of the subsequent machine learning compared with the PMDS and the original data representation.


**Table 2.** Accuracy of clustering with different data representations.

**Table 3.** Purity of clustering with different data representations.


The overall performance of the model is then evaluated based on the Friedman statistic. Based on the last three columns of Table 2, the ranking values for the performance of different data representations in each dataset can be first derived. The average ranking values of 2.4583, 2.5416 and 1 for Avg, PMDS-Avg and pcDMDS-Avg on the 12 datasets can be calculated, respectively. Since there are 12 datasets with three types of averages, *FF* obeys a degree of freedom of 3 − 1 = 2 and (12 − 1)(3 − 1) = 22 for the *F* distribution. From the F(2, 22) distribution, the *<sup>p</sup>*-value can be calculated as 2.2082 × <sup>10</sup>−7, so the original hypothesis is rejected at a high significance level, and the comprehensive evaluation of the pcDMDS algorithm outperforms the PMDS algorithm. The data representation obtained by the pcDMDS algorithm is more discriminative than the data representation obtained by PMDS and the original data representation.

Table 3 lists the purity of the clustering results on the different data representations. It can be seen that the 12 highest purity are on the data representation of pcDMDS. Overall, the clustering performance on pcDMDS is better than PMDS and raw space. Also, the average purity of the data representation obtained by the pcDMDS algorithm is 8.31% and 9.18% higher than that of the PMDS and the original space, respectively.

Similarly, the Friedman statistic is used to evaluate the overall performance of the model. According to the last three columns of Table 3, the average ranking values of Avg, PMDS-Avg and pcDMDS-Avg can be obtained as 2.5, 2.5 and 1, respectively. The Friedman statistic can be calculated as X<sup>2</sup> *<sup>F</sup>* = 13.0833, and then the Iman-Davenport as *FF* = 13.1832. The *<sup>p</sup>*-value can be calculated from the F(2, 22) distribution as 1.7245 × <sup>10</sup>4, so the original hypothesis is rejected at a higher significance level, and the combined evaluation of pcDMDS algorithm is better than PMDS and the original space.

In terms of accuracy and purity, it can be seen that the data representation obtained by pcDMDS has a better performance for subsequent clustering algorithms than the original data representation and the data representation obtained by PMDS, which can learn more discriminative features. For big datasets, pcDMDS can enhance the discriminativeness by considering both the topology of samples in the original space and the cluster structure in the new space, and also incorporating the extended pairwise constraint information in the samples.

#### **6. Conclusions**

In this paper, a feature learning algorithm named pcDMDS is proposed and the discriminability is enhanced in two aspects. Firstly, the ability to automatically discover clusters in samples by fuzzy *k*-means, so that new data representations corresponding to samples in the same cluster are close to the cluster center during feature learning. Then the pairwise constraint information between more samples, noted as extended pairwise constraint information, is obtained by a constraint transferring algorithm based on the pairwise constraint information between a given part of samples. In the whole process of feature learning, the ability of the original model to obtain discriminative features is further improved. Because pcDMDS not only considers the topological structure of the sample in the original space and the cluster structure in the new space, but also incorporates the extended pairwise constraint information in the sample. However, the effect of different values of parameter *λ* on the clustering performance of pcDMDS was analyzed in pcDMDS, but the values are fixed, so the effect of *β* and *λ* can be considered jointly in the future. Plus, the model does not use incremental learning, and it can be put into research in the future work.

**Author Contributions:** Conceptualization, L.Z., B.P., H.T. and H.W.; methodology, L.Z., B.P. and H.T.; software, L.Z., B.P., H.T. and H.W.; validation, C.L., Z.L. and H.T.; formal analysis, L.Z.; investigation, Z.L.; resources, B.P., H.T., H.W. and C.L.; data curation, L.Z.; writing—original draft, L.Z., B.P., H.T., H.W. and C.L.; writing—review and editing, L.Z., B.P., H.T., H.W., C.L. and Z.L.; visualization, L.Z., B.P. and H.W.; supervision, H.W. and C.L.; project administration, Z.L. All authors have read and agreed to the published version of the manuscript.

**Funding:** This research work was supported by Science and Technology Project of State Grid Sichuan Electric Power Company (52199722000Y), and by the National Natural Science Foundation of China under Grant Nos (62276216, 62202395).

**Data Availability Statement:** Data are freely available from MSRA.

**Conflicts of Interest:** The authors declare no conflict of interest.

#### **References**


### *Article* **Triplet Contrastive Learning for Aspect Level Sentiment Classification**

**Haoliang Xiong 1,†, Zehao Yan 1,†, Hongya Zhao 2, Zhenhua Huang <sup>3</sup> and Yun Xue 1,\***


**Abstract:** The domain of Aspect Level Sentiment Classification, in which the sentiment toward a given aspect is analyzed, attracts much attention in NLP. Recently, the state-of-the-art Aspect Level Sentiment Classification methods are devised by using the Graph Convolutional Networks to deal with both the semantics and the syntax of the sentence. Generally, the parsing of syntactic structure inevitably incorporates irrelevant information toward the aspect. Besides, the syntactic and semantic alignment and uniformity that contribute to the sentiment delivery is currently neglected during processing. In this work, a **Triplet Contrastive Learning Network** is developed to coordinate the syntactic information and the semantic information. To start with, the aspect-oriented sub-tree is constructed to replace the syntactic adjacency matrix. Further, a sentence-level contrastive learning scheme is proposed to highlight the features of sentiment words. Based on The Triple Contrastive Learning, the syntactic information and the semantic information are thoroughly interacted and coordinated whilst the global semantics and syntax can be exploited. Extensive experiments are performed on three benchmark datasets and achieve accuracies (BERT-based) of 87.40, 82.80, 77.55 on Rest14, Lap14, and Twitter datasets, which demonstrate that our approach achieves state-of-the-art results in Aspect Level Sentiment Classification task.

**Keywords:** Aspect Level Sentiment Classification; Contrasitve Learning; Graph Convolutional Networks

**MSC:** 18C50

#### **1. Introduction**

Aspect Level Sentiment Classification (ALSC) is a fundamental subtask of fine-grained sentiment analysis, which currently receives a great deal of attention [1]. The main focus of ALSC is to identify the sentiment polarity (e.g., positive, neutral or negative) of aspects explicitly given in sentences. For example, in the sentence "*The price is reasonable although the service is poor*" (Figure 1), the sentiment toward aspects *price* and *service* is positive and negative, respectively.

Advances of deep neural networks bring paradigm shift to various tasks of NLP and the ALSC is no different [2–4]. The attention-based network is a most common approach that exploits the semantic information to capture the sentiment words of the given aspect. In Figure 1, more attentive weights can be assigned to the sentiment words *reasonable* and *poor* via attention mechanism. However, the use of semantic feature alone can result in the misunderstanding of contextual words, especially for sentences of complex syntax structure. More recently, the application of Graph Convolutional Networks (GCN) in ALSC is both creative and practical [5]. For one thing, the encoding of syntactic information

**Citation:** Xiong, H.; Yan, Z.; Zhao, H.; Huang, Z.; Xue, Y. Triplet Contrastive Learning for Aspect Level Sentiment Classification. *Mathematics* **2022**, *10*, 4099. https://doi.org/10.3390/ math10214099

Academic Editors: Jianping Gou, Weihua Ou, Shaoning Zeng and Lan Du

Received: 7 October 2022 Accepted: 1 November 2022 Published: 3 November 2022

**Publisher's Note:** MDPI stays neutral with regard to jurisdictional claims in published maps and institutional affiliations.

**Copyright:** © 2022 by the authors. Licensee MDPI, Basel, Switzerland. This article is an open access article distributed under the terms and conditions of the Creative Commons Attribution (CC BY) license (https:// creativecommons.org/licenses/by/ 4.0/).

<sup>1</sup> School of Electronics and Information Engineering, South China Normal University, Foshan 528225, China; xionghl@m.scnu.edu.cn (H.X.); yzh\_scnu@m.scnu.edu.cn (Z.Y.)

using GCN mitigates the deficiencies of long-distance dependencies among words [6,7]. For another, not just the syntax, but also the semantic information can be processed by GCN, which gives rise to opportunities to the integration of semantic features. As such, the state-of-the-art approaches work on developing multi-channel GCNs to deal with multiple information [8,9].

 **- --** -


**Figure 1.** An example of ALSC.

Despite the progress of GCN-based method in ALSC, two main limitations are observed. **On the one hand**, most syntactic parsing is performed on the whole sentence without considering the importance of key phrases (e.g., aspect words, opinion words and etc.) to sentiment determination. In such a manner, redundant information or even noise can be incorporated during feature extraction. **On the other hand**, current methods set the semantic information and syntactic information in two individual spaces for feature extraction and fuse their features in a elementary way. But the alignment and uniformity of these two categories of features are ignored [10].

Inspired by the methods reported by [8,11], a **Triplet Contrastive Learning Network (TCL)** for ALSC is proposed to address the aforementioned issues. For the exploiting of syntactic information, we start with reconstructing the syntax dependency tree by setting the aspect as the root according to [12] (Figure 2). The dependencies between aspect word and other words are explicitly established, which contributes to the capturing of opinions words to the aspect and restricting the introduction of redundant information. As presented in [13], the key phrase plays a pivot role in delivering the essence of texts. To further filter the noise and highlight the key information, a contrastive learning scheme is proposed to magnify the significance of sentiment-related words. In ALSC tasks, the key phrases are either nouns, verbs, adjectives, or adverbs of degree [14]. With the application of masking mechanism, both positive and negative examples are generated and fed into the contrastive learning module to enhance the impacts of key phrases and distill the syntactic features.

**Figure 2.** Reconstruction of aspect-oriented syntax dependency tree.

With respect to the integration of sentence syntax and semantics, recent publications reveal that they are distinct and related [8,15]. Likewise, the features from both space, conveying sentiment toward the aspect, also have a similar relationship between each other. For this reason, the alignment of both features can facilitate the information integration. Concretely, the features, within either syntactic or semantic space, expressing the same sentiment polarity can be aligned while those expressing different sentiment polarities can be separated. With this, the interaction between syntactic information and semantic information is carried out, based on which a dual-contrastive learning scheme is devised. For each data within the mini-batch, the features of the same sentiment polarity are getting closer based on dual-contrastive learning, and vice versa. In this way, features of both categories are thoroughly interacted and aligned. We can thus leverage feature integration to improve the ALSC performance.

The contributions of this paper are as follows:


This work is organized as follows. Section 2 gives an overview of relevant work of ALSC and contrastive learning. Section 3 describes the TCL model in details. In Section 4, the experiment is depicted, as well as the presentation of result analysis. Concluding remarks are given in Section 5.

#### **2. Related Work**

#### *2.1. Aspect Level Sentiment Classification*

Sentiment classification tasks mainly focus on capturing the sentiment information from the given text [16,17]. ALSC aims to classify the sentiment polarity of a specific aspect from given texts. Within ALSC, a more detailed analysis about the sentiment associated with the aspect is performed by using the textual information. Early research focuses on employing CNN- and RNN-based method, together with the integration of attention mechanisms or knowledge distillation [18,19], to obtain aspect-related information. As such, the utilization of attention mechanism to precisely capture the aspect-aware contextual information becomes a main topic [2,3]. In recent years, GCN-based models rise to prominence in a variety of NLP tasks, which is capable of alleviating the defects of attention networks. On the task of ALSC, Ref. [6] first apply GCN to tackle the syntax dependency and resolve the long-term multi-word dependencies. Later work aims to establish the syntax structure and extract aspect-related features [7]. Ref. [12] re-shape the syntax dependency tree into an aspect-oriented sub-tree, in order to determine the connection between aspects and its opinion words. Ref. [20], fuse the syntax dependency types into GCN, based on which to highlight the syntax corresponds to sentiment classification. So far, there is an ongoing trend to combine the sentence syntax and semantics [8,9,21]. Most approaches tend to separately construct adjacency matrix for syntactic and semantic information, generate corresponding feature representations, and concatenate the representations for sentiment classification.

#### *2.2. Contrastive Learning*

A fundamental focus of contrastive learning is the learning of alignment and uniformity of given data [10]. Comprehensively, alignment is taken to indicate the similarity among positive examples while uniformity refers to informative-distribution of features, so

that negative examples are isolated from positive ones. In practical use, both alignment and the uniformity are used as indexes to optimize the feature learning. That is, the capturing intra-class similarities and inter-class differences can benefit the performance in downstream tasks.

Recently, a number of studies apply contrastive learning to NLP tasks and achieve satisfying results [22–24]. Ref. [22] devise a simple contrastive sentence embedding framework, which can produce superior sentence embeddings on semantic textual similarity tasks. For the aspect words absent from the training set, Ref. [25] take contrastive learning to capture aspect-invariant and aspect-dependent features to distinguish the roles of valuable sentiment features. Ref. [11] propose a novel contrastive-learning-based approach that simultaneously learns the features of input samples and the parameters of classifiers in the same space on the task of text classification.

#### **3. Proposed Method**

Figure 3 shows the framework of the TCL Network. Let *X* = {*x*1, *x*2,..., *xa*,..., *xa*<sup>+</sup>*la* , ..., *xN*} be a *<sup>N</sup>*-word sentence with aspect *<sup>A</sup>* <sup>=</sup> \* *xa*,..., *xa*<sup>+</sup>*la* + in it where *a* represents the starting index of *A* and *la* is the length of *A*. We feed the sentence into GloVe [26] or BERT [27] encoder for sentence embedding establishment. For GloVe-based model, each word is mapped into a low-dimensional vector by looking up in a pretrained word embedding matrix *<sup>E</sup>* ∈ R|*V*|×*dE* where |*V*| is the lexicon size and *dE* is the dimension of word vector. The sentence embedding is given as *x* = {*e*1,*e*2,...,*eN*}. The hidden states of the sentence are extracted via Bi-LSTM. The contextual feature vector is *H* = {*h*1, *h*2,..., *hN*} with *<sup>H</sup>* ∈ R*N*×2*<sup>d</sup>* and *<sup>d</sup>* representing the hidden layer dimension. In addition, the sequence [*CLS*]*X*[*SEP*]*A*[*SEP*] can also sent to BERT encoder to obtain the contextual feature vector *H*. Subsequently, *H* is taken as the input of both semantic-learning GCN module and syntactic-aware GCN module. A multi-layer Biaffine unit is proposed to integrate the semantic features and syntactic features. To further align the features from both space, the dual contrastive learning scheme is carried out. More details of each component are presented as follows.

**Figure 3.** The overall architecture of our Triplet Contrastive Learning Network.

#### *3.1. Syntactic-Aware GCN Module*

The architecture of syntactic-aware module is exhibited in Figure 4. As pointed out in the Introduction, the syntactic-aware GCN in our model tends to precisely capture the aspect-related context words and remove the redundant information. According to [12], a relational graph attention network is devised. Specifically, we construct an aspect-oriented dependency tree to replace the adjacency matrix of classical syntax dependency tree. Then, the attention mechanism is applied to the reshape sub-tree to capture the aspect-specific contextual features. Moreover, to resolve the long-dependencies among words, we set four categories of words as the key phrases that contributes to sentiment delivery, i.e., nouns, verbs, adjectives, and adverbs of degree. As such, the contrastive learning is performed to enhance the features of key phrases and effectively capture the word feature of long dependency.

**Figure 4.** Architecture of syntactic-aware GCN module

#### 3.1.1. Relational Graph Attention Module

At this stage, the aspect *A* is taken as the central word to construct the aspect-oriented dependency tree; see Algorithm 1 For words syntactically related to the central word of one hop, the corresponding dependency types are established. Through iteration, for words syntactically related to the central word of *n* hops(*n* ≥ 2), the dependency types are characterized by (*con* : *n*). If the aspect contains multiple words, these words are considered as a whole. In such a manner, we shall thus obtain the re-constructed dependency tree as *D* = {*dep*1, *dep*1,..., *depN*} and map it into embedding space to generate the dependency representation *HD* = \* *hD*<sup>1</sup> , *hD*<sup>2</sup> ,..., *hDN* + . Notably, the randomly initialized dependency embedding *ED* <sup>∈</sup> <sup>R</sup>|*Vd*|×*dD* is employed with <sup>|</sup>*Vd*<sup>|</sup> standing for the number of dependency types. For *HD* ∈ R*N*×*dD* , we have *dD* representing the dimension of dependency type embeddings.

The relational attention between aspect and the dependency type representation is computed. Specifically, the syntactic dependency of context toward the aspect is incorporated within *HD*. Thus, the attentive weight between *HD* and *H* is calculated using a simplified inner product operation, which is:

$$att = f\left(\frac{\left(W\_{\rm D}H\_{\rm D} + b\_{\rm D}\right) \times \left(W\_{\rm h}H + b\_{\rm h}\right)^{T}}{\sqrt{d\_{\rm m}}}\right) \tag{1}$$

where *WD* ∈ R*dD*×*dm* and *Wh* ∈ R2*d*×*dm* are linear layer weights; *bD* and *bh* are bias terms; *f*(·) stands for the softmax activation function; and *dm* is the hidden layer dimension of the attention module.

Then, the syntactic representation is given as:

$$H\_{syn} = \text{att} \ast H \tag{2}$$

#### **Algorithm 1** Aspect-Oriented Dependency Tree

**Input**: sentence *<sup>X</sup>* <sup>=</sup> {*x*1, *<sup>x</sup>*2,..., *xN*}, aspect*<sup>A</sup>* <sup>=</sup> \* *xa*,..., *xa*<sup>+</sup>*la* + , dependency tree *T*, and dependency relations *R*. **Output**: aspect-oriented dependency *T*˜. 1: Construct the aspect root *R*˜ for *T*˜ 2: **for** *a* to *a* + *la* **do** 3: **for** *j* = 1 to *n* **do** 4: **if** *xj* ∈/ *A* and *xj Rja* −→ *xa* **then** 5: *xj Rja* −→ *<sup>R</sup>*˜ 6: **else if** *xj* ∈/ *A* and *xj Rja* ←− *xa* **then** 7: *xj Rja* ←− *<sup>R</sup>*˜ 8: **else** 9: n = distance(a, j) 10: *xj <sup>n</sup>*:*con* −−−→ *<sup>R</sup>*˜ 11: **end if** 12: **end for** 13: **end for** 14: **return** *T*˜

3.1.2. Syntactic Contrastive Learning Scheme

The effectiveness of key phrases (i.e., nouns, verbs, adjectives or adverbs of degree) is highlighted by using based on a sentence-level key phrases contrastive learning module. To be specific, a mask operation, based on the POS information of phrases in the sentence, is performed. Only if the position mask 1 assigned to key phrase and mask 0 to others, can this representation defined as a positive example, i.e., *Mpos* ∈ R*N*. Conversely, a negative example indicates a key phrase with a position mask 0 while other words with a mask 1, i.e., *Mneg* ∈ R*<sup>N</sup>* .

The dependency type can be integrated into both positive and negative examples. We shall thus compute the positive example dependency type representation and the positive example dependency type representation as:

$$H\_{D\_{\text{pos}}} = H\_D \ast M\_{\text{pos}} \tag{3}$$

$$H\_{D\_{\rm reg}} = H\_D \ast M\_{\rm reg} \tag{4}$$

Similar to Equation (1), the attention weights of *HDpos* and *HDneg* toward the context representation are available, as presented in Equations (5) and (6). Thus, the syntactic representation of both positive examples and negative examples can be obtained (Equations (7) and (8)):

$$att\_{pos} = f\left(\frac{\left(W\_{D\_{\mu\mu}}H\_{D\_{\mu\alpha}} + b\_{D\_{\mu\alpha}}\right) \times \left(W\_{h\_{\mu\alpha}}H + b\_{h\_{\mu\alpha}}\right)^T}{\sqrt{d\_{\mu\mu}}}\right) \tag{5}$$

$$att\_{\rm{MC}} = f\left(\frac{\left(W\_{\rm{D\_{avg}}}H\_{\rm{D\_{avg}}} + b\_{\rm{D\_{avg}}}\right) \times \left(W\_{\rm{h\_{avg}}}H + b\_{\rm{h\_{avg}}}\right)^{T}}{\sqrt{d\_{\rm{m}}}}\right) \tag{6}$$

$$H\_{\text{SYI}\_{\text{pus}}} = att\_{\text{pus}} \ast H \tag{7}$$

$$H\_{syn\_{m\otimes}} = att\_{m\otimes} \ast H \tag{8}$$

For every sentence, we have its syntactic representation *Hsyn*, the syntactic representation with key phrases *Hsynpos* and syntactic representation without key phrases *Hsynneg* . Each of these syntactic representations is fed into a shared-weight biaffine unit to fuse with

the semantic representation in following section. The final syntactic representations, with the integration of semantic information, are presented as *Msyn* (derived from Equation (13)), *Msynpos* and *Msynneg* , respectively.

Aiming to focus more on the key phrases, the contrastive learning scheme is carried out with the loss function set as:

$$\mathcal{L}\_{con\_{syn}} = -\frac{1}{B} \sum\_{j=1}^{B} \frac{1}{N} \sum\_{i=1}^{N} \log \frac{e^{\text{sim}(M\_{syn\_{pos}}^{i}M\_{syn}^{i})/\tau\_{1}}}{\sum\_{t=1}^{N} \left(e^{\text{sim}\left(M\_{synpus}^{t}M\_{syn}^{i}\right)/\tau\_{1}} + e^{\text{sim}\left(M\_{syn\_{reg}}^{t}M\_{syn}^{i}\right)/\tau\_{1}}\right)} \tag{9}$$

where *τ*<sup>1</sup> is the temperature coefficient, *B* is the batch size and *N* is sentence length mentioned above.

Distinguishing from the current contrastive learning approaches, in addition to the positive example *M<sup>i</sup> synpos* , the other examples, containing *N* − 1 key-phrases-related syntactic representations *M<sup>t</sup> synpos* (*t* = *i*) and N syntactic representations without key phrases *Msynneg* , are all considered as negative examples. In other words, the negative examples of each word in the sentence is 2*N* − 1.

#### *3.2. Semantic-Learning GCN Module*

The sentence semantics is also encoded via GCN to enhance the modelling of sentiment information. Seeing that the self-attention mechanism is capable of extracting the semantic relevance of other words and the given word, we use self-attention network to construct a semantic adjacency matrix *<sup>A</sup>*sem ∈ R*N*×*N*:

$$A^{\rm sem} = f\left(\frac{Q\mathcal{W}^q \times \left(K\mathcal{W}^k\right)^T}{\sqrt{d}}\right) \tag{10}$$

where both *Q* and *K* equal the context representation *H*, *W<sup>q</sup>* and *W<sup>k</sup>* are trainable weighting parameters and *d* is the hidden layer size of attention network.

The semantic representation is derived from graph convolution, which is:

$$H\_{\rm scm} = \sigma(A^{\rm sem}WH + b) \tag{11}$$

where *σ*(·) stands for the linear activation function, such as ReLU function.

#### *3.3. Biaffine Unit*

The interaction of semantic information and syntactic information is conducted via multi-layer mutual Biaffine transformation. In Equation (12), *Hsyn* and *Hsem* are first multiplied to obtain a syntactic-related matrix containing the semantic information. Then, the syntactic-related matrix is mapped via Softmax and multiplied by the original semantic information to obtain the final syntactic feature representation with semantic information integrated. Via multi-layers of Biaffine unit, the semantic features can be fused to the syntactic representation for sentiment polarity classification. So is Equation (13).

$$H\_{\rm syn}^{(l)} = f\left(H\_{\rm syn}^{(l-1)} \mathcal{W}\_1^{(l-1)} \left(H\_{\rm ferm}^{(l-1)}\right)^T\right) H\_{\rm sem}^{(l-1)}\tag{12}$$

$$H\_{sem}^{(l)} = f\left(H\_{sem}^{(l-1)}W\_2^{(l-1)}\left(H\_{sym}^{(l-1)}\right)^T\right)H\_{syn}^{(l-1)}\tag{1.3}$$

where *<sup>l</sup>*(*<sup>l</sup>* = 1, 2, ...) stands for the layer number of the biaffine unit; both *<sup>W</sup>*<sup>1</sup> ∈ R2*d*×2*<sup>d</sup>* and *<sup>W</sup>*<sup>2</sup> <sup>∈</sup> <sup>R</sup>2*d*×2*<sup>d</sup>* are learnable parameters. Specifically, we take *<sup>H</sup>*(0) *sem* and *<sup>H</sup>*(0) *sem* to represent *Hsem* ∈ R*N*×2*<sup>d</sup>* and *Hsyn* ∈ R*N*×2*d*, which are the inputs of biaffine unit.

With the mutual Biaffine transformation, we thus obtain the final semantic representation with fused syntactic features *H*(*l*) *sem* which also presented as *Msem* and the final

syntactic representation with fused semantic features *H*(*l*) *syn* which also presented as *Msyn*. The average pooling is performed on the outcomes in relation to the aspect.

$$M\_{\text{sem}}^A = \text{avgpool}\left(M\_{\text{sem}\_{4}\prime}, \dots, M\_{\text{sem}\_{4+la}}\right) \tag{14}$$

$$M\_{syn}^A = \text{avgpool}\left(M\_{syn\_{a'}}, \dots, M\_{syn\_{a+l\_a}}\right) \tag{15}$$

Then, both the semantic representation and the syntactic representation of the aspect are concatenated and sent to the linear classifier to determine the sentiment polarity of the given aspect:

$$Z = f\left(\mathcal{W}\left[\mathcal{M}\_{\text{scm}}^A; \mathcal{M}\_{\text{syn}}^A\right] + b\right) \tag{16}$$

where [; ] stands for the vector concatenation, *W* and *b* are learnable parameters in the linear layer.

#### *3.4. Dual Contrastive Learning Scheme*

In the proposed model, the main purpose of the dual contrastive learning is to comprehensively align the features of both syntactic space and semantic space. The global syntactic features and semantic features can thus be captured. Notably, the output of the biaffine unit (i.e., *Msyn* and *Msem* are taken as the input of the dual contrastive learning module. For each input *Xi*, the sequence with the same sentiment polarity within the same batch is considered as the positive example P, otherwise as negative example N . The loss function of the dual contrastive learning is presented as:

$$\mathcal{L}\_{\text{syn}-\text{sem}} = -\frac{1}{B} \sum\_{i=1}^{B} \frac{1}{|\mathcal{P}|} \sum\_{j \in \mathcal{P}} \log \frac{\varepsilon^{\text{sim}(M\_{\text{syn}\_i}, M\_{\text{sem}\_j}) / \tau\_2}}{\sum\_{t=1}^{B} \varepsilon^{\text{sim}(M\_{\text{syn}\_i}, M\_{\text{sem}\_t}) / \tau\_2}} \tag{17}$$

$$\mathcal{L}\_{\text{scm}-\text{syn}} = -\frac{1}{B} \sum\_{i=1}^{B} \frac{1}{|\mathcal{P}|} \sum\_{j \in \mathcal{P}} \log \frac{\mathcal{e}^{\text{sim}(M\_{\text{scw}\_i}, M\_{\text{syn}\_j})} / \pi\_3}{\sum\_{t=1}^{B} \mathcal{e}^{\text{sim}(M\_{\text{scw}\_i}, M\_{\text{syn}\_t})} / \pi\_3} \tag{18}$$

where *τ*<sup>2</sup> and *τ*<sup>3</sup> are the temperature coefficients of contrastive loss.

#### *3.5. Loss Function*

The loss function for model training is expressed:

$$\begin{split} \mathcal{L} &= \mathcal{L}\_{\complement\text{E}} + \mathfrak{a} \mathcal{L}\_{\text{o}} + \beta \left( \mathcal{L}\_{\text{sym-scm}} + \mathcal{L}\_{\text{scm}-\text{sym}} \right) \\ &+ \gamma \mathcal{L}\_{\text{com}\_{\text{syn}}} + \lambda \left\| \Theta \right\| \end{split} \tag{19}$$

with

$$\mathcal{L}\_0 = ||A^{\text{scum}} A^{\text{secm}}^T - I||\_F \tag{20}$$

where *α*, *β* and *γ* are hyperparameters; *LCE* represents the cross-entropy loss for sentiment polarity classification; Θ denotes the training parameter set; *λ* represents the coefficient of L2 regularization. Inspired by [8], for each word in the sentence, its attention distribution on every other word is distinguishing. In other words, the overlap of attentive weights has to be minimized especially for the application of semantic graph adjacency matrix. Therefore, an additional orthogonal regularized loss function *Lo* is thereby introduced. The parameter *I* in Equation (20) is an identity matrix and the subscript *F* stands for the Frobenius norm.

Since the contrastive learning loss results are derived from various weighting parameters, the back propagation can be applied to optimize these parameters during the loss function optimizing.

#### **4. Experiments**

#### *4.1. Datasets and Settings*

**Datasets:** We evaluate the working performance of TCL network on three benchmark datasets, which are Rest14 and Lap14 from SemEval 2014 Task4 [28] and Twitter [29]. Each sample in these datasets is either a product review or tweet sentence, which contains explicit aspect words and the corresponding sentiment polarities. Each aspect from the product reviews or tweets in our experiments is labeled as positive, neutral or negative. Details of each dataset are exhibited in Table 1.

**Experimental Settings:** For GloVe-based model, we initialize the word embeddings with 300-dimensional vectors pretrained by Glove [26]. The dimension of the dependent syntactic embeddings is set to 30. The hidden layer dimension of BiLSTM is 50. All the weights in the model are initialized by Xavier uniform distribution. The layer number of biaffine unit is set as 2. For the contrastive learning scheme, the temperature coefficient determines how much attention the contrastive learning loss assign to the outlier negative samples. The larger the temperature coefficient is, the greater the tolerance to negative samples, and vice versa. In the syntactic contrastive learning module, it is desirable that more attention is given to key phrases with a certain tolerance to other words. Therefore, *τ*<sup>1</sup> of the syntactic contrastive learning is 1 while *τ*<sup>2</sup> and *τ*<sup>3</sup> of the dual contrastive learning is set to 0.1. In addition, he Adam optimizer is adopted with a learning rate of 2 × <sup>10</sup>−3. The batch size ranges from 16 to 64. The L2 regularization coefficient *<sup>λ</sup>* is set to 1 × <sup>10</sup>−4. Notably, the values of *α*, *β* and *γ* vary in line with the datasets, which are 0.1, 0.5 and 0.5 for Rest14, 0.5, 0.7 and 0.8 for Lap14 and 0.2, 0.2 and 0.7 for Twitter.


**Table 1.** Statistics of datasets.

#### *4.2. Baselines*

In order to validate the effectiveness of the proposed model in ALSC, we take 10 stateof-the-art methods for comparison:


9. **DR-BERT [33]** The Dynamic Re-weighting Adapter is proposed to encourage model to better understand aspect-aware sentiment through

#### *4.3. Experimental Results and Analysis*

We take two metrics, accuracy and Macro-F1, to evaluate the working performance of the proposed model. The experimental results of 13 different methods are presented in Table 2. Comparing with the state-of-the-arts, the TCL network is the best performing method in most datasets. There is a considerable performance gap between the proposed model and the baselines. According to Table 2, one can easily see that models using BERTbased embeddings have a better performance than those of GloVe-based embeddings. Indeed, the employment of GCN substantially contributes to the encoding of sentence syntax and semantics. With respect to our model, the effectively use of syntactic information highlights the contextual words related to the aspect. As a result, more attentive weights are given to words that contribute to the sentiment delivery. In comparison with the singlechannel GCN (i.e., [6,12,30]), the dual-channel GCN methods (i.e., [8]), which deal with both the syntactic information and the semantic information, shows their superiorities in ALSC tasks. In this way, our model not just integrates different types of features, but also exploits the global information to further optimize the sentiment classification results.

**Table 2.** Experimental results. Bold numbers represent the best results among methods of the same type.


However, the TCL network fails to overperform DR-BERT on the Rest14. A possible explanation is that the samples of distinguishing sentiment occupy significantly different proportion in the Rest14 dataset, which affects performance of contrastive learning scheme as the generation of positive and negative samples is obtained by random sampling.

#### *4.4. Ablation Study*

An ablation study is carried out on three datasets to investigate the importance of the contrastive learning losses; see Table 3. The dual contrastive learning scheme concerns the syntactic-based semantic learning loss function L*sem*−*syn* and the semantic-based syntactic learning loss function L*syn*−*sem*. The results show that the ablating of both loss functions leads to the most significant drop. The main reason is that the employment of global features within the minibatch does benefit the sentiment delivery. We see that the contribution of L*sem*−*syn* is slightly higher than that of L*syn*−*sem*, which indicates the effectiveness of semantic alignment. By contrast, the contribution of L*consyn* in the syntactic learning module is relatively small, but its removal still results in an average decrease of 1.2% in accuracy.


**Table 3.** Ablation study results. Bold numbers represent the best results.

#### *4.5. Case Study*

Four examples of ALSC tasks are conducted and presented in Figure 5. The aspect words in green, blue, and red represent the positive, neutral, and negative sentiment polarities, respectively. The first case is a sentence of simple syntax and semantics. All the three models are capable of identifying the sentiment as negative. Sentence 2 contains multiple aspects. The ASGCN fails to determine the sentiment of aspect '*disc drive*', because '*disc drive*' is syntactically close to the word negative word '*not*'. Similarly, in sentence 3, the aspect '*apple OS*' has a long distance dependency with its opinion word, which results in the misunderstanding of the sentiment using ASGCN. By contrast, the DualGCN, which integrates both syntactic and semantic information, can classify the sentiment toward aspect '*apple OS*' correctly. In the last sentence, despite the complexity in both the syntax and semantics, the TCL network is capable of identifying the sentiment polarities of all aspects. The application of triplet contrastive learning effectively obtains alignment between semantic and syntactic features, indicating its efficacy in ALSC of complex sentences.


**Figure 5.** Case study. ALSC results of TCL, ASGCN and DualGCN on testing examples, along with their predictions and correspondingly, golden labels. The marker and indicate the correct classification and incorrect classification, respectively

#### *4.6. Visualization*

#### 4.6.1. Comparison of Syntactic and Semantic Vectors

The distribution of semantic and syntactic representations aims to verify the effectiveness of the dual contrastive learning scheme. Figure 6 shows the visualization of semantic and syntactic outputs of the dual contrastive learning module using t-SNE algorithm [34]. To facilitate the comparison, we only take the data with positive and negative sentiment polarities for visualization. Apparently, both the basic TCL network and TCL without dual contrastive learning can distinguish one type of representations. Notably, the proposed model without dual contrastive learning fails to resolve the two types of vectors with the same sentiment polarity, such as the distribution of red dots, which indicates the importance of alignment between the semantic and syntactic spaces. Moreover, there are large amount of overlapping for vectors with different sentiment polarities. The uniformity of syntax and semantics is absent. In comparison, the TCL network considers both the alignment and the uniformity of features. With the application of dual contrastive learning scheme, not only the distribution of the same-polarity-representations is more concentrated, but also the overlapping within different-polarities-representations are reduced to a large extent.

**Figure 6.** Visualization of semantic and syntactic vectors. Triangle dots represent syntactic vectors; round dots represent semantic vectors; dots in red represent positive samples; dots in green represent negative samples.

#### 4.6.2. Sentiment Classification Visualization

Similarly, the visualization of triplet contrastive learning is also performed; see Figure 7. For the ASGCN that merely exploits the syntactic features, the neural samples can be distinguished from those of other two sentiment polarities. Whereas, the classification between positive and negative samples is challenging, with large amount of misunderstanding of the sentiment. Since DualGCN tackles both syntactic and semantic information, the samples of three sentiment polarities can be better discriminated. The distribution of neural samples is still not that distinctive, especially comparing with the negative samples. By contrast, our model shows its dominance in sentiment classification. It is clearly that a more concentrated distribution of samples with the same sentiment is accessible. Due to the introduction of triple contrastive learning, a better performance of feature learning and sentiment classification can be expected.

**Figure 7.** Visualization of sentiment classification results. The dots in green, red and blue respectively represent the positive, neural and negative samples.

#### **5. Conclusions**

In this work, a TCL network is developed to deal with the ALSC tasks, which not just exploits the global information, but also obtains the alignment of semantics and syntax. To start with, an aspect-oriented dependency tree is constructed by reshaping the syntactic adjacency matrix. Then, the sentence-level contrastive learning is applied to highlight the effectiveness of key phrases toward sentiment delivery. Two GCNs are employed to respectively encode the syntactic and semantic information. A dual contrastive learning scheme is proposed to align the features from both syntactic and semantic spaces. Experiments are carried out on three benchmark datasets. Our method produces results considerably better than the state-of-the-art methods on the task of ALSC.

**Author Contributions:** Conceptualization, H.X. and Y.X.; methodology, H.X.; formal analysis, H.X. and Z.Y.; writing—original draft preparation, H.X. and Z.Y.; writing—review and editing, Y.X. and H.Z.; supervision, Y.X., Z.H. and H.Z.; funding acquisition H.Z. and Z.H. All authors have read and agreed to the published version of the manuscript.

**Funding:** This work was supported by the Characteristic Innovation Projects of Guangdong Colleges and Universities (Nos. 2018KTSCX049), the Science and Technology Plan Project of Guangzhou under Grant Nos. 202102080258 and 201903010013.

**Institutional Review Board Statement:** Not applicable.

**Informed Consent Statement:** Not applicable.

**Data Availability Statement:** Not applicable.

**Conflicts of Interest:** The authors declare no conflict of interest.

#### **References**


### *Article* **Knowledge-Enhanced Dual-Channel GCN for Aspect-Based Sentiment Analysis**

**Zhengxuan Zhang 1, Zhihao Ma 2, Shaohua Cai 3,\*, Jiehai Chen <sup>1</sup> and Yun Xue <sup>1</sup>**


**\*** Correspondence: caishaohua@m.scnu.edu.cn

**Abstract:** As a subtask of sentiment analysis, aspect-based sentiment analysis (ABSA) refers to identifying the sentiment polarity of the given aspect. The state-of-the-art ABSA models are developed by using the graph neural networks to deal with the semantics and the syntax of the sentence. These methods are challenged by two issues. For one thing, the semantic-based graph convolution networks fail to capture the relation between aspect and its opinion word. For another, minor attention is assigned to the aspect word within graph convolution, resulting in the introduction of contextual noise. In this work, we propose a knowledge-enhanced dual-channel graph convolutional network. On the task of ABSA, a semantic-based graph convolutional netwok (GCN) and a syntactic-based GCN are established. With respect to semantic learning, the sentence semantics are enhanced by using commonsense knowledge. The multi-head attention mechanism is taken to construct the semantic graph and filter the noise, which facilitates the information aggregation of the aspect and the opinion words. For syntactic information processing, the syntax dependency tree is pruned to remove the irrelevant words, based on which more attention weights are given to the aspect words. Experiments are carried out on four benchmark datasets to evaluate the working performance of the proposed model. Our model significantly outperforms the baseline models and verifies its effectiveness in ABSA tasks.

**Keywords:** aspect-based sentiment analysis; graph convolutional networks; commonsense knowledge graph

**MSC:** 18C50

#### **1. Introduction**

Aspect-based sentiment analysis (ABSA) is a sentiment classification task that aims to identify the sentiment of given aspects [1]. Within ABSA, the sentiment of each aspect is classified according to a predefined set of sentiment polarities, i.e., positive, neutral or negative [2]. In recent years, ABSA yields very fine-grained sentiment information, which is useful for applications in a variety of domains [3].

In the context of advancing deep neural networks, state-of-the-art ABSA methods report high accuracy and strong robustness on benchmark datasets. During the progressing stage in ABSA tasks, efforts are generally made in two directions: one is to enhance significant information from the given text and the other is to filter the irrelevant information and its impact. A major step toward the comprehension of semantic information is the integration of attention mechanism with deep neural networks [4–6]. More attentive weights are assigned to aspect-related words, based on which to classify the sentiment polarity. Nevertheless, it can be challenging to capture syntax dependencies between the aspect and its contexts for attention-based models. More recently, research on graph neural networks (GNNs) has given rise to dealing with the syntactic information from dependency trees, a manner in which to prevent the syntactically irrelevant contextual noise [7–9]. The widespread GNNs, such as graph convolutional networks (GCNs) and graph attention

**Citation:** Zhang, Z.; Ma, Z.; Cai, S.; Chen, J.; Xue, Y. Knowledge-Enhanced Dual-Channel GCN for Aspect-Based Sentiment Analysis. *Mathematics* **2022**, *10*, 4273. https:// doi.org/10.3390/math10224273

Academic Editors: Jianping Gou, Weihua Ou, Shaoning Zeng and Lan Du

Received: 13 October 2022 Accepted: 11 November 2022 Published: 15 November 2022

**Publisher's Note:** MDPI stays neutral with regard to jurisdictional claims in published maps and institutional affiliations.

**Copyright:** © 2022 by the authors. Licensee MDPI, Basel, Switzerland. This article is an open access article distributed under the terms and conditions of the Creative Commons Attribution (CC BY) license (https:// creativecommons.org/licenses/by/ 4.0/).

networks (GATs), are capable of encoding both the semantics and the syntax. This has been an ongoing trend to incorporate syntactic information and semantic information into GNN-based models [10–12].

In spite of the collaborative exploiting of syntax and semantics, two main limitations can be observed :

(1) For one thing, GNNs are generally used for tackling global syntactic information, while the mask operation is lastly performed to conceal the context words. Thereby, the sentiment of the aspect is determined. In practical application, the contextual noise can be introduced, which results in minor importance given to the aspect words.

(2) For another, the semantic-based GNNs are typically built up based on attention weights. With respect to the delicate relationship between aspects and opinion words, more attention is assigned to other words instead of the sentiment words. This can further confuse the sentiment aggregation. As presented in Figure 1, in the sentence *'Meal is very expensive for what you get'*, the aspect *'meal'* and its opinion word *'expensive'* are semantically insensitive.

The **menu** may be small, but everything on it is delicious.

The **menu** may be small, but everything on it is delicious.

**Meal** was very expensive for what you get.

**Meal** was very expensive for what you get.

**Figure 1.** Attention weights towards aspects. Words in black bold are aspects; words with a blue background are predicted attention weights; words with a green background represent desirable attention distribution. A word in the darker color indicates a greater weight and vice versa.

On the task of ABSA, this work focuses on establishing a Knowledge-Enhanced Dual-Channel Graph Convolutional Network (KDGCN). Two GCN-based modules, referred to as syntax-based GCN and semantic-based GCN, are developed to separately deal with the syntax structure and the semantic information. On the one hand, the syntactical dependency tree of the sentence is pruned to remove the connections of minor relevance to the aspect. Hence, the aspect-oriented syntactic information is sent to the syntaxbased GCN. Besides, the position information and the attention mechanisms are taken to highlight the importance of the aspect. On the other hand, the external knowledge is introduced to enhance the semantic-based GCN. The word sentiment vectors, together with the supplementary of the aspect, are obtained (derived) by using SenticNet (i.e., a commonsense knowledge base); see Figure 2. A multi-headed attention mechanism is carried out to re-assign the attentive weights among words. The sentiment of the opinion words can thus be aggregated to the aspect via the knowledge-enhanced semanticbased GCN.

It has so much more **speed** and the screen is very sharp .

**Figure 2.** Sentiment vectors and aspect supplementary based on SenticNet. The different colors and shades represent the emotional polarity score of the word in SenticNet, where −1 is negative and 1 is positive.

Notably, a certain number of studies leverage the commonsense knowledge to enhance the sentiment expression and classify the sentiment polarity of the aspect [13,14]. Theoretically, the commonsense knowledge is involved with the background materials of the entities under discussion. The commonsense knowledge is preserved in the commonsense bases, such as ConceptNet [15], SenticNet [16] and WordNet [17], and recalled for processing. In most cases, the integration of semantic-related commonsense knowledge can generate noise from external information. Our model aims to exploit the sentencerelated external knowledge, not just the sentiment information of each word, but also the relative knowledge of the aspect. In such a manner, the input of semantic-based GCN is distilled. Accordingly, the more-related information is preserved with the noise removed. The contributions of this paper are threefold and summarized as follows:


The paper is mainly divided into six sections. In the Introduction, we summarize the content of the article in general and propose our solutions for the challenge of the current ABSA task; in the Section 2, we will summarize the research related to our work; in the Section 3, we will introduce our proposed model and each module in detail; in the Section 4, we will conduct experiments on four public datasets and design ablation experiments; in the Section 5, we will further analyze the general situation of the model and the experimental results; in the Section 6, we summarize the full text.

#### **2. Related Work**

#### *2.1. Aspect-Based Sentiment Analysis*

As pointed out in the introduction, ABSA is a fine-grained sentiment classification task. Rather than assigning an overall sentiment polarity to a sentence or a document, ABSA aims at precisely determining the sentiment of a certain aspect. Early methods usually rely on manual features when predicting, which cannnot model the dependency relationship between the aspect and its context [18–20].

In recent years, advances in deep-learning algorithms significantly improved the working performance of ABSA, while a more detailed analysis of the textual information has risen [21,22]. The integrating of an attention mechanism into deep neural networks highlights the contribution of opinion words towards the aspects. [4–6,23–25] The relationship between aspect and its opinion words are reliably modeled in attention-based networks. Wang et al. [4] proposed an attention-long short-term memory (LSTM) method to obtain more-related information about a given aspect. Chen et al. [5] devised a hierarchical multi-attention model to address the long-range dependency between aspect and the opinion words. Whereas the attention mechanisms fails to cope with sentence syntax, by contrast, the employment of GCN takes advantages of the syntactic dependencies of the aspect and the opinion words. To be specific, an adjacency matrix is formed based on the syntactic dependent tree, which is further modeled to aggregate the sentiment information to the aspect by GCN [7,8]. Wang et al. [9] eliminated the noise from irrelevant contexts by constructing an aspect-oriented syntactic dependency tree, and then encoded the syntax relation by GNN. More recently, modules of multi-channel-GCNs have been carried out to resolve the syntax and semantics of the given sentence, which effectively optimizes the results of ABSA.

#### *2.2. Graph Convolutional Networks*

As a classical variant of GNN, GCN was originally proposed by Kipf et al. [26] in 2017. So far, GCN has shown its superiority in diversified NLP tasks, such as text classification [27,28], relation extraction [29,30], knowledge distillation [31] and machine translation [32].

Most studies [7,8] take GCNs to capture the syntactic information of a sentence where the nodes represent the words and the edges indicate the dependencies, which can induce representation vectors of nodes based on their neighborhoods' features. Likewise, the semantic relation within the sentence can also be obtained using GCN. In [10,11], the semantic graph was constructed with edges standing for the attention weights. Therefore, both semantic features and syntactic features can be extracted via GCN-based modules.

Considering a graph as structured data, the multilayers of GCN are responsible for information delivery. As such, every single node within the graph can learn the global information. Let *G* = (*V*, *E*) , where *V* = {*vi*, *v*2, ... , *vn*} is a set of *N* = |*V*| nodes and *E* is the set of edges, and it represents an *<sup>n</sup>*−node graph with an adjacency matrix of *<sup>A</sup>* ∈ R*k*×*k*. In a graph, let *vi* ∈ *V* to denote a node and *eij* = (*vi*, *vj*) ∈ *E* to denote an edge between *vi* and *vj*.

GCN can only capture information about neighbors with a layer. However, information about more neighborhoods can be integrated when multilayers of GCN are stacked. We define *h<sup>l</sup> <sup>i</sup>* as the output of node *<sup>i</sup>* on the *<sup>l</sup>* − *th* layer and *<sup>h</sup>*<sup>0</sup> *<sup>i</sup>* as the initial state of node *i*. The graph convolution of node *i* can be written as:

$$h\_i^l = \sigma(\sum\_{j=1}^k A\_{i\bar{j}} \mathsf{W}^l h\_j^{l-1} + b^l) \tag{1}$$

where *W<sup>l</sup>* is the weight of linear transformation, *b<sup>l</sup>* is the bias and *σ* is a nonlinear function such as *Relu*.

#### *2.3. Commonsense Knowledge*

The commonsense knowledge for NLP is typically obtained through large-scale corpus training and saved in commonsense bases. The commonsense is taken as prior knowledge for the pre-training of knowledge-enhanced approaches. SenticNet [16] is one such commonsense knowledge base, which contains 100*k* concepts related to sentiment expression. (e.g., mood, polarity, semantics and so on). Additionally, these affective properties provide concept-level representation and semantic connections to the words.

To facilitate access to corresponding knowledge, SenticNet provides an application programming interface. A series of sentiment scores of the word and its related concepts can be obtained from the interface (as shown in Figure 2), which can expand the semantics of the sentence.

The application of SenticNet into ABSA shows its distinctiveness in sentiment representation learning [13,33]. Ma et al. [13] utilized the commonsense from SenticNet to generate essays more closely surrounding the semantics of the input topics. Zhou et al. [14] enlarged the sentence semantics using SenticNet 5, and then jointly modeled the syntactic dependency trees and commonsense graph. Regardless of additional key information, the filter of the noise during the external knowledge introducing remains unsettled.

#### **3. Methodology**

The architecture of KDGCN is presented in Figure 3. Our model consists of five key components, i.e., a sentence encoder, a knowledge enhancement module, a semantic learning module, a syntax aware module and a sentiment classifier. Firstly, each word of the sentence is encoded as a vector by the sentence encoder. At the same time, the sentence is input into the knowledge enhancement module, and the sentiment vector of each word and the expanding words of aspect are obtained from SenticNet; secondly, the hidden state vector of the sentence is sent into a semantic learning module and a syntax aware module,

respectively, to obtain the syntactic and semantic representation. Finally, we can obtain the sentiment polarity of the aspect from the sentiment classifier.

**Figure 3.** Overall architecture of the proposed Knowledge-Enhanced Dual-Channel Graph Convolutional Network.

#### *3.1. Sentence Encoder*

**Glove embedding.** For a sentence *c* = {*w*1, *w*2,..., *wn*} with the aspects *a* = {*wa*1, *wa*2, ... , *wan*}, we take the pre-trained embedding matrix *<sup>E</sup>* ∈ R|*V*|×*de* to map each word into a low-dimensional vector, where |*V*| represents the lexicon size and *de* is the dimension of the word vector [34].

**BERT embedding.** BERT [35] is a commonly used sentence encoder in recent years. Each sentence is pre-processed by adding *[CLS]* at the beginning and *[SEP]* at the end, respectively, to obtain *c* = {*w*0, *w*1,..., *wn*+1}, where *w*<sup>0</sup> and *wn*+<sup>1</sup> denote the two special tokens inserted. Then, *c* is fed into BERT to obtain the textual feature representation *<sup>X</sup>* = {*x*0, *<sup>x</sup>*1,..., *xn*+1}, where *xi* ∈ R*dbert* .

A Bidrectional LSTM (Bi-LSTM) is employed for sentence encoding. The given sentence embedding is sent to Bi-LSTM to generate the hidden state vector *HLSTM* = {*h*1, *<sup>h</sup>*2,..., *hn*}. Specifically, the vector *<sup>H</sup>LSTM* ∈ R2*dh* is the hidden state at a time step and is the hidden state vector dimension of LSTM.

#### *3.2. Knowledge Enhancement Module*

**Word sentiment enhancement:** For the given sentence c, the sentiment vector of each word can be obtained based on the commonsense from SenticNet. A 23-dimensional sentiment vector *<sup>H</sup>LSTM* ∈ R<sup>23</sup> hat represents the sentence that is derived. Besides, for the words that do not appear in SenticNet, the zero-vector is used instead. Then, *HLSTM* and *Hsen* are fused to obtain the sentence representation, which is:

$$H^c = [H^{LSTM}; H^{sen}] \tag{2}$$

with *<sup>H</sup><sup>c</sup>* <sup>∈</sup> <sup>R</sup>2*dh*<sup>+</sup>23.

**Aspect knowledge enhancement:** In terms of the aspects a, the relative words of each word within a is collected from SenticNet, i.e., {*wex*1, *wex*2,..., *wexn*}. For the purpose of word supplementary, the first five words in relation to the aspect are used. All the relative words are also mapped to word embeddings and encoded with the Bi-LSTM encoder.

$$H^{\rm ex} = [H\_{\rm ex}^{LSTM}; H\_{\rm ex}^{\rm sen}] \tag{3}$$

where *HLSTM ex* stands for the hidden state vector of Bi-LSTM, and *Hsen ex* is the corresponding sentiment vector. The aspect expanding vector is denoted as *<sup>H</sup>ex* <sup>∈</sup> <sup>R</sup>2*dh*<sup>+</sup>23.

Notably, since the word co-occurrence in the corpus has an impact on the word embedding of glove, to prevent the noise fusion, the aspect relative words are not pretrained by glove. We take a *unk* for relative words that are absent from the given texts. Similarly, the absent-words of SenticNet are taken in place of zero.

#### *3.3. Semantic Learning Module*

Motivated by [10], most short sentences are of confused syntactic structure. That is, the rigid extraction of syntactic information can lead to the misinterpretation of the sentiment information. For this reason, a semantic learning module based on GCN is proposed to capture the semantic information among words. Both the enhanced sentiment vector and the aspect expanding vector are sent to the semantic learning module, which aims to further enrich the semantic information.

**Node construction:** Each word *wi* from the sentence, together with each aspect relative word *wexi*, is taken as a node. All nodes constitute a node set *V*.

**Edge construction**: The edge indicates the relationship between word nodes. Concretely, two semantic-related nodes are connected with an edge and vice versa. To capture the semantic relation of each word, we employ *K* − *heads* multi-head self-attention mechanism to compute the attention weight, i.e.,

$$A\_{\rm ttn} = \frac{(H\_{\rm s\varepsilon} \mathcal{W}\_{\rm s\varepsilon,k}) (H\_{\rm s\varepsilon} \mathcal{W}\_{\rm s\varepsilon,q})^T}{\sqrt{d\_{\rm hend}}} \tag{4}$$

where

$$H\_{\rm sc}^{(0)} = H^{\rm c} \tag{5}$$

$$d\_{hand} = \frac{d\_{lstm}}{k} \tag{6}$$

where *<sup>H</sup>*(0) *se* <sup>∈</sup> <sup>R</sup>2*dh*+<sup>23</sup> is the commonsense-enhanced hidden layer output; *<sup>K</sup>* is the head number of multi-head attention mechanism; *Wse*,*<sup>k</sup>* and *Wse*,*<sup>q</sup>* <sup>∈</sup> <sup>R</sup>(2*dh*+23)×*dhead* are trainable matrices. Subsequently, based on the top-k selecting approach, the largest k values of each dimension are selected and set to 1, while others are set to 0. Hence, the adjacency matrix *Ase* is obtained; see Equation (7). Corresponding to the edge construction principle, the adjacency matrix with value 1 denotes the semantic relevance between nodes. Notably, the *Ase* remains symmetric with the application of the top-k selector.

$$A\_{\rm st} = \text{topk} \sum\_{i=0}^{k} A\_{\rm tfin} \tag{7}$$

Thereby, a graph *Gsem* = (*Ase*, *Hc*) that concerns the node representations and the adjacency matrix is constructed. The graph is fed into the N-layer GCN to obtain the hidden layer state *Hse* :

$$H\_{\rm sc}^{(l+1)} = \text{GCN}(A\_{\rm sc}, H\_{\rm sc}^{(l)}, W\_{\rm sc}^{(l)}) \tag{8}$$

where *H*(*l*) *se* <sup>∈</sup> <sup>R</sup>(2*dh*+23)×*dgcn* stands for the parametric matrix of GCN. The mask operation is conducted on non-aspect words, following with the average pooling to compute semantic hidden layer output *hse*, which is written as:

$$mask = \begin{cases} 0 & 1 \le t < \tau + 1, \tau + m < t < n \\ 1 & \tau + 1 \le t \le \tau + m \end{cases} \tag{9}$$

$$h\_{\rm st} = f(mask(H\_{\rm st})) \tag{10}$$

where *τ* + 1 ≤ *t* ≤ *τ* + *m* indicates the aspect index and *f*(·) is the average pooling function.

#### *3.4. Syntax Aware Module*

The syntax aware module is devised by modifying the method proposed by Zhang et al. [7]. The sentence syntax is characterized by the syntax dependency tree. Note that not all context words are syntactically related to the aspect—an aspect-related selection approach is taken to reshape the syntax dependency tree. Only if a context word reaches the aspect within n hops can the dependency edge between nodes be kept. We can thus revise the adjacency matrix *A*<sup>0</sup> to *Asy*. In this way, the revised graph is written as *Gsy* = (*Asy*, *HLSTM*), where *HLSTM* is the current node representation. Before sending *Gsy* to GCN, the position-aware transformation is performed [7]:

$$q\_i = \begin{cases} 1 - \frac{\tau + 1 - i}{n} & 1 \le i \le i + 1 \\ 0 & \tau + 1 \le i \le \tau + m \\ 1 - \frac{\tau + 1 - i}{n} & \tau + m < i \le n \end{cases} \tag{11}$$

with

$$\mathcal{F}(h\_i) = q\_i h\_i \tag{12}$$

where *qi* ∈ R the position weight of the *i*-th token and F(·) is the function for position weight assignment. The syntactic information is learned by using graph convolution. The syntactic hidden layer output is expressed as:

$$H\_{sy}^{(l)} = \mathcal{F}(H\_{sy}^{(l-1)}) \tag{13}$$

$$H\_{sy}^{(l+1)} = \text{GCN}(A\_{sy}, H\_{sy}^{(l)}, W\_{sy}^{(l)}) \tag{14}$$

$$H\_{sy}^{(0)} = \mathcal{F}(H^{LSTM})\tag{15}$$

where *<sup>H</sup>*(*l*) <sup>∈</sup> <sup>R</sup>2*dh*×*dgcn* is a trainable parametric matrix. Similar to the semantic-based GCN, the syntactic hidden state representation *Wsy* is revised via masking (Equation (16)). The

$$H^t = \text{mask}(H\_{sy})\tag{16}$$

where *<sup>H</sup><sup>t</sup>* = {*h<sup>t</sup>* <sup>1</sup>, *<sup>h</sup><sup>t</sup>* <sup>2</sup>,..., *<sup>h</sup><sup>t</sup> j* }. The outcome hidden layer state from Equation (16) concentrates more on the aspect words. In addition, to further detect the significant semantic feature concealed within the syntax structure, the attention weight of each context word is assigned. The dot product of *h<sup>t</sup> <sup>i</sup>* and *hi* are obtained to denote the syntactic representation, i.e.,

$$h\_{sy} = \sum\_{j=1}^{n} a\_j h\_j^t \tag{17}$$

$$a\_{\dot{j}} = \frac{\exp(\beta\_{\dot{j}})}{\sum\_{i=1}^{n} \exp(\beta\_{\dot{j}})} \tag{18}$$

$$\beta\_l = \sum\_{i=1}^{n} h\_j^t h\_i = \sum\_{i=\tau+1}^{\tau+m} h\_j^t h\_i \tag{19}$$

#### *3.5. Sentiment Classifier*

Both the semantic representation and the syntactic representation are so far computed. We shall thus concatenate *hse* and *hsy* to obtain the final representation *ha* (Equation (20)). The sentiment polarity of the given aspect is classified by sending *ha* to the Softmax classifier, which is:

$$H\_a = \left[ H\_{se}; H\_{sy} \right] \tag{20}$$

$$y = \operatorname{softmax}(h\_a) \tag{21}$$

#### *3.6. Model Training*

The training process is performed by using the categorical cross entropy and *L*<sup>2</sup> regularization as the loss function:

$$Loss = -\sum\_{i} \sum\_{j} y\_i^j log \left(p\_i^j\right) \tag{22}$$

where *i* is the index of the ABSA sample and *j* is the corresponding sentiment polarity.

#### **4. Experiment**

In this section, we designed the main experiment and attention visualization to verify the effectiveness of our model on the ABSA task. Specifically, we first introduce the benchmark datasets used in our experiment, and then briefly introduce the details of the experiment and the selected baseline. Then, we carried out the main experiment and analyzed the experimental results. In addition, in order to explore the contribution of each module to the model, we designed ablation experiments and analyzed the mechanism of knowledge enhancement in attention visualization.

#### *4.1. Dataset*

To verify the working performance of the proposed model, experiments were carried out on four publicly available benchmark datasets, i.e., Rest14 and Lap14 from SemEval 2014 [36], Rest15 from SemEval 2015 [37] and Rest16 from SemEval 2016 [1], containing reviews of restaurant and laptop domains.

Every single sentence from the datasets contains at least one aspect. The sentiment polarity of each aspect is given as well, including: positive, negative and neutral. For example, in the sentence "*Great food but the service was dreamful!*", there are two aspect terms, *'food'* and *'service'*, and their sentiment polarity are positive and negative, respectively. The details of each dataset are presented in Table 1.


#### *4.2. Implementation Details*

The best test result of each method was taken for evaluation. For the proposed model, the initialization of word embeddings was conducted using Glove [38] and uncased BERT [35], respectively. The pretrained Glove provides a 300-dimensional word vector, with a learning rate of 0.001 and a batch size of 64. Moreover, the dimension of Bert-based word embeddings was 768, with a learning rate of 0.00002 and a batch size of 32. The head number

of multi-head attention network was set to 1. The value of top-k selection was 2. Besides, the Adam optimizer was employed. The *L*<sup>2</sup> regularization weight was 0.0001. The value of dropout was determined within the interval of [0.4, 0.6] using grid searching. With respect to the GCN in our model, the number of layers and the dimension of hidden layers ranged within [1,4] and [100, 200], respectively, which were also selected via grid searching.

#### *4.3. Baseline*

For the purpose of validating the effectiveness of our model, twelve state-of-the-art methods were taken for comparison, which are presented as follows:


#### *4.4. Experimental Results*

Experimental results on all datasets are exhibited in Table 2. In this experiment, we took accuracy and macro-F1 as the method evaluation metrics. Comparing with the baseline models, **KDGCN** generally obtained the best and most consistent results in all evaluation settings. However, our model with the Bert encoder was less competitive than **DMGCN+BERT** on the dataset of Rest14. A possible explanation is that the pretrained Bert contains a wealth of semantic information. The semantic enhancement via SenticNet is not that distinctive. With respect to the Glove-based word embeddings, the performance of **KDGCN** was 0.93% and 2.89% higher than **DMGCN** in accuracy and Macro-F1, respectively.

Comprehensively, current GCN-based models focus on encoding either the syntactic information (e.g., **ASGCN**, **CDT**, **R-GAT** and **TGCN+BERT**) or the semantic-integrated syntactic information (e.g., **DualGCN** and **DMGCN** ). The performance of these methods largely depends on their fitting capabilities. By contrast, the proposed model adopted the aspect-related selection approach to prune the edges of the syntax dependency tree, based on which the unrelated information to the aspect was eliminated. On the other hand, the commonsense knowledge was introduced to enhance the semantic information and the sentiment of the aspect. In this way, the results of ABSA can be improved.

Furthermore, **SK-GCN** also uses the external knowledge derived from SenticNet to construct the syntax-based GCN and semantic-based GCN. In comparison with **SKGCN**, our model performs significantly better on all datasets. Clearly, **KDGCN** is capable of exploiting the commonsense knowledge in ABSA tasks. As such, it is rational to expect the integration of external knowledge into the given sentence and thus improved sentiment classification results.

**Table 2.** Experimental results on four public datasets. The results of **R-GAT** and **R-GAT+BERT** are retrieved from [40], and others are retrieved from the original papers.


#### *4.5. Ablation Study*

An ablation study was conducted to quantitively investigate the importance of different modules in the proposed model. The results of the ablation study are given in Table 3 and Figure 4. We took the basic KDGCN as the baseline and ablated the knowledge enhancement module, semantic learning module, syntax aware module and the aspect-related select procedure. According to Table 3, the most important component for the proposed model is the syntax aware module. The accuracy drop on four datasets were 6.78%, 6.12%, 4.61% and 3.08%, which are significant. Obviously, the use of syntactic information plays a pivotal role in ABSA. Moreover, the contributions of the semantic learning module and the knowledge enhancement module are comparable. The integration of commonsense knowledge into the semantic learning process gives an improvement of the sentiment classification performance. Lastly, withdrawal of the aspect-related selection also caused a minor decrease of the working performance.

**Table 3.** Results of the ablation study.


**Figure 4.** Results of the ablation study. Different columns show the performance of different models on different datasets.

#### *4.6. Attention Visualization*

To investigate the effectiveness of the knowledge enhancement, we visualized the attention matrix. In our model, the semantics enhancement is carried out by using the commonsense from SenticNet. The connection between the aspect and its opinion word is established and enhanced. The syntax-based GCN also removes the irrelevant information by encoding the syntax dependency tree. Cases are presented to demonstrate the attention weight distribution. In the first line of Figure 5 , the attentive weights are assigned based on a basic multi-head attention mechanism. One can easily see that the minor attention was given to the opinion word **'excellent'** of the aspect **'food'**. Likewise, the attention weight of **'food'** toward **'excellent'** was also weakened. With the integration of commonsense knowledge, the relationships of both **'food'** and **'excellent'** to the context word **'meal'** were established. That is, the **'food-meal'** edge and the **'excellent-meal'** edge can be constructed by using a top-k selection. As a result, the sentiment information of **'excellent'** can be aggregated on the aspect word **'food'** with the encoding of GCN. Besides, the syntacticbased GCN, which deals with the syntactic relation among words, also facilitates the determination of aspect sentiment polarity.

Similarly, from the two figures in the second line, we can see that the aspect word **'waiter'** established a direct connection with the opinion word **'helpful'** after knowledge enhancement. Additionally, from the two figures in the last line, the aspect word **'sauce'** and the opinion word **'flavorful'** are connected through the path **'sauce-dough-flavorful'** after knowledge enhancement, so that the sentiment polarity of the aspect words can be better predicted after the subsequent network structure.

**Figure 5.** An illustration on knowledge-enhancement. (**a**) Basic attention matrix of the sentence. (**b**) Knowledge-enhanced attention matrix of the sentence. The red words are aspect words, the blue words are opinion words and the black bold words are aspect-expansion words.

#### **5. Discussion**

Through a series of experiments, we can see that our KDGCN performs well on the ABSA task. Specifically, in the main experiment part (Section 4.4), the accuracy and F1-score of our model on the four datasets are generally higher than baselines, especially compared with SK-GCN [14], which also uses SenticNet for knowledge enhancement; our improvement was 2–5%. In the ablation study, we removed the semantic learning module, the syntax aware module and so on, which proves that semantics and syntax are both important for ABSA tasks. In addition, after removing the knowledge enhancement module, the model performance also decreased significantly on the four datasets, indicating that our knowledge enhancement facilitates ABSA tasks.

Moreover, we also found the limitations of our model. Take DMGCN [11] and the use of the glove encoder as an example—KDGCN's improvement on Lap14 was not as big as that on rest14 (0.52% and 0.93%, respectively). This may be because most of the Lap14 datasets are proper nouns (such as *Windows 7* and *Microsoft*), and they do not have obvious emotional clues. Different from it, most of the words in Rest14 are daily words, so the sentiment information is rich and can be further enhanced through SenticNet. In order to obtain more semantic information and deeper connections, large-scale knowledge graphs can be introduced into the ABSA task in future work.

#### **6. Conclusions**

In this work, we propose a knowledge-enhanced dual-channel graph convolutional network to deal with the ABSA tasks. A semantic-based GCN and a syntactic-based GCN are devised to encode both the sentence semantics and the syntax. On the one hand, the external commonsense knowledge is introduced to enhance the semantics, based on which more attention is assigned to the aspect and its relevant words. On the other hand, the syntactic-based GCN processing on the syntax dependency tree further filters the low-dependency words. We demonstrate the effectiveness of our method on four benchmark datasets, obtaining state-of-the-art results on both accuracy and macro-F1. Comparing with the baseline models, the proposed method is the best alternative that produces results considerably better than the widely-applied approaches in ABSA. In the ablation experiment, we tested the contribution of each module to the model and verified that our innovation is effective. In addition, we also carried out a case analysis to further intuitively demonstrate the role of knowledge enhancement in promoting our task.

However, SenticNet is a small-scale knowledge base with shallow and limited semantics, which limits the performance of the model. Therefore, future work can consider exploring the use of a larger scale knowledge graph (such as Wikipedia) to enhance the knowledge of ABSA tasks, which can provide more clues to predict the sentiment polarity of the aspect.

**Author Contributions:** Conceptualization, Z.Z. and Y.X.; methodology, Z.Z.; formal analysis, Z.Z. and Z.M.; writing—original draft preparation, Z.Z.; writing—review and editing, S.C., J.C. and Y.X.; supervision, S.C. and Y.X.; funding acquisition, S.C. All authors have read and agreed to the published version of the manuscript.

**Funding:** This work was supported by the Characteristic Innovation Projects of Guangdong Colleges and Universities (Nos. 2018KTSCX049), the Science and Technology Plan Project of Guangzhou under Grant Nos. 202102080258 and 201903010013.

**Institutional Review Board Statement:** Not applicable.

**Informed Consent Statement:** Not applicable.

**Data Availability Statement:** Not applicable.

**Conflicts of Interest:** The authors declare no conflict of interest.

#### **References**


### *Article* **Deep Learning-Based Cyber–Physical Feature Fusion for Anomaly Detection in Industrial Control Systems**

**Yan Du 1, Yuanyuan Huang 1,\*, Guogen Wan <sup>1</sup> and Peilin He <sup>2</sup>**

<sup>2</sup> Department of Informatics and Networked Systems, University of Pittsburgh, Pittsburgh, PA 15260, USA

**\*** Correspondence: iyyhuang@hotmail.com

**Abstract:** In this paper, we propose an unsupervised anomaly detection method based on the Autoencoder with Long Short-Term Memory (LSTM-Autoencoder) network and Generative Adversarial Network (GAN) to detect anomalies in industrial control system (ICS) using cyber–physical fusion features. This method improves the recall of anomaly detection and overcomes the challenges of unbalanced datasets and insufficient labeled samples in ICS. As a first step, additional network features are extracted and fused with physical features to create a cyber–physical dataset. Following this, the model is trained using normal data to ensure that it can properly reconstruct the normal data. In the testing phase, samples with unknown labels are used as inputs to the model. The model will output an anomaly score for each sample, and whether a sample is anomalous depends on whether the anomaly score exceeds the threshold. Whether using supervised or unsupervised algorithms, experimentation has shown that (1) cyber–physical fusion features can significantly improve the performance of anomaly detection algorithms; (2) the proposed method outperforms several other unsupervised anomaly detection methods in terms of accuracy, recall, and F1 score; (3) the proposed method can detect the majority of anomalous events with a low false negative rate.

**Keywords:** deep learning; anomaly detection; cyber–physical; industrial control systems

**MSC:** 68T09

#### **1. Introduction**

In recent years, cyberattacks have caused significant damage to industrial production and national infrastructure [1]; the Stuxnet virus swept the global industry in 2010 and was able to carry out targeted attacks on infrastructure, with Iran suffering the most severe effects [2]. In 2015, a malicious program called BlackEnergy affected multiple substations in the Ukrainian power sector [3]. Many Ukrainian government agencies and companies were attacked by the ransomware NotPetya in 2017, which ultimately caused havoc worldwide [4]. A serious disaster can also result from the failure of hardware or software within an ICS as well as threats from the Internet. Globally, ICS security incidents occur frequently.

In order to secure ICS, anomaly detection is a promising approach [5]. It is usually physical faults or network attacks that cause anomalous events to occur in ICS. Sensors, actuators, pipelines, and other industrial equipment may malfunction due to physical faults. A network attack refers to an attack on a communication channel, host, or process control system, such as a man-in-the-middle attack (MITM), a denial of service (DoS), or a scanning attack. The purpose of industrial sensors is to collect status information (referred to in this paper as physical information) about the various industrial equipment in the system and to reflect the physical processes that take place within it. Physical faults have an impact on the physical operation of the system, but not on its network traffic. This results in physical faults not being detected by anomaly detection methods based solely

**Citation:** Du, Y.; Huang, Y.; Wan, G.; He, P. Deep Learning-Based Cyber–Physical Feature Fusion for Anomaly Detection in Industrial Control Systems. *Mathematics* **2022**, *10*, 4373. https://doi.org/10.3390/ math10224373

Academic Editor: Jonathan Blackledge

Received: 23 October 2022 Accepted: 18 November 2022 Published: 20 November 2022

**Publisher's Note:** MDPI stays neutral with regard to jurisdictional claims in published maps and institutional affiliations.

**Copyright:** © 2022 by the authors. Licensee MDPI, Basel, Switzerland. This article is an open access article distributed under the terms and conditions of the Creative Commons Attribution (CC BY) license (https:// creativecommons.org/licenses/by/ 4.0/).

<sup>1</sup> Department of Network Engineering, Chengdu University of Information Technology, Chengdu 610225, China

on network traffic. The physical processes of a system may not necessarily be affected by some network attacks. Consequently, algorithms that detect anomalies based solely on physical information are not able to detect these attacks. The use of anomaly detection algorithms that are based solely on physical information cannot detect network attacks in a timely manner, since most network attacks against ICS do not immediately cause the system to enter an abnormal state. Our conclusion is that taking into account both network traffic information and physical information is an effective way to improve the detection performance for anomaly detection algorithms that are used in industrial control systems [6], which has tended to be ignored in past studies.

In the past decade, artificial intelligence (AI) has been rapidly developed and applied in various fields [7–10]. A number of AI-based approaches have emerged in ICS security, which can be categorized as supervised and unsupervised algorithms as a result of the success of AI in traditional IT security [11]. In the past, many anomaly detection algorithms based on supervised algorithms have been proposed. Although the industrial Internet continues to develop, attacks from the Internet are emerging in new ways, and supervised algorithms have a limited ability to detect unknown attacks, making them increasingly unsuitable for ICS security. As ICS datasets have significant imbalances and abnormal data are much smaller than normal data, coupled with a lack of sufficient labeled samples, supervised algorithms are no longer suitable for application in ICS security problems. The limitations of supervised learning can be overcome by unsupervised algorithms such as One-Class SVM (OCSVM) [12] and isolation forests [13].

Autoencoder is an unsupervised algorithm that contains an encoder and a decoder [14]. The input X is mapped by the encoder to the latent variable Z, and subsequently Z is mapped by the decoder to the reconstruction R. The deviation between the input X and the reconstruction R is called reconstruction error. For autoencoder-based anomaly detection, the reconstruction error is used to calculate an anomaly score. Detecting anomalies can be accomplished using autoencoders trained using only normal data. When training a model, it is assumed that the model will only learn how to reconstruct for normal samples. During the testing phase, the model may not be able to reconstruct the anomaly sample well, so the anomaly sample will produce a higher reconstruction error compared to the reconstruction error of the normal sample. In some cases, small anomalies can lead to small reconstruction errors, making it difficult to detect small anomalies. Generative adversarial networks (GANs) may be used to identify small anomalies and amplify reconstruction errors [15]. Autoencoders and GANs are both unsupervised artificial neural networks, with the difference being that GANs include an adversarial game mechanism [16]. The goal of training the generator is to generate data that are as realistic as possible and thus fool the discriminator (i.e., maximizing the likelihood that the discriminator will be incorrect). As well as a generator, the GAN contains a discriminator. When training a discriminator, the objective is to minimize its own error probability, i.e., to be able to distinguish with high accuracy whether the data are real or generated. Due to the time series nature of ICS data, individual samples cannot be considered independently. Compared with ordinary autoencoders, LSTM-based autoencoders [17] have more powerful capability in reconstructing time series data.

In light of the above issues, the main contributions of this paper include the following:


from the network can significantly improve the performance of the anomaly detection algorithm.

Acquiring data from industrial sensors is an inherent function of ICS, and ICS network traffic data can be collected by listening to communication channels. It is feasible to collect both network data and physical data, and then extract the cyber–physical fusion features. The massive amount of data generated during the normal operation of the ICS is sufficient to train the unsupervised model. Without significantly changing the components of the ICS, the models need to be trained only once to detect anomalies, including various novel attack methods. It is undeniable that the components of an ICS are fixed for a long time, and the network topology is not easily changed. Therefore, the unsupervised anomaly detection model using the cyber–physical fusion features proposed in this paper can help industrial control systems cope with various cyber and physical threats, attacks, and challenges in a cost-effective and profitable manner.

In the remainder of this paper, the following sections are presented: Section 2 discusses related work in the area of anomaly detection of ICS; Section 3 describes the dataset used in this paper; Section 4 describes the method proposed in this paper; Section 5 describes the experimental setup and presents the experimental results and analysis; and Section 6 concludes with our future plans.

#### **2. Related Work**

It is necessary to detect anomalies in ICS in order to ensure its security. Studies conducted in the past can be categorized according to their use of physical information or network traffic, depending on the features selected.


neural networks are used. A multi-level anomaly detection scheme combining LSTMs and Bloom filters was proposed by Feng et al. [27] in order to detect malicious traffic in SCADA datasets. An algorithm for detecting anomalous traffic was proposed by Zhang et al. [28]. A grayscale image was created by converting the ICS traffic feature values into grayscale images, and then the model was trained with the resulting grayscale images, which improved the accuracy of anomaly detection.

According to another perspective, past research can also be divided into supervised and unsupervised research. The use of supervised machine learning has been demonstrated in some studies [29,30] as a means of detecting anomalous events or attacks. In spite of good results, the system was only able to detect known attacks and not unknown or zero-day attacks. ICS datasets are also often imbalanced, i.e., the anomalous samples are much smaller than the normal samples, which limits the performance of the supervised algorithm. Unsupervised or semi-supervised algorithms have been used in other studies to overcome the limitations of supervised algorithms. According to Kravchik et al. [31], their algorithm was able to detect 31 out of 36 network attacks using a one-dimensional CNN-based semi-supervised algorithm. Chang et al. [32] reported that an anomaly detection framework based on k-means and convolutional autoencoders achieved an F1-score of 0.9373 for water storage tank datasets. An autoencoder-based anomaly detection model was proposed by Audibert et al. [33], which used the reconstruction error as the loss function during the training phase and as the anomaly score during the testing phase. An anomaly is determined when the sample's anomaly score exceeds a predetermined threshold. Using an adaptive update strategy based on WGAN-GP, Lu et al. [34] proposed an improved generative adversarial network that produces fake anomaly samples, improving the accuracy of anomaly detection. Li et al.'s [35] GAN-based semi-supervised method, MAD-GAN, utilizes both LSTMs as generators and discriminators to capture the temporal correlation between time series distributions and potential interactions between variables, and it can detect anomalies effectively.

Anomaly detection algorithms are designed based on the selection of appropriate features. In order to detect anomalies in ICS, it is not enough to rely solely on physical information, but it is also necessary to consider network information. However, there are some limitations to the above methods due to the dataset. It was found that the datasets they selected had the following problems: (1) the dataset was not acquired in an ICS environment; (2) the dataset was nonpublic; (3) the dataset was outdated; and (4) the dataset was either restricted to physical process data or to network traffic. The authors in [36] compared the classification performance achieved by the algorithm when only using network features with that achieved by the algorithm when using physical network features, demonstrating that the fusion of physical and network information contributes to improved classification accuracy. The experiment was conducted on four supervised machine learning algorithms, but unsupervised algorithms were not considered.

Due to the above deficiencies, the following improvements have been made in this paper.


#### **3. Dataset Description**

During the normal operation of the Water Distribution Testbed as well as in the event of network attacks or physical faults, the dataset used in this study was compiled from four acquisitions. In Table 1, each acquisition is represented as a sub-dataset. During the first acquisition, eight network attacks or physical failures were conducted, resulting in eight scenarios. In a similar manner, the second and third acquisitions yielded thirteen and seven scenarios, respectively. As of the time of the fourth acquisition, the system was functioning normally, without any network attacks or physical faults. In total, physical faults included two water leaks and six sensors and pumps breakdowns, and network attacks included eight man-in-the-middle (MITM) attacks, five denial of service (DoS) attacks, and seven scanning attacks. Figure 1 shows the number and proportion of samples divided into normal and malicious for each acquisition.

**Table 1.** Data acquisition and description.

**Figure 1.** Number and proportion of samples divided into normal and malicious.

This dataset provides both the physical process data and the corresponding raw network traffic. The physical process data describe the information for the 40 physical statuses of the system in every second, such as whether the pump is turned on and the pressure sensor value of the water tank. In addition, the dataset also provides some network features extracted from raw network traffic data, such as the IP and MAC addresses of packets. For more detailed information on this dataset, please refer to [37].

#### **4. Methodology**

Firstly, we describe how to extract additional network features and fuse them with the original physical features. Then we formulate the problem, and finally we describe our proposed anomaly detection model in more detail.

#### *4.1. Extraction and Fusion of Additional Network Features*

The physical information collected by industrial sensors alone is not capable of detecting abnormal behavior in time owing to the widespread adoption of traditional information

technology in ICS, and the damage caused by cyberattacks has a hysteresis. It is therefore important to take into account both physical information and network traffic when extracting features for anomaly detection.

The original physical process datasets have a sampling interval of one second, while the number of samples collected per second in the original network datasets is over 1000. By re-extracting the network features from the original network traffic according to the specifics of the ICS network, we can fuse the physical and network features together and summarize the situation every second. To enhance the performance of anomaly detection, 22 additional features were extracted from the original dataset. In Table 2, you will find a list of the new features that have been added.


**Table 2.** Additional extracted features.

In addition, a feature named stage is added, which describes the stage of the current moment in the process cycle, and its value range is (0,1]. Taking the fourth acquisition as an example, there are a total of 3423 sampling points, including 12 complete process cycles. As shown in Figure 2, in each process cycle, the water level of Tank\_1 gradually increases from 0 to the maximum value, and then gradually decreases to 0 and maintains for a period of time. Correspondingly, the value of stage is gradually increased from 0 to 1 and maintained for a period of time.

The architecture of additional feature extraction and fusion is shown in Figure 3. For each row (sample) in the original physical dataset, its sampling time is time t (e.g., 09/04/2021 11:30:55). All packets with time t are aggregated from the network traffic corresponding to this physical dataset, and the features described in Table 2 are extracted from those packets. Subsequently, the newly extracted features are fused with the original physical features to form a cyber–physical dataset. Some incomplete data were deleted, which were mainly concentrated in the first and last part of the dataset. The reason for the incomplete data is that when the original physical dataset is acquired, the corresponding original network dataset has not yet been acquired or the acquisition has been completed. The information of the finally formed cyber–physical dataset is shown in Table 3.

**Figure 2.** Changes of stage value and Tank\_1 water level during process cycle.

**Figure 3.** Additional network feature extraction and fusion architecture.

**Table 3.** Cyber–physical dataset after feature fusion.


#### *4.2. Problem Formulation*

In this paper, a dataset with sample number *T* is considered a multivariate time series *TS* of length *T*. *xt* is a vector consisting of all physical features and network features at time *t*, and the number of features is *m*.

$$TS = \{\mathbf{x}\_1, \mathbf{x}\_2, \dots, \mathbf{x}\_T\} (\mathbf{x}\_t \in \mathbb{R}^m, 1 \le t \le T) \tag{1}$$

In order to make better use of the correlation between observations at the current moment and previous observations, a time window *Wt* is defined. For each observation, its correlation with the previous *K* observations is considered. Therefore, the original time series *TS* can be transformed into a time window series *W*.

$$\mathcal{W} = \{\mathcal{W}\_1, \mathcal{W}\_2, \dots, \mathcal{W}\_T\} \\
\{\mathcal{W}\_t = \{\mathbf{x}\_{t-K+1}, \dots, \mathbf{x}\_{t-1}, \mathbf{x}\_t\} \in \mathbb{R}^{K \ast m}, 1 \le t \le T\} \tag{2}$$

Use the time window series *W* as the input to the model instead of the raw time series *TS*. Before conversion to a time window series, each observation *xt* in the *TS* was normalized by

$$TS^j = \left\{ \mathbf{x}\_{1'}^j, \mathbf{x}\_{2'}^j, \dots, \mathbf{x}\_T^j \right\} \left( \mathbf{x}\_t^j = \frac{\mathbf{x}\_t^j - \min\{TS^j\}}{\varepsilon + \max\{TS^j\} - \min\{TS^j\}}, 1 \le j \le m, 1 \le t \le T \right) \tag{3}$$

where *ε* is a very small number in order to prevent zero-division.

#### *4.3. Proposed Model*

The proposed model consists of three modules: an encoder network *LE* using LSTM, and two decoder networks *LD*<sup>1</sup> and *LD*<sup>2</sup> using LSTM. As can be seen from Figure 4, these three modules constitute two LSTM-Autoencoders *LAE*<sup>1</sup> and *LAE*<sup>2</sup> that share the encoder network. The hyperparameters of the model are shown in Table 4. The training of the model consists of two phases.

**Figure 4.** Proposed model architecture. The proposed model consists of three modules: an encoder network *LE*, and two decoder networks *LD*<sup>1</sup> and *LD*2.



4.3.1. Phase 1—Input Reconstruction

The goal of this phase is to train *LAE*<sup>1</sup> and *LAE*<sup>2</sup> to reconstruct the input. LSTM-Autoencoder can reconstruct each time window *Wt* = {*x*1,..., *xK*−1, *xK*}. The time window *Wt* is used as the input of the model, and the encoder network *LE* will output the hidden variable *hK* ∈ R*<sup>n</sup>* (n is the number of cells in the LSTM hidden layer). Then, the two decoder networks will output the reconstructions of *Wt* (*O*<sup>1</sup> and *O*2) according to *hK* and *xK* in reverse order, where *xK* is the last of *Wt*. Use L2-norm to define the reconstruction loss for each decoder:

$$O\_1 = LAE\_1(\mathcal{W}\_t), O\_2 = LAE\_2(\mathcal{W}\_t) \tag{4}$$

$$Loss1 = ||W\_{\rm t} - O\_1||\_2, Loss2 = ||W\_{\rm t} - O\_2||\_2 \tag{5}$$

4.3.2. Phase 2—Adversarial Training

In the second phase, *LAE*<sup>1</sup> and *LAE*<sup>2</sup> are trained adversarially. Put reconstruction *O*<sup>1</sup> as input to *LAE*<sup>2</sup> again and output reconstruction *O*3. The purpose of training *LAE*<sup>2</sup> is to hope that it can distinguish whether *O*<sup>3</sup> is the real data or a reconstruction of the output of *LAE*1. Conversely, *LAE*<sup>1</sup> is trained to fool *LAE*2, that is, making *LAE*<sup>2</sup> unable to judge whether *O*<sup>3</sup> is the real data. The training objective is:

$$O\_3 = LAE\_2(LAE\_1(\mathcal{W}\_l))\tag{6}$$

$$\begin{array}{l}\underset{LAE\_1LAE\_2}{\min} ||\mathcal{W}\_t - \mathcal{O}\_3||\_2\\\end{array} \tag{7}$$

Therefore, the goal of *LAE*<sup>1</sup> is to minimize the distance between *O*<sup>3</sup> and *Wt*, and the goal of *LAE*<sup>2</sup> is to maximize this distance, and the loss is defined as follows:

$$Loss1 = +||W\_t - O\_3||\_{2} \, Loss2 = -||W\_t - O\_3||\_{2} \tag{8}$$

Then, the evolutionary loss function is used to combine the losses of the two phases as the total loss for each LAE.

$$Loss1 = \frac{1}{n}||\mathcal{W}t - O\_1||\_2 + (1 - \frac{1}{n})||\mathcal{W}t - O\_3||\_2\tag{9}$$

$$\text{Loss2} = \frac{1}{n}||\mathcal{W}\_t - \mathcal{O}\_2||\_2 - (1 - \frac{1}{n})||\mathcal{W}\_t - \mathcal{O}\_3||\_2\tag{10}$$

where *n* denotes the number of training iterations. The training process of the model can be seen in Figure 5a. Now define the anomaly score:

$$AnomalyScore = \frac{1}{2}||W\_t - O\_1||\_2 + \frac{1}{2}||W\_t - O\_3||\_2\tag{11}$$

After the training is completed, the model is used to calculate the anomaly scores for each time window in the normal dataset, and then a threshold is determined based on the distribution of the anomaly scores. During the testing phase, shown in Figure 5b, for each unseen time window, the trained model will output its anomaly score. When the anomaly score of a time window is higher than the threshold, the model judges it as an anomaly.

**Figure 5.** Proposed model training and testing flow chart. (**a**) Training flow chart; (**b**) testing flow chart.

#### **5. Experiments and Results Analysis**

*5.1. Experiment Environment and Metrics*

The experiments were performed using the following hardware and software platforms: Intel(R) Core (TM) i5-12400 CPU, Windows 10 Professional (64 bits), NVIDIA GeForce GTX 1650 Super, NVIDIA CUDA 11.1, Python 3.7.13, Pytorch 1.8.2, Python Scikitlearn library 1.0.2.

The proposed model is evaluated using recall, precision, F1-score, and accuracy. *TP*, *TN*, *FP*, and *FN* represent true positive, true negative, false positive, and false negative, respectively.

$$Recall = \frac{TP}{TP + FN} \tag{12}$$

$$Precision = \frac{TP}{TP + FP} \tag{13}$$

$$F1 - score = 2 \times \frac{Recall \times Precision}{Recall + Precision} \tag{14}$$

$$Accuracy = \frac{TP + TN}{TP + FP + FN + TN} \tag{15}$$

#### *5.2. Dataset*

The dataset needs to be divided differently for supervised and unsupervised algorithms.

#### 5.2.1. Dataset for Supervised Algorithms

In this paper, the datasets are organized chronologically, with each cyberattack or physical fault lasting for a period of time, corresponding to multiple consecutive samples. When the dataset is shuffled and then divided, some samples from an abnormal event will be placed in the training set, and the remainder will be placed in the testing set. As a result, the model is able to achieve a higher accuracy on the testing set, but this is an illusion [36]. As a result, all samples will be either divided into training sets or testing sets, depending on the scenario. Divide 85% of the normal data into the training set and the rest into the testing set. For the anomaly scenarios, scenario 1.1–1.6, 2.1–2.7, 3.1–3.3 are divided into the training set and the rest are divided into the testing set. Use min–max to normalize the data, and the information of the dataset is shown in Table 5.

**Table 5.** Dataset information for supervised algorithms.


#### 5.2.2. Dataset for Unsupervised Algorithms

The fourth acquisition (no anomalies) is used as the training set to train the model. The other three acquisitions (with anomalies) were used as testing sets to evaluate the model.

#### *5.3. Experiments of Using Supervised Algorithms*

Three supervised machine learning algorithms were used for training and testing: random forest (RF), support vector machine (SVM), and naïve Bayes (NB). Use *Random-ForestClassifier*, *SVC*, and *GaussianNB* in the Python Scikit-learn library to implement the above algorithm, and the hyperparameters for all of the above algorithms are generated by Python Scikit-learn library 1.0.2 defaulted.

The experimental results are shown in Table 6. All three algorithms achieve poor performance when only using physical features. The best performance is achieved by RF, but its F1 score is only 0.28. When using cyber–physical features, the performance achieved by all three algorithms is greatly improved, with F1 scores exceeding 0.87. The results show that the additionally extracted network features can significantly improve the anomaly detection performance of the supervised algorithm.


**Table 6.** Performance of three supervised machine learning algorithms.

#### *5.4. Experiments of Unsupervised Algorithms*

5.4.1. Performance of the Proposed Model

Consider two situations, one using only physical features and another using cyber– physical features. Table 7 shows the performance achieved by the proposed model in the above two situations. In addition, the anomaly scores of the three test sets obtained by the model in the above two situations are shown in Figures 6 and 7, respectively. Cyberattacks and physical faults are marked in red and blue, respectively, in the figure.


**Table 7.** Performance of the proposed model.

**Figure 6.** Anomaly scores for the proposed model (using only physical features).

When using only physical features, the model performed poorly on all test sets. Conversely, when combining additionally extracted network features, the performance is greatly improved on each test set. We believe that the reason for the poor results obtained by physical features alone is that there are some network attacks that do not affect the physical state of the system too much, so the model fails to detect these network attacks. For the network attack scenario, the anomaly score given by the model for anomalous time points is significantly higher than that for non-anomalous time points, indicating that the model can easily detect network attack events. For physical fault scenarios, the anomaly scores given to anomalous time points are not very significant, but are sufficient to detect most physical fault events.

**Figure 7.** Anomaly scores for the proposed model (using cyber–physical features).

Due to the continuous increase in the degree of impact of an attack or fault on the system, it may not cause immediate damage to the system at the beginning, resulting in false negatives. Furthermore, it may still take some time for the attacked system to return to normal after the attack has ended, which may result in false positives. An example would be Scenario 1.6, which simulates the rise of the water level in Tank 3 as a result of a leak in the pipeline. A graph of the water level in Tank 3 over time is shown in Figure 8a. Figure 8b shows the corresponding anomaly scores, as well as the time period during which the fault occurred (scenario 1.6). While the water level rose initially, it was consistent with the normal rise in the tank's level. In this period, the anomaly score does not exceed the threshold, and the model considers it to be a normal period. Persistent faults cause the water level to exceed the normal level and continue to rise. As a result, the anomaly score for this period gradually increases and exceeds the threshold. Upon the resolution of the fault, the water level begins to decline, which is reflected in the anomaly score as well. Nevertheless, the water level remains above the normal level for a period of time after the faults have been resolved, so the anomaly score remains above the threshold, and the model still considers the system to be abnormal.

Figure 9a shows the changes of the water levels of Tank 1 and Tank 5 over time, and Figure 9b shows the corresponding anomaly scores. The time periods of the three fault scenarios are marked by red, green, and blue, respectively. Scenario 3.1 simulates a fault that pauses the transfer of water from Tank 1 to Tank 5. Scenario 3.3 simulates a fault by closing the Tank 5 outlet valve, thus achieving a slowdown in the flow of water from Tank 5. Scenario 3.4 simulates a fault that suspends the transfer of water from the reservoir to Tank 1. None of the above three faults caused the water level to exceed the normal level,

so none of the anomaly scores exceeded the threshold and the model considered the system to be in a normal state.

**Figure 8.** Scenario 1.6. (**a**) Water level in Tank 3; (**b**) anomaly score.

**Figure 9.** Scenario 3.1, 3.3, 3.4. (**a**) Water level in Tank 3 and Tank 5; (**b**) anomaly score.

5.4.2. Comparison with Other Unsupervised Algorithms

This section compares the performance of OCSVM [12], Isolation Forest (iForest) [13], USAD [33], and the proposed model. This paper implements USAD based on the author's GitHub repository. Both One-Class SVM and Isolation Forest are provided by the Python Scikit-learn library and use default parameters.

As can be seen in Table 8, the proposed model outperforms several other algorithms. OCSVM and iForest achieved a high recall rate, but too many false positives resulted in a low F1 score. Compared with the first two algorithms, the F1 score of USAD has been greatly improved, but the recall rate is lower. Low recall means that there are more false negatives, meaning that the model does not effectively detect anomalies, which is fatal for anomaly detection systems. As shown in Figure 10, USAD is able to detect most network attacks, but it is almost incapable of detecting physical faults. In contrast, the proposed model can detect most physical faults. The experimental results show that for the ICS anomaly detection task, the model proposed in this paper can achieve better performance.


**Table 8.** Performance comparison of the proposed model with other methods.

**Figure 10.** Anomaly scores for USAD (using network and physical features).

#### *5.5. Ablation Experiments*

The LSTM autoencoder in the proposed model is replaced by the standard autoencoder, BiLSTM autoencoder, and GRU autoencoder, and their hyperparameters are shown in Table 9. We removed the adversarial training phase from the proposed model, which is hereafter referred to as the proposed model with no adversarial training. The same training settings were set for the above models: the batch size is 32, the window size is 3, the optimizer is Adam, the learning rate is 0.001, the max epoch is 100, and the initial parameters are generated by Pytorch-1.8.2 defaulted.


**Table 9.** Three autoencoder hyperparameters.

#### *5.6. Discussion*

As shown in Table 6, with the addition of network features, the accuracies of RF, SVM, and NB improved from 0.777, 0.763, and 0.690 to 0.958, 0.953, and 0.945, respectively, and the F1 score, precision, recall, and accuracies of the proposed model improved from 0.425, 0.479, 0.382, and 0.686 to 0.758, 0.800, 0.720, and 0.860, respectively. This is due to the existence of some network attacks, such as scanning attacks, which only generate some anomalous network traffic data, but do not have a substantial impact on the physical conduct of the system. Therefore, the fusion of network traffic data and physical sensor data definitely helps to improve the anomaly detection capability.

Compared with other unsupervised algorithms, the unsupervised anomaly detection model proposed in this paper has better performance. As shown in Table 8, the USAD model achieves a recall of only 0.613 when using cyber–physical fusion features, while the proposed model can improve the recall to 0.720, with a performance improvement of about 17.5%. As can be seen from Figures 7 and 10, the USAD model gives anomaly scores for normal and abnormal data that are not very different in general, which means that it does not reconstruct normal data perfectly and therefore cannot clearly distinguish between normal and abnormal samples. In contrast, the proposed model gives a large difference in the abnormal scores for normal and abnormal data, which indicates that the model can detect abnormalities well.

Table 10 depicts the performance of the standard autoencoder, BiLSTM autoencoder, GRU autoencoder, the proposed model (LSTM autoencoder), and the proposed model with no adversarial training, and Figure 11 shows the time they need to consume for one training. It can be seen that the standard autoencoder achieves the fastest training speed as well as the highest recall rate, but its precision and F1 scores are the lowest. This means that

the model identifies numerous normal data as abnormal. The BiLSTM autoencoder took more time to train, but the improvement in performance was marginal. The time cost of training the GRU autoencoder is slightly lower than the time cost of training the proposed model, but the performance of the GRU autoencoder is much worse than the proposed model. It achieves a recall of 0.618, while the proposed model achieves a recall of 0.72, which we believe is worth the small time cost to obtain such a significant improvement. As shown in Figure 12, the model is able to reduce the loss earlier and with smaller loss values when adversarial training is performed. Furthermore, when adversarial training is removed, the recall decreases from 0.72 to 0.652, which is sufficient to demonstrate that adversarial training based on generative adversarial networks is indeed able to identify small anomalies by amplifying the reconstruction error.

**Figure 11.** Training time for proposed model and other models.

**Figure 12.** Loss with adversarial training and without adversarial training.


**Table 10.** Performance comparison of LSTM autoencoder with others.

#### **6. Conclusions**

Given the special characteristics of ICS networks, we designed a method to extract network features. Based on the latest publicly available ICS dataset, the network features are extracted using the previously mentioned method, and then an ICS cyber–physical dataset is created. The anomaly detection algorithm obtained by training with this fused feature has better performance. In addition, we propose an unsupervised anomaly detection method based on LSTM-Autoencoder and GAN. The results of the ablation experiments show that using LSTM as an autoencoder is the optimal choice, and adversarial training based on GAN can also help the model to detect more anomalies.

This paper uses a dataset acquired in ICS using only the Modbus TCP protocol, but other protocols such as S7 and EtherNet/IP exist in the global industry. Our future work will investigate a more effective and compatible method for detecting ICS anomalies based on a more comprehensive dataset.

**Author Contributions:** Conceptualization, Y.D. and Y.H.; funding acquisition, Y.H. and G.W.; methodology, Y.D. and Y.H.; resources, G.W.; software, Y.D.; validation, Y.H., G.W., and P.H.; writing—original draft, Y.D.; writing—review and editing, G.W. and P.H. All authors have read and agreed to the published version of the manuscript.

**Funding:** This research was supported by Natural Science Foundation of Sichuan Province (No. 2022NSFSC0557), Foundation of Sichuan Network Culture Research Center (No.WLWH22-18), National Natural Science Foundation of China (No. 62102049, No. 62076042), and Key Research and Development Project of Sichuan (No. 2021YFSY0012, No. 2021YFG0332).

**Data Availability Statement:** The data presented in this study are openly available in the IEEE DataPort at [10.21227/rbvf-2h90] accessed on 20 November 2022, reference number [37].

**Conflicts of Interest:** The authors declare no conflict of interest.

#### **References**


### *Article* **Deep Large-Margin Rank Loss for Multi-Label Image Classification**

**Zhongchen Ma 1,2,\*,†, Zongpeng Li 1,2,† and Yongzhao Zhan 1,2**


**Abstract:** The large-margin technique has served as the foundation of several successful theoretical and empirical results in multi-label image classification. However, most large-margin techniques are only suitable to shallow multi-label models with preset feature representations and a few largemargin techniques of neural networks only enforce margins at the output layer, which are not well suitable for deep networks. Based on the large-margin technique, a deep large-margin rank loss function suitable for any network structure is proposed, which is able to impose a margin on any chosen set of layers of a deep network, allows choosing any *<sup>p</sup>* norm (*p* ≥ 1) on the metric measuring the margin between labels and is applicable to any network architecture. Although the complete computation of deep large-margin rank loss function has the <sup>O</sup>(*C*2) time complexity, where *<sup>C</sup>* denotes the size of the label set, which would cause scalability issues when *C* is large, a negative sampling technique was proposed to make the loss function scale linearly to *C*. Experimental results on two large-scale datasets, VOC2007 and MS-COCO, show that the deep large-margin ranking function improves the robustness of the model in multi-label image classification tasks while enhancing the model's anti-noise performance.

**Keywords:** image classification; large-margin technique; deep neural network; robustness; anti-noise performance

**MSC:** 68T01

#### **1. Introduction**

Multi-label image classification (MLiC) aims to predict a set of visual concepts present in an image, which is one of the most important problems in computer vision. It can be widely applied to numerous real-world applications, such as scene recognition [1,2] or medical diagnosis [3,4]. In contrast with single-class or multi-class image classification, which only allows each image associated with a unique class label from a set of disjoint class labels, MLiC allows the images to be associated with more than one class label. MLiC is thus more general and realistic than the other tasks and such a generality makes it more difficult than them.

To cope with this task, one approach is called problem transformation, which transforms the multi-label learning problem into several binary classification problems or multiclass classification problems. Representative algorithms include binary relevance [5] and random k-labelsets [6]. Another approach is called algorithm adaptation, which adapts popular learning techniques to deal with multi-label data directly. Representative algorithms include ML-kNN [7] and Rank-SVM [8]. Conventionally, most of them use handcrafted features for image classification, such as SIFT [9], histogram of oriented gradients [10] and

**Citation:** Ma, Z.; Li, Z.; Zhan, Y. Deep Large-Margin Rank Loss for Multi-Label Image Classification. *Mathematics* **2022**, *10*, 4584. https:// doi.org/10.3390/math10234584

Academic Editor: Jakub Nalepa

Received: 27 October 2022 Accepted: 29 November 2022 Published: 3 December 2022

**Publisher's Note:** MDPI stays neutral with regard to jurisdictional claims in published maps and institutional affiliations.

**Copyright:** © 2022 by the authors. Licensee MDPI, Basel, Switzerland. This article is an open access article distributed under the terms and conditions of the Creative Commons Attribution (CC BY) license (https:// creativecommons.org/licenses/by/ 4.0/).

local binary patterns [11]. Inefficient feature representation may limit the performance of traditional methods in multi-label image classification tasks.

Motivated by the success of deep neural networks, some approaches combine deep representation learning and multi-label learning into an end-to-end trainable system. By dividing the original multi-label classification problem into multiple independent binary classification tasks, convolution neural network (CNN) can be applied naturally. However, this kind of method ignores label correlations, which has promoted research into deep learning methods to capture and explore label correlations. RNN-CNN [12] and ML-GCN [13] are two typical representatives of this kind of method. Some new approaches tend to explore label correlations, ref. [14] designed the label correlation term defined on some anchor data, and ref. [15] proposed a novel framework with local feature selection and local label correlation.

For simplicity, most deep MLiC classifiers adopt binary cross-entropy (BCE) loss function for training. Training such a deep multi-label image classifier requires collecting clean multi-label annotations for a large number of images, which is costly or even impossible in real-world applications. Therefore, even slight label perturbations may reduce the performance of traditional deep MLiC classifiers. The large-margin technique, maximizing the distance of each training point to a decision boundary, can effectively solve this problem [16]. Specifically, if the classifier reaches the boundary of *γ*, that is, the decision boundary is at least *γ* away from all training images, then any input perturbation less than *γ* will not flip the predicted label. For deep MLiC classifiers, the conventional definition of the margin is based on output values. However, the input margin is often of more practical interest. For example, a large-margin in the input space implies immunity to input perturbations. However, the margin in the input space is computationally intractable for deep MLiC classifiers.

To address the aforementioned issues, a novel deep large-margin rank loss function (DlmRl) for MLiC task is proposed. By treating the activations at each intermediate layer of the deep MLiC classifier as an intermediate representation of the image, DlmRl is able to impose a margin on any chosen set of layers of a deep network. The margin between labels can be measured by choosing any *<sup>p</sup>* norm (*p* ≥ 1), which applies to any network architecture and provides more practicability. Although the complete computation of DlmRl has the O(*C*2) time complexity, where *<sup>C</sup>* denotes the size of the label set, we propose the negative sampling technique to make our loss function scale linearly to *C*. Experimental results on VOC2007 and MS-COCO show the effectiveness of our approach. Our contributions are three-fold:

(1) In this paper, a novel deep large-margin ranking loss for multi-label image classification tasks is designed, which can be applied between any layers of the deep network, the implementation of which is more flexible and compatible, thus enhancing the universality of the deep network;

(2) The proposed method quantifies the interval by an arbitrary *<sup>p</sup>* norm (*p* ≥ 1) to achieve a measurable margin. The metric enhances the controllability of the labels, improves the confidence of the label data, and therefore strengthens the comprehensibility and trustworthiness of the deep network.

(3) We propose a negative sampling technique applied to the large-margin loss in multi-label image classification tasks. This negative sampling technique greatly reduces the complexity of operations and therefore improves the performance of DlmRl operations.

#### **2. Related Works**

#### *2.1. Multi-Label Image Classification*

Deep convolutional neural networks have made great progress on the MLiC task. Some works embed label dependencies with the deep model to improve the accuracy of MLiC. A popular method is to use recurrent neural networks (RNNs) [17] or long shortterm memory (LSTM) [18] to model the label dependencies. However, its performance depends on the label order. Recent works use graph neural networks (GNNs) to explicitly

model label dependencies. For example, the works [13,19,20] utilized GNN to propagate the dependencies to learn inter-dependent classifiers.

Some works mainly focus on learning deep attentional representations for each label by treating an image as multiple images sampled from different regions. For example, ref. [21] introduced a max pooling layer that hypothesizes the possible location of the label in an image. Ref. [22] research on capturing the proximity and geometric structure of k-nearest neighbors. Ref. [23] combined the global average pooling with class activation maps to enable the localization ability of CNN. Ref. [24] proposed a new activation function to output the sparse probabilities of each label. Ref. [25] generated class-specific features for every category by proposing a simple spatial attention score. Ref. [26] unite similarity-based learning and generalized linear models to achieve the best of both worlds.

Recent works exploit the label noise property of the multi-label problem. For example, ref. [27] proposed a robust logistic loss function to train CNNs from user-provided tags. Ref. [28] exploited the potential connections between noisy labels and feature contents to identify the noisy labels. Ref. [29] proposed a curriculum learning strategy to predict missing labels. Ref. [30] proposed a loss function that measures the smoothness of labels and features of images on the data manifold to handle training data with noise labels. Although good performance has been achieved, these methods all add specific noisy-labelprocessing terms to the traditional multi-label loss function, e.g., BCE with logits loss (bce) [31]. In this paper, we aim to propose a plug and play loss, which performs well on MLiC tasks and is also robust to label noise.

#### *2.2. Large-Margin Classification*

The large-margin technique plays a key role in many machine learning algorithms. Traditional large-margin algorithms are designed for shallow models and have good interpretability. Support vector machine (SVM) [32] is a well-known large-margin technique, which tries to separate the training examples of different classes with a maximized margin. The margin provides good support to the generalization performance of SVM and has also been extended to interpret the good generalization of many other learning algorithms, such as AdaBoost [33].

In the context of deep neural networks, the large-margin technique has also shown potential performance. Ref. [34] encouraged large-margin solutions of cross-entropy loss by additional terms, however, these terms encourage margins only at the output layer of a deep neural network. Ref. [35] demonstrated that deep networks can attain a max-margin solution by their proposed regularizer, however, the regularizer may not be robust to the deviation of data. Ref. [16] formulated a loss function that directly maximizes the margin at any layer, including input, hidden and output layers. Its formulation is general to margin definitions in different distance metrics (e.g., 1, 2, and <sup>∞</sup> norms), and thus is relatively robust to data disturbances. Inspired by this large-margin loss formulation, we proposed a large-margin rank loss for the MLiC task, which inherits the good properties, and shows the effectiveness on three large-scale MLiC datasets.

#### **3. Method**

#### *3.1. Notations*

The goal of MLiC task is to find all labels of an image. Suppose we have *N* training images *<sup>I</sup>*1, ... , *IN*, as well as observe their label vectors {**y***<sup>i</sup>* }*N <sup>i</sup>*=1, where **<sup>y</sup>***<sup>k</sup>* = [*y<sup>k</sup>* <sup>1</sup>, ... , *<sup>y</sup><sup>k</sup> <sup>C</sup>*] ∈ Y ⊆ {−1, 1}*C*, *<sup>C</sup>* denotes the number of labels. For a given image *Ik* and label *<sup>c</sup>*, *<sup>y</sup><sup>k</sup> <sup>c</sup>* = 1(*resp*. − 1) indicates the presence (resp. absence) of the label *c* in image *k*. Let *Pk* and *Nk* denote the positive labels and the negative labels in **y***k*.

#### *3.2. Large-Margin Ranking Loss*

The above tasks can be converted to solve optimization problems to learn deep prediction models *<sup>f</sup>*(*I*; *<sup>θ</sup>*) ∈ R*<sup>C</sup>* with parameter *<sup>θ</sup>* by solving an optimization problem [36].

$$\min\_{\theta} \frac{1}{N} \sum\_{k=1}^{N} l\left(f(I\_k; \theta), \mathbf{y}^k\right) + \mathcal{R}(\theta) \tag{1}$$

where *l f*(*Ik*; *θ*), **y***<sup>k</sup>* is a loss function and R(*θ*) is a regularization term. Let *<sup>f</sup> <sup>i</sup> <sup>c</sup>* denote the prediction score of a deep network for classifying the image *i* to label *c*.

Multi-label pairwise ranking loss aims to produce a label vector for image *Ik*, whose values for positive labels *Pk* are greater than those for the negative labels *Nk*, i.e., *fu*(*Ik*) > *fv*(*Ik*), ∀*u* ∈ *Pk*, *vs*. ∈ *Nk*,

$$d\_{\text{rank}} = \sum\_{\upsilon \in \mathcal{N}\_k} \sum\_{\mu \in P\_k} \max\left(0, \alpha + f\_{\upsilon}(I\_k) - f\_{\mu}(I\_k)\right) \tag{2}$$

where *α* is a hyper-parameter that determines the margin, commonly set to 1 [31].

Although pair-wise ranking loss has achieved state-of-the-art results on various benchmarks of MLiC, it only encourages margins at the output layer of a deep neural network. We propose that the input margin is more robust to input perturbations and is thus often of more practical interest.

Specifically, a model of MLiC with a margin of *δ* is robust to perturbations *Ik* + *δ*, where *sign*(*fv*(*Ik*) − *fu*(*Ik*)) = *sign*(*fv*(*Ik* + *δ*) − *fu*(*Ik* + *δ*)), for ∀*u* ∈ *Pk*, *vs*. ∈ *Nk*. *sign*(·) is a sign function, in mathematics and computer operations, which takes the sign (positive or negative) of a number. For instance, the example shown in Figure 1 expresses the goal of our task.

**Figure 1.** As shown in the above figure, the left side represents the prediction value of image (*Ik*) obtained through the prediction model, and the right side represents the predicted value of the image with perturbations (*Ik* + *δ*) obtained through the prediction model. The positive labels in the clean image include: umbrella, rain coat, car and person; negative labels include trunk and sunglasses. We hope that our model is robust to perturbations *Ik* + *δ*. For example, the car is the positive label in the real predicted value. After the perturbations are added, the predicted value of the car is still higher than that of negative labels.

To this end, a deep large-margin ranking loss for MLiC, i.e., DLmRl, is proposed, which is able to impose a margin on any chosen set of layers of a deep network, allowing to choose any *<sup>p</sup>* norm (*p* ≥ 1) on the metric measuring the margin between labels and is applicable to any network architecture. We define the ranking boundary between any pair of labels {*u*, *v*}, where *u* ∈ *Pk*, *v* ∈ *Nk*, as

$$\mathcal{D}\_{\{\boldsymbol{\mu},\boldsymbol{v}\}} \triangleq \left\{ I\_{\boldsymbol{k}} \mid f^{\boldsymbol{k}}\_{\boldsymbol{\mu}} = f^{\boldsymbol{k}}\_{\boldsymbol{v}} \right\} \tag{3}$$

Under this definition, the distance of an image *Ik* to the ranking threshold is defined as the smallest displacement of the point that results in a score tie:

$$\begin{aligned} d\_{f, l\_k, \{u, v\}} & \stackrel{\scriptstyle \Delta}{=} \min\_{\mathcal{S}} ||\mathcal{S}||\_p \\ \text{s.t.} \quad & f\_{\boldsymbol{\mu}}(I\_k + \mathcal{S}) = f\_{\boldsymbol{v}}(I\_k + \mathcal{S}) \end{aligned} \tag{4}$$

The exact computation of *d* is intractable when *f* s are nonlinear, ref. [16] presented an approximation to *d* by linearizing *f* with respect to *δ* around *δ* = 0.

$$\begin{aligned} \tilde{d}\_{f, l\_k, \{u, v\}} & \stackrel{\scriptstyle \Delta}{=} \min\_{\delta} ||\mathcal{S}||\_p \\ \text{s.t.} \quad f\_u^k + \left< \mathcal{S}, \nabla\_{l\_k} f\_u^k \right> &= f\_v^k + \left< \mathcal{S}, \nabla\_{l\_k} f\_v^k \right> \end{aligned} \tag{5}$$

According to [16], this problem then has the following closed form solution:

$$\bar{d}\_{f, l\_k \{u, v\}} = \frac{\left| f\_u^k - f\_v^k \right|}{||\nabla\_{I\_k} f\_u^k - \nabla\_{I\_k} f\_v^k||\_q} \tag{6}$$

where ·*<sup>q</sup>* is the dual-norm of ·*p*. Specifically, if distances are measured with respect to *l*1, *l*2, or *l*<sup>∞</sup> norm, their dual norms will, respectively, be *l*∞, *l*2, or *l*<sup>1</sup> norm.

We start with a triple set (*Ik*, *u*, *v*) and penalize the displacement of *Ik* to satisfy the margin constraint for *f <sup>k</sup> <sup>u</sup>* > *f <sup>k</sup> <sup>v</sup>* . This implies using the following loss function:

$$\max\left\{0, \gamma + d\_{f,k,\{u,v\}} \text{sign}\left(f\_v^k - f\_u^k\right)\right\} \tag{7}$$

where the *sign*(·) adjusts the polarity of the distance. The intuition is that, if the constraint *f k <sup>u</sup>* > *f <sup>k</sup> <sup>v</sup>* is already satisfied, then we only want to ensure it has distance *γ* from the ranking threshold, and penalize proportional to the distance *df* ,*k*,{*u*,*v*} it falls short, so the penalty is max{0, *γ* − *d*}. However, if it is not satisfied, we also want to penalize the label for not being correctly ranked. Hence, the penalty includes the distance *I<sup>k</sup>* which needs to travel to reach the ranking threshold as well as another *γ* distance to travel on the correct side of the ranking threshold to attain the *γ* margin. Therefore, the penalty becomes max{0, *γ* + *d*}. For image *Ik*, we aggregate individual losses arising from each *u* ∈ *Pk* and *v* ∈ *Nk* to obtain the DlmRl formulation, i.e.,

$$\ell\_{DlmRl} = \sum\_{u \in P\_k, v \in N\_k} \max \left\{ 0, \gamma + d\_{f, k, \{u, v\}} \operatorname{sign} \left( f\_v^k - f\_u^k \right) \right\} \tag{8}$$

Plugging (6) into (8), the loss function becomes:

$$\sum\_{u \in P\_k, v \in N\_k} \max \left\{ 0, \gamma + \frac{\left| f\_u^k - f\_v^k \right| \text{sign} \left( f\_v^k - f\_u^k \right)}{\left\| \nabla\_{I\_k} f\_u^k - \nabla\_{I\_k} f\_v^k \right\|\_q} \right\} \tag{9}$$

This further simplifies into the following loss formulation:

$$\sum\_{u \in P\_k, v \in N\_k} \max \left\{ 0, \gamma + \frac{f\_v^k - f\_u^k}{||\nabla\_{I\_k} f\_u^k - \nabla\_{I\_k} f\_v^k||\_q} \right\} \tag{10}$$

In deep networks, the activations at each intermediate layer could be interpreted as some intermediate representation of the data. To force the entire representation and ranking thresholds to maintain a large-margin, the loss formulation can be defined based on any intermediate representation and the ultimate ranking thresholds.

Thus, the loss formulation (10) can impose a margin on any chosen set of layers of a deep network (including input and hidden layers) by replacing the input with its intermediate representations. It can be adapted as below to incorporate intermediate margins:

$$\sum\_{u \in P\_k, v \in N\_k} \max \left\{ 0, \gamma + \frac{f\_v^k - f\_u^k}{\epsilon + \left\| \nabla\_{h\_l} f\_u^k - \nabla\_{h\_l} f\_v^k \right\|\_q} \right\} \tag{11}$$

where *hl* denotes the output of the *l*th layer (*h*<sup>0</sup> = *I*), *γ<sup>l</sup>* is the margin enforced for its corresponding representation, and  is used to prevent numerical problems.

#### *3.3. Negative Sampling*

The complete calculation of the loss involves *P* × *N* pairwise comparisons, thus having the *O*(*C*2) time complexity. This can cause scalability issues when *C* is large. To make the loss scale linearly to *C*, we sample at most *t* pairs from the Cartesian product. Denoting this by *φ*(*Ik*; *t*) ⊆ *Pk* ⊗ *Nk*, the DlmRl loss formulation becomes

$$\sum\_{\boldsymbol{\Phi}(\boldsymbol{I}\_{k};\boldsymbol{t})} \max \left\{ 0, \gamma + \frac{f\_{\boldsymbol{v}}^{k} - f\_{\boldsymbol{u}}^{k}}{\epsilon + \left\| \nabla\_{\boldsymbol{h}\_{l}} f\_{\boldsymbol{u}}^{k} - \nabla\_{\boldsymbol{h}\_{l}} f\_{\boldsymbol{v}}^{k} \right\|\_{q}} \right\} \tag{12}$$

We set *t* = 100 by default, which achieves a better performance in most cases.

#### **4. Discussion**

We evaluate our method on the VOC2007 [37] and the MS-COCO [38] datasets. For each dataset, we use the standard training/test sets. To evaluate the performances, we show the results for the mean average precision (MAP) [39] and the instance-centric mean average precision (MiAP), which are standard multi-label classification metrics. We compare our DlmRl loss against different loss functions in three scenarios: (a) full-image labels, where only a subset of the images are labeled, but the labeled images have the annotations for all the categories; (b) partial labels [29], where all the images are used but a subset of images only have one positive label; (c) noisy labels [40], where the categories of all images are labeled but some labels are wrong. The experiments are carried out on a single NVIDIA V100 GPU.

#### *4.1. Implementation Details and Baselines*

All the deep models used in our experiments are implemented in PyTorch. ResNet-101 is employed as our classification network, whose weights were pretrained in ImageNet for single-label image classification as the initialization and fine-tune the weights of all layers. Note that we prefer a suitable CNN to more advanced frameworks to focus on the advantages of DlmRl rather than to show state-of-the-art results. We use a stochastic gradient descent (SGD) optimizer for model training with an initial learning rate of 0.1. When the validation loss stops decreasing for 5 epochs, the learning rate delays to one tenth. We stop training when the learning rate drops to 0.0001, which takes less than 20 epochs in most cases.

Since our loss function can be used in a variety of multi-label scenarios, only the traditional classical loss function without complex regularization terms as a comparison method is fair to us. In the experiments, we compare our Dlrml loss against two classic loss formulations, i.e., BCE with Logits Loss (bce) [31] and MultiLabel SoftMargin Loss (slm) [41], whose formulations are shown below:

$$\ell\_{b\&} = -\sum\_{c}^{C} \mathcal{y}\_c^k \log \sigma \left( \mathcal{Y}\_c^k \right) + \left( 1 - \mathcal{Y}\_c^k \right) \log \left( 1 - \sigma \left( \mathcal{Y}\_c^k \right) \right) \tag{13}$$

and

$$\ell\_{\rm slm} = -\sum\_{\mathfrak{c}} y\_{\mathfrak{c}}^k \log \left( (1 + \exp(\mathfrak{y}\_{\mathfrak{c}}^k))^{-1} \right) + (1 - y\_{\mathfrak{c}}^k) \log \left( \frac{\exp(-\mathfrak{y}\_{\mathfrak{c}}^k)}{(1 + \exp(-\mathfrak{y}\_{\mathfrak{c}}^k))} \right) \tag{14}$$

#### *4.2. Results on VOC2007 Dataset*

VOC2007 is a widely used multi-label image classification dataset. It has 9963 images and 20 classes, in which the training set has 5011 images and the test set has 4952 images.

**Full labels:** We randomly sample a subset of the standard training set for training. The proportion is between 10% (10% of training images are used) and 100% (all training images are used). The results of ResNet-101 using different loss functions are shown in Figures 2 and 3, from which we can see: (1) as the number of training samples increases, the performance of all models improves gradually; (2) Our method performs slightly worse than the bce method when only 10% of the training data are available, but this can be viewed as the cost of learning more robust feature representations. As the training data increase, the performance of the DlmRl method is able to maintain the highest level, which is due to the fact that the margin plays a lesser role when the amount of data is small than

when the amount of data is large, illustrating that our method can effectively improve accuracy when dealing with large-scale data, as it can impose the margin in a large amount of data, which is more advantageous compared to other methods in dealing with large amount of data.

**Figure 2.** The figure shows the MAP score (%) of three different loss methods on VOC2007 with full labels. The orange line indicates the accuracy rate using BCE with Logits Loss (bce): the red line indicates the accuracy rate using MultiLabel SoftMargin Loss (slm); and the blue line indicates the accuracy rate using our DlmRl.

**Figure 3.** The figure shows the MiAP score (%) of three different loss methods on VOC2007 with full labels. The orange line indicates the accuracy rate using BCE with Logits Loss (bce); the red line indicates the accuracy rate using MultiLabel SoftMargin Loss (slm); and the blue line indicates the accuracy rate using our DlmRl.

**Partial labels:** We generate an extreme partial dataset by keeping only one positive label per image. The simulation copes with extreme single-label datasets in reality, e.g., ImageNet. If the image has more than one positive label, we randomly select one positive label among the positive labels and switch the other positive labels to negative labels. The proportion of partial images in the standard training set is between 10% (10% of training images only have one positive label) and 90% (90% of the training images only have one positive label). The performances of different loss functions on the partial dataset are

shown in Figures 4 and 5, from which we can see that: (i) As the proportion of partial training samples increases, the performance of all loss functions degrade gradually. (ii) The performance of *bce* loss function drops the fastest and the performance degradation of our loss function is the slowest. (iii) When the fraction exceeds 30%, the performance of our loss function is consistently better than other loss functions. This shows that our DlmRl can cope with extreme datasets very well. In a dataset with almost all single labels, our method has an extremely good performance compared to other methods, which shows that DlmRl has excellent robustness in dealing with datasets with sparse labels. Due to the good robustness of DlmRl to extreme datasets, it is possible to only label the main items of the images when labeling them realistically.

**Figure 4.** The figure shows the MAP score (%) of three different loss methods on VOC2007 with partial labels. The orange line indicates the accuracy rate using BCE with Logits Loss (bce): the red line indicates the accuracy rate using MultiLabel SoftMargin Loss (slm); and the blue line indicates the accuracy rate using our DlmRl.

**Figure 5.** The figure shows the MiAP score (%) of three different loss methods on VOC2007 with partial labels. The orange line indicates the accuracy rate using BCE with Logits Loss (bce): the red line indicates the accuracy rate using MultiLabel SoftMargin Loss (slm); and the blue line indicates the accuracy rate using our DlmRl.

**Noisy labels:** In this experiment, we randomly choose, for each training image, whether to flip its positive/negative label to the other label. The fraction of such flipped labels range from 5% to 20% in increments of 5%. An increment of 5% means that the 5% of labels are wrong during training, while 95% of other labels are clean. The performance of different loss functions on the partial dataset are shown in Table 1, Compared with bce, we observe a substantial improvement in the MAP of 1.98%, 5.07%, 7.41% and 9.85% for the 5%, 10%, 15% and 20% ratio of noisy labels, respectively. from which we can see that: (i) Under all noise ratios, DlmRl is consistently more robust than other methods. (ii) As the noise ratio increases, the performance of DlmRl only slightly decreases. (iii) As the noise ratio increases, the performance of slm degrades the fastest, which reveals the limitation of the traditional large-margin technique.



*4.3. Results on MS-COCO Dataset*

MS-COCO Microsoft is widely used for segmentation, classification, detection and captioning. We use COCO-2014 in our experiments, which has 82,081 training images and 40,137 validation images and 80 object classes. Due to the large scale of this dataset, we conduct only one experiment for each of the three labeled scenarios, i.e., full label, partial label and noisy label. The ratios in the full, partial and noisy label scenarios are randomly set to 10%, 10% and 5%, respectively. From Table 2, we can see that DlmRl can achieve comparable performance against its counterparts on the full labels scenario, but significantly better performance than them on the partial and noisy label scenarios.


#### *4.4. Ablation Study*

In this subsection, we conduct experiments to study the effect of different hyperparameters or components of our loss function on the VOC2007 dataset. To discuss the effect of one hyper-parameter, we conduct experiments on its different values, but keep other hyper-parameters or components fixed.

Figures <sup>6</sup> and <sup>7</sup> show the effect of *<sup>γ</sup>* with values in {101, 102, 103, 104}. The penalty includes the distance that *I<sup>k</sup>* needs to travel to reach the ranking threshold as well as another *γ* distance to travel on the correct side of the ranking threshold to attain the *γ* margin. As can be seen, the performance of different values is very similar, so the classification performance is not very sensitive to *γ*.

**Figure 6.** The figure shows the effect of *γ* on the MAP score using our DlmRl.

**Figure 7.** The figure shows the effect of *γ* on the MiAP score using our DlmRl.

Figures <sup>8</sup> and <sup>9</sup> show the effect of  with values in {10−1, 10−2, ... , 10−6}. As can be seen, a small value of  is very important. When the value is small enough, the classification performance will only change slightly. The  is used to prevent numerical problems. This experimental result is reasonable. When the value of ε is too small, the maximum margin represented by Formula (12) will be too large, and when it exceeds a certain range, the effect of our DlmRl will not be displayed.

The architecture of ResNet-101 consists of four blocks from bottom to up, i.e., Block1, Block2, Block3 and Block4, as well as two fully connected layers. To analyze the effect of imposing a margin on different hidden neural network layers, we conduct experiments on the four different blocks of ResNet-101, respectively. Figures 10 and 11 show the experimental results. As can be seen, it achieves the best MAP score and MiAP score by imposing a margin on Block4.

Figures 12 and 13 show the effect of *q* with values in {1, 2, ∞}. As can be seen, the classification performance is sensitive to this parameter and *q* = ∞ is the best.

**Figure 8.** The figure shows the effect of  on the MAP score using our DlmRl.

**Figure 9.** The figure shows the effect of  on the MiAP score using our DlmRl.

**Figure 10.** The figure shows the effect of imposing a margin on different blocks on the MAP score using our DlmRl.

**Figure 11.** The figure shows the effect of imposing a margin on different blocks on the MiAP score using our DlmRl.

**Figure 12.** The figure shows the effect of *q* on the MAP score using our DlmRl.

**Figure 13.** The figure shows the effect of *q* about MiAP score which using our DlmRl.

According to the above analysis, in the experiments described previously, we set *γ* = 103,  = 10<sup>−</sup>6, *q* = ∞ and impose large margin on Block4 of ResNet-101 by default.

#### **5. Conclusions**

In this paper, we have proposed a novel loss, i.e., DlmRl, for a MLiC task. It is a plug and play loss, and is thus applicable to any network architecture. In contrast to a traditional large margin, the ranking loss encourages only margins at the output layer of a deep neural network, so the proposed loss formulation imposes a margin on any chosen set of layers of a deep network and allows choosing any *<sup>p</sup>* norm (*p* ≥ 1) on the metric measuring the margin between labels— showing a far more flexible and compatible implementation. We design a negative sampling technique to make it more computationally efficient, thus addressing the scalability issues brought by full computation. Experiments on the VOC2007 dataset and the COCO dataset have verified that our DlmRl is better than other methods by applying a margin to the input layer, and our computational efficiency has been greatly improved thanks to the introduction of negative sampling technology. Extensive experiments show that our loss formulation is more robust than traditional loss formulations of MLiC.

**Author Contributions:** Writing—original draft, Z.L.; Writing—review & editing, Z.M. and Y.Z.; Project administration, Y.Z. All authors have read and agreed to the published version of the manuscript.

**Funding:** This research was funded by the National Natural Science Foundation of China (NSFC) grant number 62006098, and the Fellowship of China Postdoctoral Science Foundation grant number 2020M681515.

**Institutional Review Board Statement:** Not applicable.

**Informed Consent Statement:** Not applicable.

**Data Availability Statement:** Not applicable.

**Conflicts of Interest:** The authors declare no conflict of interest.

#### **References**


### *Article* **Dual-Word Embedding Model Considering Syntactic Information for Cross-Domain Sentiment Classification**

**Zihao Lu 1, Xiaohui Hu 2,\* and Yun Xue <sup>2</sup>**


**Abstract:** The purpose of cross-domain sentiment classification (CDSC) is to fully utilize the rich labeled data in the source domain to help the target domain perform sentiment classification even when labeled data are insufficient. Most of the existing methods focus on obtaining domain transferable semantic information but ignore syntactic information. The performance of BERT may decrease because of domain transfer, and traditional word embeddings, such as word2vec, cannot obtain contextualized word vectors. Therefore, achieving the best results in CDSC is difficult when only BERT or word2vec is used. In this paper, we propose a Dual-word Embedding Model Considering Syntactic Information for Cross-domain Sentiment Classification. Specifically, we obtain dual-word embeddings using BERT and word2vec. After performing BERT embedding, we pay closer attention to semantic information, mainly using self-attention and TextCNN. After word2vec word embedding is obtained, the graph attention network is used to extract the syntactic information of the document, and the attention mechanism is used to focus on the important aspects. Experiments on two real-world datasets show that our model outperforms other strong baselines.

**Keywords:** cross-domain sentiment classification; word embedding; GAT

**MSC:** 68T50

#### **1. Introduction**

Sentiment classification is an important task in natural language processing, and it can help people make better decisions in daily life [1,2]. Over the past few decades, many machine learning methods have been introduced for classification tasks, such as logistic regression, collaborative representation, support vector machines, and neural networks [3–7]. With the development of the internet, a large number of user comments and other texts containing sentiment have been generated from different domains. However, the classical sentiment classification methods require that the training and testing data come from the same domain [8,9]. In addition, the training of deep networks relies on a large amount of labeled data, but texts in many domains lack sufficient labeled data. Cross-domain sentiment classification (CDSC) is a promising direction that can make full use of the rich labeled data in the source domain to assist the target domain with the lack of labeled data for sentiment classification.

Traditional word-level vector representations, such as word2vec [10], glove [11], and fastText [12], can use a single vector to represent all possible meanings of a word. This method results in providing the same representation for words that express different sentiment polarities in various domains. In recent years, pre-trained language models, such as ELMO [13] and BERT [14], have been widely used in natural language processing (NLP) tasks because they can obtain contextualized word embedding. Notably, BERT has achieved state-of-the-art results on many NLP tasks because of its strong language understanding capabilities. In cross-lingual tasks, multilingual BERT (mBERT) can share

**Citation:** Lu, Z.; Hu, X.; Xue, Y. Dual-Word Embedding Model Considering Syntactic Information for Cross-Domain Sentiment Classification. *Mathematics* **2022**, *10*, 4704. https://doi.org/10.3390/ math10244704

Academic Editors: Jianping Gou, Weihua Ou, Shaoning Zeng and Lan Du

Received: 9 November 2022 Accepted: 9 December 2022 Published: 11 December 2022

**Publisher's Note:** MDPI stays neutral with regard to jurisdictional claims in published maps and institutional affiliations.

**Copyright:** © 2022 by the authors. Licensee MDPI, Basel, Switzerland. This article is an open access article distributed under the terms and conditions of the Creative Commons Attribution (CC BY) license (https:// creativecommons.org/licenses/by/ 4.0/).

part of its representation space between languages [15]. In addition, the mBERT language model has the ability to transfer syntactic knowledge cross-lingually, and can embed the dependency parse tree of sentences cross-lingually [16]. This shows that Bert parse trees have a strong ability to perform different tasks. However, some problems occur with directly fine-tuning BERT in CDSC tasks [17]. One of the pre-training tasks of BERT is to randomly MASK off 15% of the words, and when the words are filled back, various domains may fill back different words. In addition, because no labeled data exist in the target domain, fine-tuning only by the labeled data in the source domain reduces the performance because of different training and test distributions. Therefore, using BERT or word2vec only to obtain word vector embeddings in CDSC is insufficient. On the other hand, many current models aim to learn transferable semantic information in CDSC to predict the sentiment polarity of the target domain. However, in addition to semantic information, syntactic information is equally important. Therefore, extracting transferable syntactic information is important for CDSC tasks to better help target domain sentiment classification.

To solve the above problems, we propose a dual-word embedding model considering syntactic information for CDSC. The model performs dual-word embedding through BERT and word2vec to obtain rich word embedding information. Different from most previous models that only consider semantic information, we adopt dual-channel to obtain transferable semantic information and syntactic information. Semantic information is obtained by self-attention and TextCNN. Syntactic information is obtained through the graph attention network so that the aspects in the sentence can obtain syntactic information [18]. Then, the attention mechanism is used to pay attention to important aspects so that the syntactic information of aspects can play a role. Finally, domain-invariant features are obtained through adversarial training. The contributions of our study can be summarized as follows:


#### **2. Related Work**

#### *2.1. CDSC*

CDSC aims to utilize the source domain with rich labeled data to help sentiment classification in the target domain without labeled data. The traditional CDSC method needs to manually select pivots. Blitzer et al. [19] proposed the structural correspondence learning (SCL) method. The most frequently used words in both domains are good predictors of source domain labels, so they select the set of pivot features that appear most frequently in both the source and target domains. Pan et al. [20] proposed spectral feature alignment (SFA) for CDSC. They want to associate the source domain with the target domain by aligning pivots with non-pivots. However, manually obtaining domain-invariant features through these traditional methods is a time-consuming and expensive process. With the rise of neural networks in recent years, many scholars have explored the application of deep learning in CDSC tasks. Among them, the domain adversarial neural network (DANN) [21] is explored to learn domain-invariant features in the min-max game between the domain classifier and the feature extractor through adversarial training. Li et al. [22] proposed a hierarchical attention transfer network (HATN) that can automatically capture pivots and non-pivots through hierarchical attention and auxiliary tasks. Zhang et al. [23] designed an interactive attention transfer network (IATN) that applies interactive attention to CDSC, considering the influence of aspects in sentences. Yang et al. [24] proposed a dual-channel mutual learning domain adaptive model. In recent years, BERT has been gradually applied to CDSC because of the advantages of the BERT pre-training model. Du et al. [17] designed a domain-aware BERT (BERT-DAAT) to apply BERT to unsupervised CDSC tasks. Du et al. [25] designed a Wasserstein-based transfer network (WTN) to obtain rich domain-invariant features. Fu et al. [26] paid closer attention to the intra-domain structure, and they proposed domain adaptation with a contractible difference strategy. The successful application of the attention mechanism improves classification accuracy substantially. However, it is difficult to obtain syntactic information using attention. In this paper, we consider adding a graph attention network to obtain transferable syntactic information.

#### *2.2. Graph Attention Work*

Graph neural networks have received extensive attention from scholars in recent years because these networks allow the use of deep learning frameworks on graph structure data [27–30]. At present, many mature neural network models can work on regular network structures. Since the graph convolutional neural network (GCN) [31] was proposed as a deep convolutional learning paradigm for graph structure data, it has filled the gap in the development of deep learning for processing such data. To capture the dependencies between discontinuous and long-distance words in a document, Vashishth et al. [32] used GCN to characterize the dependency tree for each sentence in the document. However, the importance of each node in the graph should be different, and a graph convolutional neural network cannot deal with this situation. Therefore, some researchers have introduced the idea of attention mechanism into the graph convolutional neural network. Veliˇckovi´c proposed [33] graph attention network (GAT), which mainly improves GCN by using the attention mechanism to aggregate the characteristics of discriminated neighbor nodes. Therefore, compared with GCN, GAT can better handle dynamic graphs. Huang et al. [18] used GAT to establish dependencies between words. Although it is common to use GCN or GAT to obtain syntactic information in single-domain tasks, few people extract syntactic information in CDSC tasks.

#### *2.3. Word Embedding*

Word vector representations transform words in natural language into a form that the computer can recognize and understand [34]. We can obtain word vector representations by using word embedding methods, such as word2vec and glove. Nguyen et al. [35] applied a word2vec embedding model to construct a semantic vector for the plot content of each movie. Wang et al. [36] trained their personality classification model on a shared potential feature space by predictive text embedding. Naderalvojoud [37] et al. proposed two methods to create sentiment-aware word embeddings, improving on the pre-trained word embedding of the word2vec and gloVe models.

In recent years, BERT has received a lot of attention because it can learn contextualized word representations. BERT is a bidirectional variant of the multilayer transformer, which further integrates bidirectional representations. Jawahar et al. [38] revealed elements of the English language structure learned by BERT. They also demonstrated that BERT captures phrase-level information at the low layers, syntactic features at the intermediate layers, and semantic features at the high layers. In addition, the information at lower layers is diluted at higher layers. In this paper, we combine word2vec and BERT to obtain rich word vector information. In addition, in order to prevent the low-layer information from being diluted at the high-layer, we use the weighted sum of all layer information of BERT as the input vector.

#### **3. Methodology**

In this section, we introduce the framework of DWE in technical detail. First, we describe the problem and provide a model structure. Then, the training strategy is detailed.

#### *3.1. Problem Definition*

In the task of CDSC, we are given two domains, *Ds* and *Dt*, which denote a source domain and a target domain, respectively. A set of labeled data {*Xs*,*Ys*} is used in *Ds*, where {*Xs*,*Ys*} <sup>=</sup> \* *xs <sup>i</sup>* , *<sup>y</sup><sup>s</sup> i* +*Ns <sup>i</sup>*=<sup>1</sup> presents *Ns* labeled samples in *Ds*. We also have a set of unlabeled data {*Xt*} in *Dt*, where {*Xt*} <sup>=</sup> \* *xt i* +*Nt <sup>i</sup>*=<sup>1</sup> presents *Nt* unlabeled samples in *Dt*. The goal of the CDSC task is to utilize the source domain with rich labeled data to assist the target domain lacking labeled data for sentiment classification.

#### *3.2. Model Structure*

As shown in Figure 1, DWE mainly contains three parts: feature extraction module, domain discriminator, and sentiment classifier. The feature extraction module uses dual channels to obtain semantic information and syntactic information. The domain discriminator obtains domain-invariant features. The sentiment classifier uses the softmax activation function to obtain the probability of the sentiment label.

**Figure 1.** Model architecture.

#### *3.3. Feature Extraction*

To get rich word embedding information, we use BERT and word2vec to obtain dual word embedding. After obtaining different word embeddings, a dual channel is formed to extract transferable semantic information and syntactic information.

#### 3.3.1. Bert Semantic Channel

In this channel, we mainly extract semantic information. We first use BERT to obtain word vectors. To prevent the loss of some information, unlike in the general final hidden state using the BERT structure, we apply an approach similar to that by Du et al. [25], using the weighted sum of all hidden states as the input vector. We define the *n*th hidden state of the *m*th layer as *h<sup>m</sup> <sup>n</sup>* . We suppose that a document contains *S* sentences with *k* words, and *wi* is the *i*th word of the input document. *wi* is tokenized to *q* BPE (byte pair encoding) tokens *wi* = \* *b*1 *<sup>i</sup>* , *<sup>b</sup>*<sup>2</sup> *<sup>i</sup>* ,...,..., *<sup>b</sup><sup>n</sup> i* + . The word vector obtained by BERT can be defined as

$$e\_i^B = \sum\_{m=1}^L \alpha\_m \cdot \frac{\sum\_{n=1}^q h\_n^m}{q} \tag{1}$$

where *α<sup>m</sup>* and L are the weight coefficients of layer *m* and the number of hidden state layers of BERT, respectively. BiGRU is the variant of BiLSTM, which has the ability to learn long-term dependencies. We can use BiGRU to build sequential information about words or sentences. Thus, we then input the word vector into BiGRU to obtain the hidden states

$$h\_i^B = BiGRI\left(e\_i^B\right) \tag{2}$$

Different words in a sentence have different effects on sentence sentiment because these words express different semantic information. The attention mechanism can pay attention to the words that play an important role in sentence sentiment according to attention coefficient. In this paper, we use self-attention to calculate word-to-word associations in sentences, which can focus on words that have a stronger impact on sentence sentiment. Attention scores were calculated as follows:

$$\mathbf{g}\_i^B = \tanh\left(\mathcal{W} \* h\_i^B + b\right) \tag{3}$$

where *W* and *b* represent the learnable weight matrix and bias in the network, respectively.

Furthermore, we normalized the attention scores by using the softmax activation function to generate the attention coefficients *α<sup>B</sup> <sup>i</sup>* for each word

$$\alpha\_i^B = \frac{\exp\left(\mathcal{g}\_i^B\right)}{\sum\_{i=1}^n \exp\left(\mathcal{g}\_i^B\right)}\tag{4}$$

The attention coefficient is combined with the hidden state obtained by BiGRU to obtain the sentence vector *s<sup>B</sup>*

$$s^B = \sum\_{j=1}^k a\_j^B \cdot h\_j^B \tag{5}$$

where · indicates the element-wise product. After obtaining sentence vectors, TextCNN [39] is used to further extract important semantic information that mainly includes convolution layer and pooling layer. First, we input the sentence vector to the convolution layer and the convolution operation involves the filter *wcnn*

$$\mathbf{c}^{B} = F\left(w\_{cmn} \diamond s^{B} + b\_{cmn}\right) \tag{6}$$

where ◦ represents the convolution operation, *bcnn* is the bias term, and *F* is a nonlinear function such as Relu. Then, max pooling is performed to retain important features. Finally, dropout prevents overfitting to obtain the sentence representation of the semantic channel. The relevant formulas are the following:

$$\mathfrak{c}\_p^B = \text{Maxpooling}\left(\mathfrak{c}^B\right) \tag{7}$$

$$d^B = dropout\left(c\_p^B\right) \tag{8}$$

3.3.2. Word2vec Syntax Channel

In this channel, we first use word2vec to obtain the word vector representation

$$e\_i^w = word2vec(w\_i) \tag{9}$$

Then, input the word vector into BiGRU to extract the sentence representation. The hidden output of BiGRU can be expressed as follows:

$$h\_i^w = BiGRI(e\_i^w) \tag{10}$$

To obtain syntactic information, the syntax dependency tree of the given sentence is built in advance, and then the tree structure is converted into a graph structure in which each node represents a word. Given a dependency graph with *N* nodes, the node representation is computed by aggregating the hidden states of the neighborhood. After *l* layers of GAT, the last layer outputs the syntactic representation. The output of the ith node at layer *l* is defined as *g<sup>l</sup> i* , and *g*<sup>0</sup> *<sup>i</sup>* indicates the initial node status, *<sup>g</sup>*<sup>0</sup> *<sup>i</sup>* = *<sup>h</sup><sup>w</sup> <sup>i</sup>* . The node update process is as follows:

$$\boldsymbol{\alpha}\_{i\bar{j}}^{l} = \operatorname{leaky Relu}\left(\boldsymbol{\alpha}^{l^T}\left(\boldsymbol{\mathcal{W}}\_{\mathcal{S}}^{l}\boldsymbol{\mathfrak{g}}\_{i}^{l}||\boldsymbol{\mathcal{W}}\_{\mathcal{S}}^{l}\boldsymbol{\mathfrak{g}}\_{j}^{l}\right)\right) \tag{11}$$

$$\alpha\_{ij}^l = \frac{\exp\left(e\_{ij}^l\right)}{\sum\_{k \in N(k)} e\_{ik}^l} \tag{12}$$

$$\mathbf{g}\_{i}^{l+1} = \sigma \left( \sum\_{j \in \mathcal{N}(i)} a\_{ij}^{l} \mathcal{W}\_{\mathcal{S}}^{l} \mathbf{g}\_{i}^{l} \right) \tag{13}$$

where *W<sup>l</sup> <sup>g</sup>* and *α<sup>l</sup> T* are trainable weight matrices and weight vectors, respectively. represents vector concatenation. *e<sup>l</sup> ij* is the raw attention score between the ith and jth nodes. *N*(*i*) is the set of all adjacent nodes. *α<sup>l</sup> ij* is the normalized attention weight. *σ* denotes a Relu activation function. For simplicity, we can write such feature propagation process as

$$\mathcal{g}\_{i}^{l+1} = \operatorname{GAT} \left( \mathcal{g}\_{i\prime}^{l}, A\_{\prime} \theta\_{l} \right) \tag{14}$$

where *A* is the graph adjacent matrix and *θ<sup>l</sup>* is the set of parameters at layer *l*. Finally, we input the syntactic representation into BiGRU and Attention. BiGRU can build the long-term dependencies of sentences in a document. Attention mechanism can make the syntactic information of important aspects in syntactic representation play a more critical role. Thus, we obtain the final representation of the syntactic channel:

$$H\_i^w = BiGRlI(g\_i^w) \tag{15}$$

$$\alpha\_i^w = \frac{\exp\left(\tanh\left(\mathcal{W}\_{\text{w}}H\_i^{\text{w}} + b\_{\text{w}}\right)\right)}{\sum\_{i=1}^{\mathcal{V}} \exp\left(\tanh\left(\mathcal{W}\_{\text{w}}H\_i^{\text{w}} + b\_{\text{w}}\right)\right)}\tag{16}$$

$$d^w = \sum\_{j=1}^k a\_j^w \cdot H\_j^w \tag{17}$$

where *α<sup>w</sup> <sup>i</sup>* , · , *Ww* and *bw* represent the attention weight, the element-wise product, the learnable weight matrix, and bias in the network, respectively.

#### 3.3.3. Final Document Representation

The final document representation is obtained by concatenating the document representation of the two channels as follows:

$$d = \left[ d^B, d^w \right] \tag{18}$$

#### *3.4. Sentiment Classifier*

The ultimate goal of our task is to predict sentiment labels. In this module, we use the softmax activation function to obtain the sentiment prediction label for the document

$$y = softmax(\mathcal{W}\_y d + b\_y) \tag{19}$$

where *Wy* and *by* represent the learnable weight matrix and bias, respectively.

#### *3.5. Domain Discriminator*

The purpose of the domain discriminator (D) is to enable the feature extractor (FE) to learn domain-invariant representations. We consider using adversarial training. The domain discriminator tries to find out which domain the document vector comes from, while the feature extractor aims to deceive the domain discriminator so that it cannot distinguish which domain the document comes from and achieve the purpose of domain information transfer. The domain discriminator regards the document representation obtained by the feature extractor as input and outputs the probability that the document comes from the source domain. If a document belongs to the source domain, we set *ri* = 1. For the target domain, we set *ri* = 0. To better solve this problem, we introduce a gradient reversal layer (GRL) that can reverse the gradient direction during training. We can treat the gradient reversal layer as a pseudo function *G*(*x*). Through the domain discriminator, we can obtain domain-invariant features. Formally, the domain discriminator performs a min-max game to optimize the parameters Θ*FE* and Θ*<sup>D</sup>* as follows:

$$
\tilde{d} = G(d) \tag{20}
$$

$$y\_d' = \text{softmax}\left(\mathcal{W}\_d \vec{d} + b\_d\right) \tag{21}$$

$$\Theta\_{FE}\Theta\_D = \underset{\Theta\_{FE}}{\arg \max} \underset{\Theta\_D}{\min} \, L\_{\text{dom}} \tag{22}$$

$$L\_{dom} = -\left(r\_i \ln y\_d' + (1 - r\_i) \ln(1 - y\_d')\right) \tag{23}$$

where *Ldom*, Θ*FE*, and Θ*<sup>D</sup>* represent the domain loss, parameters of the feature extractor, and parameters of the domain discriminator, respectively.

#### *3.6. Training Strategy*

We apply the cross-entropy loss function to the sentiment classifier to obtain the sentiment classification loss

$$L\_{\rm sen} = -\left(y'Iny + (1 - y')In(1 - y)\right) \tag{24}$$

where *y* represents the ground truth of the sentiment label. Furthermore, we obtain our total loss function

$$L\_{\text{total}} = L\_{\text{sen}} + L\_{\text{dom}} + \rho L\_{\text{reg}} \tag{25}$$

$$L\_{\text{reg}} = \lambda \|\theta\|^2 \tag{26}$$

where *Lreg*, *ρ*, *λ*, *θ* represents the *L*<sup>2</sup> regularization term which can avoid overfitting, regularization parameter, hyperparameters, and all parameters in the network, respectively. The regularization term can automatically weaken unimportant feature variables, automatically extract important feature variables from many feature variables, and reduce the magnitude of feature variables.

#### **4. Experiment**

#### *4.1. Datasets*

To verify the effectiveness of the proposed model, we used two datasets which are obtained from Amazon product reviews. Dataset 1 has been widely used in CDSC tasks. It contains reviews from four different domains: Books (B), DVDs (D), Electronics (E), and Kitchen (K). A total of 2000 labeled data are in each domain, consisting of 1000 positive reviews and 1000 negative reviews. We selected 800 positive and 800 negative reviews in the source domain as the training data; 1600 in the target domain for domain classification; and the remaining 200 positive reviews and 200 negative reviews in the target domain as the test data. Table 1 records the details of Dataset 1.

Dataset 2, constructed by He et al. [40], contains data for three sentiment labels, namely, positive, neutral, and negative, so this dataset is more convincing. Dataset 2 also contains data from four domains: Book (BK), Beauty (BT), Music (M), and Electronics (E). Each domain has two types of data: Set 1 and Set 2. Set 1 is balanced, with 2000 data for each sentiment label, while Set 2 is unbalanced. For Dataset 2, we choose processing similar to that used by Du et al. [25], using balanced Set 1 as the training data of the source domain, and using unbalanced data Set 2 as the training data of the target domain. We selected 1200 reviews from the training set of the source domain as the development set. The balanced data Set 1 from the target domain is used as the test set. Table 2 presents an overview of the datasets.



**Table 2.** Statistics of Dataset 2.

#### *4.2. Experiment Setup*

In the experiment, we use the common word2vec and BERT to obtain dual-word embedding. First, we use 300-dimensional word2vec vectors as one of the word embeddings, which are trained on 100 billion words from Google News. Then, we fine-tune it during the training. We use uniform distribution U (−0.25,0.25) to randomly initialize words outside the vocabulary. In addition, we use BERT with 12 layers, 768 hidden units, 12 self-attention heads, and 110 million parameters as another word embedding. The dimension of the attention vector is set to 200. The dimension of the feature representation in each field and the maximum word number of every review are set to 200. The weight matrix in the network is randomly initialized from the uniform distribution U (−0.01,0.01). The dropout rate is 0.5 to prevent overfitting. The number of GAT layers is set to 3, and Adam algorithm is used as the optimizer.

#### *4.3. Experimental Results*

Following previous studies, we apply the accuracy rate as the evaluation standard. The accuracy rate is the percentage of correctly classified data in the total data. The best

results are highlighted in bold. We compare the proposed model DWE with some classic baselines as follows:


Table 3 records the classification accuracy of different models on Dataset 1. The results show that our proposed model DWE achieves the best performance on 11 cross-domain pairs. Our model outperforms DANN by 12.24%, AMN by 9.64%, DAS by 9.44%, HATN by 6.74%, IATN by 5.64%, WTN by 1.14%, and PTASM by 0.44% on average. DAS uses entropy minimization and self-integration methods to refine its classifier, which improves the experimental results compared with DANN and AMN. The addition of attention has greatly improved HATN and IATN compared with DAS, reflecting the effectiveness of the attention mechanism. Both WTN and PTASM have applied BERT to CDSC, which has been greatly improved compared with previous methods. WTN is based on Wasserstein distance as a domain discrepancy learning module, while PTASM uses an attention transfer mechanism and hierarchical attention to improve target domain classification. Different from previous methods, our proposed model uses dual-word embedding to make up for the deficiency of single word embedding. Our model also considers both transferable semantic information and syntactic information, which may be the reason for the improvement of our model.


**Table 3.** Classification accuracy of various models on Dataset 1.

Furthermore, we also compare our proposed model DWE with other baseline models on Dataset 2 and conduct ablation experiments simultaneously.

Table 4 records the classification accuracy on Dataset 2. We can see that our model DWE has the best performance among all cross-domain pairs. Our model outperforms AuxNN by 9.5%, DAS by 6.5%, and WTN by 3.8% on average, which demonstrates the effectiveness of our proposed model. On the other hand, after removing BERT word embedding and word2vec word embedding, the average performance decreases by 8.9% and 1.9%, respectively, which demonstrates the validation of the proposed dual-word embedding. The possible reason is that the single word embedding causes the model to lose part of the information, especially after removing the BERT word embedding, where a large amount of context-related information is lost.


**Table 4.** Classification accuracy of various models on Dataset 2.

#### *4.4. Case Study*

To demonstrate the role of the proposed DWE model, we selected a piece of data from BK as our case analysis and compared it with WTN when BT was the source domain and BK was the target domain. Figure 2 shows the attention weights of the DWE and WTN for the sample. The darker the color, the higher the attention weight.

**Figure 2.** Case Study of Book Domain.

Figure 2 shows that the WTN model focuses more on "alot" and "love," while our proposed DWE model focuses more on the most important sentiment word "profound". The main reason may be that the syntactic module we added allows "profound" and "book" to establish a syntactic connection, thereby focusing on the more important sentiment word.

#### *4.5. Visualization of Feature Representation*

In this section, we visualize the data in two cross-domain pairs, namely, M → BK and BT → E in Dataset 2. Figure 3 shows the feature representation of M as the source domain and BK as the target domain. Figure 4 shows the feature representation of BT as the source domain and E as the target domain.

**Figure 3.** Visualization of feature representation on M → BK.

**Figure 4.** Visualization of feature representation on BT → E.

Figures 3 and 4 show that the sample features of two different domains are aligned. No obvious boundary exists between the two domains, and distinguishing between them is difficult. This condition shows that the two domains can share the learned feature representation, and the information from the source domain can be transferred to the target domain.

#### **5. Conclusions**

In this paper, we proposed a dual-word embedding model considering syntactic information for CDSC. The dual-word embedding is obtained through BERT and word2vec; then, the transferable syntactic information and semantic information are obtained by combining dual channel and adversarial training. Experiments showed that our model achieved better results on two real-world datasets. In future work, we will apply the model to cross-domain aspect-based sentiment analysis.

**Author Contributions:** Conceptualization, Z.L. and Y.X.; methodology, Z.L.; formal analysis, Z.L.; writing—original draft preparation, Z.L.; writing—review and editing, X.H.; supervision, Y.X. and X.H.; funding acquisition, X.H. All authors have read and agreed to the published version of the manuscript.

**Funding:** This work was supported by the Characteristic Innovation Projects of Guangdong Colleges and Universities (Nos. 2018KTSCX049) and the Science and Technology Plan Project of Guangzhou under Grant Nos. 202102080258 and 201903010013.

**Institutional Review Board Statement:** Not applicable.

**Informed Consent Statement:** Not applicable.

**Data Availability Statement:** Not applicable.

**Conflicts of Interest:** The authors declare no conflict of interest.

#### **References**


### *Article* **Keyword-Enhanced Multi-Expert Framework for Hate Speech Detection**

**Weiyu Zhong †, Qiaofeng Wu †, Guojun Lu, Yun Xue and Xiaohui Hu \***

**\*** Correspondence: huxh@scnu.edu.cn

† These authors contributed equally to this work.

**Abstract:** The proliferation of hate speech on the Internet is harmful to the psychological health of individuals and society. Thus, establishing and supporting the development of hate speech detection and deploying evasion techniques is a vital task. However, existing hate speech detection methods tend to ignore the sentiment features of target sentences and have difficulty identifying some implicit types of hate speech. The performance of hate speech detection can be significantly improved by gathering more sentiment features from various sources. In the use of external sentiment information, the key information of the sentences cannot be ignored. Thus, this paper proposes a keywordenhanced multiexperts framework. To begin, the multi-expert module of multi-task learning is utilized to share parameters and thereby introduce sentiment information. In addition, the critical features of the sentences are highlighted by contrastive learning. This model focuses on both the key information of the sentence and the external sentiment information. The final experimental results on three public datasets demonstrate the effectiveness of the proposed model.

**Keywords:** hate speech detection; contrastive learning; multi-task learning

**MSC:** 18C50

#### **1. Introduction**

With the widespread use of social media and mobile internet platforms, the increasing speed of online speech and the freedom to publish it have led to the malicious prevalence of hate speech. Exposure to such language may cause negative effects on the mental health of victims [1], which may lead to severe social problems. To prevent further negative effects, authorities need to intervene in detecting hate speech online. Thus, the rapid and accurate automatic detection of hate speech has become a popular topic of research in the field of natural language processing. Hate speech detection has gained attention in recent years.

Figure 1 shows an example in which the first sentence contains the hate term *fucking aids* which is an obvious form of offensive hate speech, while the second sentence without obvious hate words or semantics is a positive sentence.

!

 

"!- 

**Figure 1.** An example sentence from the Ruddit dataset. The offensive score ranges between −1 (maximally supportive) and 1 (maximally offensive).

An approach to hate speech detection using deep learning has been the focus of most of the research in recent years [2–5]. However, previous research disregarded the sentiment features of target detection sentences and only used pre-trained models or deeper neural networks to obtain semantic features. Wang, C. [6] showed that the semantics of hate speech bear a strong tendency toward negative sentiment. To overcome this problem,

**Citation:** Zhong, W.; Wu, Q.; Lu, G.; Xue, Y.; Hu, X. Keyword-Enhanced Multi-Expert Framework for Hate Speech Detection. *Mathematics* **2022**, *10*, 4706. https://doi.org/10.3390/ math10244706

Academic Editors: Jianping Gou, Weihua Ou, Shaoning Zeng and Lan Du

Received: 12 November 2022 Accepted: 8 December 2022 Published: 11 December 2022

**Publisher's Note:** MDPI stays neutral with regard to jurisdictional claims in published maps and institutional affiliations.

**Copyright:** © 2022 by the authors. Licensee MDPI, Basel, Switzerland. This article is an open access article distributed under the terms and conditions of the Creative Commons Attribution (CC BY) license (https:// creativecommons.org/licenses/by/ 4.0/).

School of Electronics and Information Engineering, South China Normal University, Foshan 528225, China

recent studies have proposed the use of multi-task learning (MTL), which improves the performance of hate speech detection by using sentiment information [7]. Transfer learning is the process of transferring generalizable knowledge gained from training data to the target task. MTL is a type of transfer learning that involves learning several related tasks simultaneously, allowing these tasks to share information during the learning process, and utilizing the correlation between various tasks to enhance the model's performance and generalization capacity on each task. Kapil, P. [8] proposed a deep MTL framework to exploit useful information from several related classification tasks to perform hate speech detection; this framework uses a hard parameter-sharing approach that is prone to negative transfer. Zhou, X. [9] used multiple feature extraction units to share multi-task parameters so that the model can perform sentiment knowledge sharing. Then, gated networks were used to fuse features for hate speech detection. This model employed a soft parameter sharing method by dividing a single expert into multiple experts, thereby mitigating the negative transfer problem caused by hard parameters.

Although hate speech detection has achieved good performance in recent years, the following problems remain: (1) The latest multi-task framework used in hate speech detection is soft parameter sharing [9], where all experts share all tasks, but the tasks of hate speech detection and sentiment analysis have both positive and negative correlations. Positive correlations are parameter relationships that are beneficial to the fit of the primary task, and conversely, negative correlations are not beneficial. If the negative correlation parameters between tasks are not separated, some noise occurs as part of the tasks, which leads to negative transfer. Moreover, when using multiple experts, the simple gated networks cannot effectively fuse and filter the different information because the experts have abundant information from different tasks. (2) Current work lacks the ability to extract critical information (e.g., keywords) from sentences [5]. It cannot effectively identify different types of hate words, such as profanities, nor identify the association between certain identity terms and offensive statements. Certain identity terms (especially those involving minority groups) appear mainly in texts that are offensive [10], such as the sentence *"This is also the reason that so many of Obama's policies are being overturned/undone, it's just because the Black Guy did them."* has no conspicuous hate words, but rather racial discrimination through the identity term *Black*.

To solve the aforementioned problems, we propose the following approaches. **(1) For the first problem,** we are inspired by the recent progressive layered extraction (PLE) model [11] and gated network research [12]. We divide feature extraction units (e.g., expert modules) into a shared part and task-specific parts. This approach strengthens the independent features of the tasks themselves and better reduces the negative transfer caused by weakly correlated task-sharing parameters. Moreover, we design a feature-filtering gate that can better fuse and filter the information of multiple expert modules. **(2) To solve the second problem,** we propose a solution inspired by a recent contrastive learning model [13]. Our model applies contrastive learning to English hate speech detection by using a swearing dictionary and an identity term dictionary to construct positive and negative examples. This result allows the model to be more sensitive to the critical words so that it can learn the association between various types of hate words or identity term words and offensive statements. In summary, the contributions of our study are as follows:


#### **2. Related Work**

Recently, researchers have widely studied automatic hate speech detection. In this section, we review related work on deep-learning-based methods for hate speech detection, especially MTL-based methods, as well as related work on contrastive learning.

Recently, deep-learning-based approaches have achieved considerable success in hate speech detection. Ref. [14] proposed a transformed word embedding model (TWEM), which balances high performance while achieving a simple structure. Ref. [3] proposed a deep neural network structure (combining CNN and GRU) as a feature extractor to learn the semantic features of hate speech. Ref. [4] built a large-scale dataset using hate speech and its reactions and used the pre-trained language model GPT-2 to detect hate speech. Ref. [5] created the first English Reddit comment dataset with fine-grained, real-valued scores and used the pre-trained model HateBERT to detect hate speech. Clearly, deep learning models can extract underlying semantic features of text, which provide the most direct clues to detect hate speech.

Transfer learning can bring more useful information to hate speech detection, and common transfer learning methods include multi-task learning and knowledge distillation [15]. Knowledge distillation aims at knowledge transfer through a wide network (teachers) to a small network (students). Multi-task learning aims at training multiple related tasks and sharing information between tasks at the same time. In recent years, some results have been achieved in the field of hate speech detection using multi-task learning [7]. Ref. [16] proposed a theoretical framework for hate speech type detection that includes fuzzy multi-task learning. Ref. [17] proposed an MTL approach based on the pre-trained model BERT for hate speech detection. Ref. [8] proposed a deep MTL framework to improve the performance of hate speech detection by exploiting useful information from multiple related classification tasks. Ref. [9] proposed a hate speech detection framework based on sentiment knowledge sharing. The preceding studies show that MTL can exploit the relevance between sentiment analysis tasks and hate speech detection tasks, which improves model performance and generalization in hate speech detection.

In addition, some optimization algorithms [18,19] have recently been proposed to obtain better classification results and semantic representations, and contrastive learning is one of them. Contrastive learning aims to learn effective representations by pulling semantically similar sentences together and pushing dissimilar sentences apart [20]. Several recent approaches use contrastive objectives to obtain different views from data augmentation or different copies of the model [21–24]. For example, [24] proposed ConSERT, a Contrastive Framework for Self-Supervised SEntence Representation Transfer, which employs contrastive learning to fine-tune BERT in an unsupervised manner. SimCSE [25] uses the simplest idea of applying only the standard dropout as noise to obtain different outputs of the same sentence, thereby forming positive instances. We propose the use of contrastive learning for hate speech detection, which increases the sensitivity of the model to key information of the sentence and improves the performance of the task.

#### **3. Methodology**

In this section, our model keyword-enhanced multi-expert framework for hate speech detection (KMT) is presented. This model exploits critical information of the sentence and external sentiment information to improve hate speech detection.

The general architecture of KMT is shown in Figure 2. The framework consists of four modules: **(1) Textual input module.** The bottom of the figure shows the textual input module, where the pre-trained model BERT or HateBERT is used to encode the input sentences and generate contextually and semantically integrated input vector *x*; **(2) Multitask learning module.** The top left of the figure shows the multi-task learning module, where we use the multi-task learning framework to interact sentiment information and hate information, and learn the shared features and task-specific features to assist hate speech detection using sentiment information; **(3) Feature-filtering module.** Gate of the figure is the feature-filtering module, which is used to filter and fuse the features outputted by expert

modules to select the important information of sentiment and hate speech; **(4) Contrastive learning module.** The top right of the figure shows the contrastive learning module, which extracts critical information within the sentences to improve the sensitivity of the model to sentence keywords. Finally, the MTL and contrastive learning modules are jointly trained.

**Figure 2.** The overall architecture of our proposed Keyword-enhanced Multi-expert Framework for Hate Speech Detection (KMT).

Given the input text *s* = {*w*1, *w*2,..., *wn*}, *n* is the length of the text *s*. We feed the sequence [*CLS*]*s*[*SEP*] to the BERT or HateBERT encoder in the Textual input module to obtain the input vector *x* with contextual information. Subsequently, *x* is taken as input to both multi-task learning module and contrastive learning module. In multi-task learning, the hate information and sentiment information in *x* are interacted by shared expert and task-specific expert modules, the features are then fused and filtered using a feature-filtering gate, and finally the hate speech detection is performed using the tower containing the classification layer. In the contrastive learning module, positive and negative examples are generated by masking *x*. Subsequently, the model is enabled to focus on key information in the sentences by bringing *x* closer to positive examples and away from negative examples. More details of each module are shown as follows.

#### *3.1. Multi-Task Learning Module*

Due to the diversity of language, insulting meanings in many sentences are implicit, causing difficulty in determining whether a sentence is offensive or not. For example, the sentence *"These guys are all a bunch of pigs."* does not contain an explicitly hateful word, but the sentence still constitutes hate speech. Although the word *pig* is neutral, most people associate it with foolishness and clumsiness. Thus, likening guys with pigs is demeaning to

the former. The secret to effectively judging sentences is grasping emotional common sense. The sentence *"He's a fucking good player."* contains the obvious hate word *fucking*. However, in this case, *fucking* is merely an adverb of level used to indicate excitement; hence, the sentence does not constitute hate speech. From the preceding two examples, we can see that although hate speech tends to contain hate words, achieving better results in detecting it by using only the hate information of the sentence itself is difficult. To introduce external sentiment information, we combine the generic sentiment dataset and then interact the information from the sentiment dataset and the hate dataset using the MTL approach, which improves the performance of hate speech detection.

In MTL frameworks, the problem of overfitting is fundamentally reduced due to extensive use of the shared experts layer structure. However, the effectiveness of the framework may be affected by the seesaw phenomenon and negative migration problem because of the differences between tasks and data distribution [11]. Thus, we use the PLE framework structure [11]. As shown in Figure 2, the model is divided between task-specific tower structures at the top and expert modules at the bottom. The number of Experts in each expert module is the hyperparameter to be tuned. Each expert module comprises numerous sub-networks known as Experts. The shared experts in PLE are responsible for extracting shared features, while the task-specific experts extract task-specific features. Each tower network extracts information from the shared experts and its own task-specific experts. Our expert modules and tower networks consist of feed-forward neural networks. Specifically, when the model performs gradient backpropagation, it changes the parameters in the expert modules. As the output of the task-specific expert modules is only passed to the tower of their own tasks, their parameters are only affected by their own task gradients. By contrast, the shared expert modules have parameters that are affected by the mixed gradients of all tasks because the output is passed to the towers of all tasks.

In the MTL module, features are extracted using the shared experts *E<sup>T</sup> <sup>s</sup>* and the task *k s* specific experts *E<sup>T</sup> <sup>k</sup>* . Then, the extracted features are concatenated to form *<sup>S</sup>k*(*x*) as Equations (1)–(3):

$$E\_k^T = \left[ E\_{(k,1)'}^T E\_{(k,2)'}^T \cdots \right] E\_{(k,m\_k)}^T \tag{1}$$

$$E\_s^T = \begin{bmatrix} E\_{(s,1)'}^T E\_{(s,2)'}^T \cdots \ \_ , E\_{(s,m\_s)}^T \end{bmatrix} \tag{2}$$

$$S^k(\mathbf{x}) = \begin{bmatrix} E\_k^T \ E\_S^T \end{bmatrix}^T \tag{3}$$

where *x* is the input vector, *ms* and *mk* are the number of sub-networks in the shared experts *E<sup>T</sup> <sup>s</sup>* and task *k s* specific experts *E<sup>T</sup> <sup>k</sup>* , *<sup>E</sup><sup>T</sup>* (*k*,*mk* ) and *<sup>E</sup><sup>T</sup>* (*s*,*ms*) are the sub-networks in task *k s* specific experts and shared experts, respectively. The features *Sk*(*x*) of the shared experts and task *k s* specific experts are selectively fused through a feature-filtering gate (Gate). The filtered features of task *k* are formulated as Equation (4):

$$\mathbf{G}^k(\mathbf{x}) = \mathbf{G}\mathbf{a}\mathbf{t}\mathbf{e}\left(\mathbf{x}, \mathbf{S}^k(\mathbf{x})\right) \tag{4}$$

Lastly, the task *k* prediction using the tower network is:

$$O^k(\mathbf{x}) = f^k\left(G^k(\mathbf{x})\right) \tag{5}$$

where *<sup>f</sup> <sup>k</sup>*(·) stands for the task *<sup>k</sup> s* tower network, which consists of feed-forward neural networks as Equation (5).

#### *3.2. Feature-Filtering Module*

The multiple expert setting in MTL enables better interaction of affective and hate information, but because multiple experts have a large amount of information, a structure is needed for selective fusion. Thus, we are inspired by the research on gating modules [12] to design a feature-filtering module that not only better fuses the information between experts but also reduces the noise. As shown in Figure 2, the input vector *x* is used as a selector to obtain useful information on the selected vector (e.g., the output *Sk*(*x*) of the experts) as follows Equations (6)–(9):

$$\mathbf{g}^k(\mathbf{x}) = \mathcal{W}^k\_\mathcal{J} \mathbf{x} \tag{6}$$

$$\text{parallel:}\quad p^k(\mathbf{x}) = \frac{\mathbf{S}^k(\mathbf{x}) \cdot \mathbf{x}}{\mathbf{x} \cdot \mathbf{x}} \mathbf{x} \tag{7}$$

$$\text{orthogonal:} \quad \sigma^k(\mathbf{x}) = S^k(\mathbf{x}) - p^k(\mathbf{x}) \tag{8}$$

$$G^k(\mathbf{x}) = \text{concat}\left(g^k(\mathbf{x})o^k(\mathbf{x}), \left(1 - g^k(\mathbf{x})\right)p^k(\mathbf{x})\right) \tag{9}$$

where *W<sup>k</sup> <sup>g</sup>* <sup>∈</sup> *<sup>R</sup>*(*mk*+*ms*)*<sup>d</sup>* is a parameter matrix, *<sup>d</sup>* is the dimension of the input vector, and *gk*(*x*) is the weight vector for task *k* obtained by a linear transformation. *Sk*(*x*) is decomposed into an orthogonal component and a parallel component. The parallel component *pk*(*x*) is a projection of *Sk*(*x*) onto *x*, which contains part of the information of *x*. On the contrary, *ok*(*x*) is orthogonal to *x*, and therefore contains new information. Specifically, if *x* is the hate speech input, *pk*(*x*) is the part of *Sk*(*x*) that contains hate speech information, and *ok*(*x*) is the part of *Sk*(*x*) that contains sentiment information, then *Gk*(*x*) represents the fusion of these two components. *gk*(*x*) is used to regulate the composition of both components to obtain the optimal fusion.

#### *3.3. Contrastive Learning Module*

As the pre-trained model lacks the ability to grasp critical word information from sentences, it cannot effectively distinguish between different types of hate words and cannot identify the relationship between certain identity terms and offensive statements. Currently, contrastive learning demonstrates excellent ability in acquiring and distinguishing crucial knowledge by focusing on positive examples and comparing negative examples, which has resulted in considerable advances in many tasks. Our goal is to make our model more sensitive to the essential words within a body of text. To this end, we use a contrastive learning module to focus on the positive examples while pushing the negative ones away, allowing the model to more effectively distinguish between important and minor information. To create a positive example *xp*, we mask each non-key token representation in the input vector *<sup>x</sup>* as a constant vector *<sup>m</sup>* ∈ *<sup>R</sup><sup>d</sup>* where this constant is equal to 1e-6. This method allows the sentence to combine key information and eliminate unimportant words. To obtain the negative example *xn*, we simultaneously employ a similar method to mask the key token representation in *x* as *m*.

Thereafter, we model *x*, *xp*, and *x<sup>n</sup>* separately using the feed-forward neural networks with the following formulation Equations (10)–(12):

$$\mathcal{L} = f(\mathbf{x}) \tag{10}$$

$$\mathbf{c}^p = f(\mathbf{x}^p) \tag{11}$$

$$\mathbf{x}^{n} = f(\mathbf{x}^{n}) \tag{12}$$

where *f*(·) denotes the feed-forward neural networks. We then compute the cosine similarity of the positive and negative examples as follows Equation (13):

$$\text{sim}\left(c^1, c^2\right) = \frac{c\_1^T c\_2}{||c\_1|| \cdot ||c\_2||}\tag{13}$$

where sim *c*1, *c*<sup>2</sup> denotes as sim(*c*, *cp*) and sim(*c*, *cn*). We follow the contrast module training objectives developed by [26] as Equation (14):

$$l\_{\rm con} = -\sum\_{k=1}^{K} \sum\_{i=1}^{N} \log \frac{e^{\frac{\sin\left(\epsilon\_{j}c^{p}\right)}{\tau}}}{\sum\_{j=1}^{N} \left(e^{\frac{\sin\left(\epsilon\_{j}c^{p}\right)}{\tau}} + e^{\frac{\sin\left(\epsilon\_{j}c^{n}\right)}{\tau}}\right)}\tag{14}$$

where *N* is the length of a sentence, *K* is the batch size, and *τ* is a temperature hyperparameter that is set to 1 in our model.

#### *3.4. Loss Function*

In the training process, we jointly train the objectives of the multi-task learning module and the contrastive learning module. Our training aims to minimize the following total loss functions as Equation (15):

$$\text{loss} = \sum\_{i=1}^{n} \lambda\_i l\_i + \lambda I\_{\text{conf}} \tag{15}$$

where *n* represents the number of tasks, *li* is the loss function of each task in the MTL module, and *λ* and *λ<sup>i</sup>* are hyperparameters.

#### **4. Experiments**

*4.1. Datasets*

In our experiments, we employed two sentiment datasets and three public hate speech datasets. Table 1 displays the statistics of the datasets.

**Ruddit [5]** It is the first English Reddit comment dataset with fine-grained, real-valued scores ranging between −1 (maximum support) and 1 (maximum offense).

**OffensEval 2019 (Offen) [27]** This dataset was published in the evaluation exercise for SemEval 2019: Task 6. The dataset contains a total of 14,100 tweets. It is divided into a training set with 13,240 tweets and a test set with 860 tweets. There are 4400 tweets marked as offensive in the training and 240 in the test.

**AbusEval (Abuse) [28]** To obtain this dataset, the researchers added a layer of abusive language annotation to OffensEval 2019. The dataset is the same size as OffensEval 2019, as well as being divided into a training set of 13,240 texts and a test set of 860 texts.

**Reddit Sentiment Analysis (RSA) (https://www.kaggle.com/datasets/cosmos98/twi tter-and-reddit-sentimental-analysis-dataset [November 2022])** This dataset was produced as a result of a university study using PySpark to conduct sentiment analysis across multiple social media networks. The dataset also includes a sentimental label and approximately 37,000 comments. Since this dataset is an auxiliary dataset for training the multi-task learning module, we only use the training set.

**Tweet Sentiment Analysis (TSA) (https://www.kaggle.com/datasets/dv1453/twitter -sentiment-analysis-analytics-vidya [November 2022])** This is a tweet sentiment dataset from Kaggle 2018. This dataset contains more positive tweets and less negative tweets. This dataset also uses only the training set.

We used Pearson correlation (Pear) and mean square error (MSE) as evaluation metrics for the Ruddit dataset and Macro F1 (F1) as evaluation metrics for the Offen and Abuse datasets.


**Table 1.** Statistics of three experimental datasets.

#### *4.2. Training Details*

We use the five-fold cross-validation approach to evaluate the performance of our model on all three datasets. Referring to [5], we separated the original dataset into five equal parts, using one copy for testing and used the remaining data for training. To prevent the problem of data imbalance in multi-task learning, we use the WeightedRandomSampler approach to sample the data according to the weights. In our experiments, in the MTL module, the number of subnetworks in share expert is 2, and the number of sub-networks in the task-specific expert is also 2. Each expert has one layer of dropout, which is 0.1. The dropout used in the tower network is also 0.1. For the contrastive learning module, the temperature parameter *τ* is set to 1. The optimizer is Adam, the learning rate is 2 × 10−5, and the batch size is 16.

#### *4.3. Comparison with Baselines*

We compare our model (KMT) with a number of reliable baselines. The following is a brief description of the models:

**BERT [29]** This pre-trained model is mainly used to capture sentence features for the detection of hate speech.

**HateBERT [30]** It is a BERT variant that has been specially trained to recognize hate speech in English. The big dataset RAL-E, which contains Reddit comments from communities that have been banned because of their hateful or offensive speech, was used to train HateBERT. In the three popular datasets OffensEval 2019 [27], AbusEval [28], and HatEval [31], HateBERT significantly outperforms the BERT model.

**KMT** It is our proposed hate speech detection model based on sentence critical information and external sentiment information.

The comparison of the entire performance of KMT is shown in Table 2. From the results in this table, the following conclusions can be drawn:



**Table 2.** Comparative results of KMT and existing methods. Superscript \* indicates data obtained from the literature. The best results for each model are shown in boldface.

#### *4.4. Ablation Experiments*

We analyze the effect of different modules on the performance of our model. The results are shown in Table 3, where *w*/*o cl* indicates the ablation experiment for contrastive learning; *w*/*o s* indicates that the MTL module is removed and the sentiment dataset is not used as input to the model; and *w*/*o gate* indicates that the feature-filtering gate module is replaced with simple feed-forward neural network and a softmax layer.

According to the results in Table 3, we find that:



**Table 3.** Results of ablation experiments. The best results for each model are shown in boldface.

#### *4.5. Effect of Number of Experts*

Each expert module in the multi-task module consists of multiple sub-networks called Experts. To investigate the effect of the number of respective Experts (e.g., *E<sup>T</sup> <sup>s</sup>* and *E<sup>T</sup> <sup>k</sup>* ) in the shared expert module and task-specific expert module on the performance, we use 1 to 4 Experts on the Ruddit dataset to evaluate our model. As shown in Figure 3, the model performs best when the shared expert module has two Experts and the task-specific expert module has two Experts, which justifies the number of experts we choose in the experimental setup. In addition, the performance of the model is worse when the number of Experts in the shared expert module is three or four. This result indicates that having a larger number of parameters does not improve the performance of the model because too

many parameters may cause the model to be more difficult to train and an extremely large number of Experts may cause redundant information.

**Figure 3.** Pear mean value of model with different number of Experts, where the darker color indicates a higher Pear value.

#### *4.6. Effect of Extraction Network Layer Number*

Extraction networks are in the multi-tasking module, and each network consists of the expert modules and the feature-filtering module (Gate) in Figure 2, which is mainly used to extract features. To investigate the effect of the number of extraction network layers on performance, we test the effects of one-layer and two-layer extraction networks on our model on the Ruddit dataset. According to experience, the number of training parameters increases with the depth of the network structure. As the results shown in Table 4, the model performs better when the extraction network is one layer. As the depth of the extraction network increases, the model performance decreases because when the model is highly complex, it causes overfitting that the model becomes unstable. Furthermore, we also compare the overall running time of the two models, performed at the 3090 GPU setting, as shown in Figure 4. The results illustrate that the overall performance of the model is improved when the one-layer extraction network is used, besides, the number of parameters is also reduced due to the reduction in the number of network layers, which improves the efficiency of the model.

**Table 4.** Effect of number of extraction network layers


**Figure 4.** Runtime comparison.

#### **5. Conclusions and Future Work**

In this work, we propose a keyword-enhanced multi-expert framework for hate speech detection. This model can leverage both the external sentiment information and critical information of the sentence itself. Moreover, this model mainly uses a shared expert module to share certain parameters of multiple tasks. Through this approach, the model can more effectively share sentiment information and then fuse features by employing a feature-filtering gate to detect hate speech. We use contrastive learning for keyword enhancement, which enables the model to better identify critical information in sentences. Experiments show that our model, keyword-enhanced multi-expert framework, performs better on three datasets. Finally, detailed analysis further demonstrates the effectiveness of our model and the contribution of each module. In future work, we will explore the portability and generalization of the model and conduct portability experiments across datasets. Meanwhile, based on this work, we consider adding image information for multimodal hate detection.

**Author Contributions:** Conceptualization, W.Z. and G.L.; methodology, W.Z.; formal analysis, W.Z. and Q.W.; writing—original draft preparation, W.Z. and Q.W.; writing—review and editing, Y.X. and X.H.; supervision, Y.X. and X.H.; funding acquisition X.H. All authors have read and agreed to the published version of the manuscript.

**Funding:** This work was supported by the Characteristic Innovation Projects of Guangdong Colleges and Universities (Nos. 2018KTSCX049), and the Science and Technology Plan Project of Guangzhou under Grant Nos. 202102080258 and 201903010013.

**Institutional Review Board Statement:** Not applicable.

**Informed Consent Statement:** Not applicable.

**Data Availability Statement:** Not applicable.

**Conflicts of Interest:** The authors declare no conflict of interest.

#### **References**


### *Article* **A Novel Deep Reinforcement Learning Based Framework for Gait Adjustment**

**Ang Li 1,2, Jianping Chen 2,3,\*, Qiming Fu 1,2,\*, Hongjie Wu 1,2, Yunzhe Wang 1,2 and You Lu 1,2**


**Abstract:** Nowadays, millions of patients suffer from physical disabilities, including lower-limb disabilities. Researchers have adopted a variety of physical therapies based on the lower-limb exoskeleton, in which it is difficult to adjust equipment parameters in a timely fashion. Therefore, intelligent control methods, for example, deep reinforcement learning (DRL), have been used to control the medical equipment used in human gait adjustment. In this study, based on the key-value attention mechanism, we reconstructed the agent's observations by capturing the self-dependent feature information for decision-making in regard to each state sampled from the replay buffer. Moreover, based on Softmax Deep Double Deterministic policy gradients (SD3), a novel DRL-based framework, key-value attention-based SD3 (AT\_SD3), has been proposed for gait adjustment. We demonstrated the effectiveness of our proposed framework in gait adjustment by comparing different gait trajectories, including the desired trajectory and the adjusted trajectory. The results showed that the simulated trajectories were closer to the desired trajectory, both in their shapes and values. Furthermore, by comparing the results of our experiments with those of other state-of-the-art methods, the results proved that our proposed framework exhibited better performance.

**Keywords:** deep reinforcement learning; attention mechanism; state reconstruction; gait adjustment

**MSC:** 03D80; 68Q30

#### **1. Introduction**

Regaining the ability to walk is a primary goal of recovery for stroke patients. However, patients often experience restrictions on their daily communication and freedom of movement. Therefore, gait rehabilitation is urgently needed for these patients [1]. In the fields of gait rehabilitation and walking assistance, most lower-limb exoskeletons are developed for assisting paraplegic patients with disabilities of both of their legs. Through gait rehabilitation, we can achieve the goal of helping patients with mobility disorders in the rehabilitation of their musculoskeletal strength, motor control, and gait.

In traditional rehabilitation therapies, intensive labor is involved, and physical therapists have to provide patients with highly repetitive training that is usually inefficient and time-consuming [2]. The inherent shortcomings of these therapies include their failure to autonomously adapt to the user's changing needs, as well as the lack of sensory feedback that they provide to the user regarding the states of the limb and of the device. Compared to traditional physical therapies, exoskeleton-assisted rehabilitation has the advantages of reducing the work of therapists, and it is more convenient to use for quantitatively assessing the patient's level of recovery by measuring force and movement patterns [3].

**Citation:** Li, A.; Chen, J.; Fu, Q.; Wu, H.; Wang, Y.; Lu, Y. A Novel Deep Reinforcement Learning Based Framework for Gait Adjustment. *Mathematics* **2023**, *11*, 178. https:// doi.org/10.3390/math11010178

Academic Editor: Jianping Gou, Weihua Ou, Shaoning Zeng and Lan Du

Received: 18 November 2022 Revised: 15 December 2022 Accepted: 21 December 2022 Published: 29 December 2022

**Copyright:** © 2022 by the authors. Licensee MDPI, Basel, Switzerland. This article is an open access article distributed under the terms and conditions of the Creative Commons Attribution (CC BY) license (https:// creativecommons.org/licenses/by/ 4.0/).

To date, studies on exoskeleton control methods have achieved remarkable results. Mendoza-Crespo, Rafael et al. [4] developed and presented a method to acquire and saliently analyze subject-specific gait data, with the subject donning a passive lowerlimb exoskeleton. In [5], a trajectory tracking controller based on the boundary layer augmented sliding control (BASMC) law was implemented to guide the subject's limbs along physiological gait trajectories. However, patients are normally trained to passively follow a predefined gait reference trajectory and their initiatives or motivations are usually not considered in the abovementioned methods. Therefore, adaptive control techniques and deep reinforcement learning (DRL)-based control methods have been proposed. DRL can potentially be used for exoskeleton control, and a predefined gait trajectory is not required. More importantly, interaction between the exoskeleton of the lower extremity and the patient during rehabilitation can be achieved. Thus, in this study, we focused on the control of a lower-limb exoskeleton using DRL.

#### **2. Novelty and Contribution of the Study**

In this study, in order to achieve the goal of gait rehabilitation and walking assistance, we simulated an exoskeleton based on the lower-limb musculoskeletal model used in the 2019 NeurIPS "Learning to Move–Walk Around" challenge.

Firstly, we adopted the Markov decision process (MDP) to model the gait adjustment problem, which provided an intelligent policy for the control of the exoskeleton. Secondly, in order to solve the curse of dimensionality caused by the complexity of the musculoskeletal model, we proposed a DRL-based framework named AT\_SD3, which incorporated key-value attention-based state reconstruction and Softmax Deep Double Deterministic policy gradients (SD3). Based on the key-value attention mechanism, we presented a novel state reconstruction framework, in which all sampled sates are used in order to be fused proportionally with the initial observations, which enables the model to extract the self-dependent feature information of each sampled state to reconstruct an effective and interpretive state. Then, the DRL agent can select a better action in accordance with the same policy. Moreover, we used the autoencoder to extract features from the reconstructed state to solve the curse of dimensionality. Finally, we compared gait trajectories, including the desired trajectory, the unadjusted trajectory obtained in previous works, and the adjusted trajectory obtained in this work. The results showed that the adjusted trajectory was closer to the desired trajectory, in terms of its shape and value, than the unadjusted trajectory, and the performance of our proposed framework was better than that of other state-of-the-art DRL algorithms.

The related code and dataset are available at https://github.com/li0516/opensim-rl. git (accessed on 17 November 2022).

#### **3. Related Works**

#### *3.1. Adaptive Control Techniques*

Adaptive control techniques utilize dynamics models for both the user and the exoskeleton. Fatai Sado proposed a control strategy that integrated a dual unscented Kalman Filter (DUKF) for trajectory generation/the prediction of the spatio-temporal features of human walking and used an impedance-cum-supervisory controller to enable the exoskeleton to follow this trajectory in order to synchronize human walking [6]. In order to improve the control performance, the authors introduced a linear quadratic regulator with integral action (LQRi) and an unknown input observer (UIO) to compensate for disturbances [7]. In [8], an adaptive oscillator method named the amplitude omega adaptive iscillator (A*ω*AO), comprising both low-level classifiers (to detect activities) and high-level classifiers to detect transitions between activities, was proposed to provide bilateral hip assistance for human locomotion. Sado, F. et al. [9] proposed a exoskeleton controller, with the design of a low-level linear quadratic gaussian (LQG) torque controller, a middle-level user-input torque estimator based on the use of a dual extended Kalman filter (EKF), and a novel

high-level supervisory algorithm for the detection of movement and the synchronization of the exoskeleton with the user.

#### *3.2. DRL-Based Control Methods*

As one of learning-based control methods, Deep Reinforcement Learning (DRL), has been used in lower limb exoskeletons control. A human–robot interactive control, designed with Sigmoid function and the reinforcement learning algorithm, was proposed to govern the assistance provided by a lower limb exoskeleton robot to patients in the gait rehabilitation training [10]. In [11], Zhang, Y. et al. proposed a reinforcement-learning-based impedance controller, which actively reshapes the stiffness of the force-field to the subject's performance. In [12], an optimal adaptive compliance control was proposed for a Robotic walk assist device, where the reinforcement-learning-based strategy is a completely dynamic-model-free scheme, and this scheme employed joint position and velocity feedback as well as sensed joint torque (applied by user during walk) for compliance control. In [13], Rose, L. et al. presented for the first time an end-to-end model-free deep reinforcement learning method for an exoskeleton that can learn to follow a desired gait pattern, while considering a user's existing gait pattern and being robust to their perturbations and interactions. Oghogho, Martin et al. [14] employed the Twin Delayed Deep Deterministic Policy Gradient (TD3) method for rapid learning of the appropriate controller's gain values and delivering personalized assistive torques by the exoskeleton to different joints to assist the wearer in a weight handling task. In [15], Kumar, V.C.V. et al. took the Proximal Policy Optimization (PPO) to develop a human locomotion policy which can imitates the human walking reference motion. Based on all these achievements above, DRL-based control is inherently both adaptive and optimal, which can adapt to uncertainty and unforeseen changes in the robot dynamics [12].

Previous studies have shown that DRL is effective in the lower limb exoskeleton control. Moreover, with the concept of strengthening the discrimination among all the similar classes using the specific weights [16], in this paper, we propose a DRL-based framework, which incorporates a novel DRL algorithm SD3 and the key-value attention mechanism. Compared with the previous DRL methods, our framework can deal with the curse of dimensionality caused by the musculoskeletal model with high degree of freedom. From this perpective, our framework can greatly improve the performance of the DRL algrithm when a reinforcement learning (RL) agent observes a high dimensional state, and more importantly, experimental results show that our proposed framework has the state-of-art performance for the gait adjustment.

#### **4. Preliminaries**

#### *4.1. Reinforcement Learning*

We usually model the reinforcement learning problem as a MDP. A MDP is a quintuple (*S*, *A*, *R*, *P*, *γ*), where *S* is the state space, *A* is the action space, *R* is the reward function, *P* is the transition probability distribution and *γ* is the discount factor. At time step t, the agent selects and executes an action *at* ∈ *A* according to the policy *π*, which maps from the state s to the probability of an action a. Then, the environment moves to a new state *st*+<sup>1</sup> ∈ *S*, where *st*+<sup>1</sup> is determined from the transition probability *P*(*st*+1|*st*, *at*). Simultaneously, the agent receives the immediate reward *rt*+<sup>1</sup> ∼ *R*(*st*, *at*). The dynamic diagram of the agent interaction with the environment is shown in Figure 1.

In RL, we aim to find an optimal policy which maximizes the return *Gt* = ∑<sup>∞</sup> *<sup>k</sup>*=<sup>0</sup> *<sup>γ</sup>krt*<sup>+</sup>*k*+1. To achieve this, we evaluate the policy *π* by estimating the value function, including statevalue function *V<sup>π</sup>* and action-value function *Qπ*. Here, the state-value function *V<sup>π</sup>* is the expected return *Gt* when starting in state *s* and following policy *π* thereafter:

$$V\_{\pi}(\mathbf{s}) = E\_{\pi}[\mathbf{G}\_{t} \mid \mathbf{s}\_{t} = \mathbf{s}],\tag{1}$$

where the *Eπ*[·] denotes the expected value of the return *Gt* given that the agent follows policy *π*. The action-value function, also called Q-value, *Qπ*(*s*, *a*), represents the expected return *Gt* after taking an action *a* in state *s* and thereafter following policy *π*:

**Figure 1.** The interaction between the agent and the environment in RL.

Thereafter, the optimal policy *π*<sup>∗</sup> can be obtained by maximizing the state-value function or the action-value function, denoted *V*<sup>∗</sup> and *Q*∗, respectively. These two functions can be defined as follows:

$$V\_\*(s) = \max\_{\pi} V\_{\pi}(s)\_\* \tag{3}$$

$$Q\_\*(s\_\prime a) = \max\_{\pi} Q\_\pi(s\_\prime a) = E\left[R\_{t+1} + \gamma \max\_a Q\_\*(s\_{t+1}, a) \mid s\_t = s, a\_t = a\right].\tag{4}$$

#### *4.2. Softmax Deep Double Deterministic Policy Gradients*

DDPG algorithm is often used to solve continuous control problems [17,18]. However, one of the dominant concerns for DDPG is that it suffers from the overestimation problem caused by selecting an action with highest action-value estimates according to the critic network [19]. To reduce the adverse impact of the overestimation, double estimators were proposed for the critic in TD3 [20]. Nevertheless, another problem is the large underestimation bias caused by direct adoption of taking minimum estimation of actionvalue from the two critics in TD3 [21].

To tackle this problem, Pan, L. [19] proposed a method, called SD3, which combines the softmax operator with the estimation of the action-value based on double critic estimators. In SD3, double actor networks and critic networks are built to select multiple actions and evaluate the corresponding action-values, respectively. To be specific, alternative actions will be selected via different actor networks, and then the minimum action-value can be obtained by calculating and comparing the action value functions of the corresponding actions evaluated by two critic networks:

$$\hat{Q}\_{i=1,2}(s',a') = \min\left(Q\_{i=1}\left(s',a';\theta\_{i=1}^{-}\right), Q\_{i=2}\left(s',a';\theta\_{i=2}^{-}\right)\right). \tag{5}$$

Thereafter, the minimum Q-value will be induced by the softmax operator in expectation by the importance sampling, and the specific definition of the softmax Q-value is as follows: exp(*βQ*(*s* ,*a*ˆ ;*θ*<sup>−</sup>))*Q*(*s* ,*a*ˆ ;*θ*<sup>−</sup>) 

$$\text{softmax}\_{\beta}(Q(s', \cdot; \theta^{-})) = \frac{E\_{\mathbb{A}' \sim p}\left[\frac{\frac{\exp\left\{\frac{\exp\left\{\beta Q(s', \theta^{\prime}, \theta^{-})\right\}\right\} \mathcal{S}\_{\mathbb{A}'}(\theta^{-}, \theta^{-})}{p(\theta^{\prime})}}{E\_{\mathbb{A}' \sim p}\left[\frac{\exp\left\{\beta Q(s', \theta^{\prime}, \theta^{-})\right\}}{p(\theta^{\prime})}\right]}\right]},\tag{6}$$

where *β* is the parameter of the softmax operator, and the implication of *p*(*a*ˆ ) is the probability density function of the Gaussian distribution for the importance sampling. The *Ea*<sup>ˆ</sup>∼*p*[·] denotes the expected value of a random variable given that *<sup>a</sup>*ˆ are sampled from the Gaussian distribution *p*(*a*ˆ ). And *a*ˆ is the action with additional noises for exploration, which are sampled from the Gaussian distribution *p*(*a*ˆ ). Finally, the softmax Q-value can be obtained to calculate the target value:

$$y = r + \gamma (1 - d) \operatorname{softmax}\_{\beta} \left( Q(s', \cdot; \theta^-) \right). \tag{7}$$

#### *4.3. Key-Value Attention Mechanism*

Attention mechanism [22] in neural networks is introduced to focus on the information which is critical to the current task among the numerous input information. Therefore, the attention mechanism is often used to solve the problem of information overload and improve the efficiency and accuracy of task processing.

However, it is not suitable for some specific problems. So, Vaswani, A. et al. [22] introduced the key-value attention mechanism, which uses the format of a key-value pair to represent input information. The key is used to calculate the attention distribution *αi*, and the value is used to calculate aggregate information. As shown in Figure 2, (*K*, *V*) = [(*k*1, *v*1), ... ,(*kn*, *vn*)] is used to represent N sets of the input information and the vector *q* is used to represent the query vector for a given task. Then, the attention function can be defined as follows:

$$\text{att}(X, q) = \sum\_{i=1}^{N} \alpha\_i x\_i = \sum\_{i=1}^{N} \frac{\exp(s(k\_i, q))}{\sum\_j \exp(s(k\_i, q))} v\_{i\prime} \tag{8}$$

where *s* is the attention evaluation function, and *xi* is equal to *vi* which is used to represent the value of N sets of the input information. Finally, *a* weighted average of the input information *vi*, the final output *a* , can be achieved according to the distribution *αi*, which is computed based on the function *s*.

**Figure 2.** The key-value attention mechanism.

#### *4.4. Parameter Space Noise for Exploration*

Traditional RL methods increase exploration by adding noise, for example the Gaussian noise, to the output of the actor network. That is to say, the noise added to the actor network is independent of the state *st*, in other words, state-independent exploration. Hence, even for the same state *st*, a different action *at* will be certainly achieved and even sometimes it has nothing to do with *st*.

Therefore, Fortunato, Meire et al. [23] and Plappert, Matthias et al. [24] proposed to add noise to the agent's parameters. They sampled from a set of policies by adding the noise sampled from the Gaussian noise to the current policy *π*(*st*), and in this case, the same action *at* = *π*ˆ(*st*) can be achieved every time the same state *st* is taken as the input to the actor network.

#### **5. Problem Modeling**

In the previous work, we conducted gait simulation experiments with DRL algorithms based on the lower limb musculoskeletal model. The experimental results show that DRL algorithm is effective in gait simulation. However, sometimes during the simulation, there will be abnormal gait. In this paper, we adopt MDP to model the gait adjustment problem based on the musculoskeletal model.

#### *5.1. The Lower Limb Musculoskeletal Model*

In our work, the simulated environment used for the gait adjustment, named osimrl, used in 2019 NeurIPS "Learning to Move–Walk Around" challenge, incorporates the lower limb musculoskeletal model and DRL to provide the accurate human movement simulation. The lower limb musculoskeletal model built in OpenSim has 8 internal degrees of freedom (4 per leg) and is actuated by 22 muscles (11 per leg). During the simulation, muscles are driven by muscle activations (the control signals that muscles produce power), and then states of the musculoskeletal model including joint angles, body location and ground reaction forces will be returned. The lower limb musculoskeletal model is shown in Figure 3. More detailed environment description can be found at the page: http://osimrl.kidzinski.com/docs/nips2019/environment/ (accessed on 17 November 2022).

**Figure 3.** The lower limb musculoskeletal model.

### *5.2. MDP Modeling*

#### 5.2.1. State Space

The observation of the DRL agent consists of two parts: a target velocity map *T* and a body state *S*. Firstly, as shown in Figure 4, the target velocity map *T* is represented as a randomly generated target velocity matrix, which is a 2-dimensional target velocity vector, consisting of the target position and the current position of the model. Then, a target velocity vector can be achieved based on these positions. Secondly, the body state *S* is expressed by a 97-dimensional vector which consists of the pelvis state, ground reaction forces, joint angles and states of lower limb muscles. To be specific, the varibles of state space is listed in Table 1.

**Figure 4.** The target velocity map.

#### **Table 1.** State space.


#### 5.2.2. Action Space

The action space [0, 1] <sup>22</sup> represents muscle activations of 22 muscles. Muscles responds to these activations and generate forces, and then the model will act accordingly, for example, moving forward. At the same time, states of the model change accordingly.

#### 5.2.3. Reward Function

The DRL agent will obtain a reward *J*(*π*). The specific definition is as follows:

$$J(\pi) = R\_b + R\_{\mathbb{X}'} \tag{9}$$

where *Rb* and *Rg* refer to the reward for the initial gait simulation and the gait adjustment according to the desired trajectory. To be specific, *Rb* ensures that a basis human gait can be obtained based on the musculoskeletal. However, during the simulation, deformed gaits sometimes appeared. So *Rg* is designed to make up for the gait defects, which is reflected in the deviation between the simulated angle and the desired angle of each joint of the lower limb.

Firstly, the specific definition of *Rb* is as follows:

$$R\_b = M\_{alive} + M\_{step} \tag{10}$$

where *Malive* and *Mstep* refer to the model remaining standing as long as possible and moving with minimal forces according to the target velocity map, respectively. Here, *Malive* and *Mstep* are defined as follows:

$$M\_{alive} = \sum\_{i} m\_{alive\prime} \tag{11}$$

$$M\_{step} = \sum\_{step\_i} \left( w\_{step} m\_{step} - w\_{rel} \mathbf{c}\_{vel} - w\_{eff} \mathbf{c}\_{eff} \right). \tag{12}$$

In Equation (11), *malive* refers to the unit time of "model survival". In addition, in Equation (12), on the one hand, *mstep* is stepping reward which represents the total elapsed time-steps of "model survival" in simulation. *cvel* and *ceff* are the velocity and effort costs, respectively. On the other hand, *wstep*, *wvel* and *weff* are weights for the stepping reward, velocity and effort costs. Another point needed to note is that *wstep* is used to avoid getting higher reward by making small steps in human gait simulation.

Secondly, *Rg* is designed based on the changes of the real-time angle of each joint relative to the desired trajectory, for example, approaching or even exceeding in each episode. The specific definition is as Equation (13):

$$R\_{\mathcal{S}} = \sum\_{i=0}^{n} (w\_h r\_{i\_h} + w\_k r\_{i\_k} + w\_d r\_{i\_d})\_{\prime} \tag{1.3}$$

where *ri* and *wi* are the reward for each of the three joints in the lower limb and the corresponding weight, respectively. The reward *ri* for timestep *i* is defined as follows:

$$
\sigma\_i = w\_F F(q\_i) + w\_G G(q\_i),
\tag{14}
$$

where *wF* and *wG* are the weights for the reward *F*(*qi*) and the penalty *G*(*qi*), respectively. Here, on one hand, the function *F*(*qi*), representing the reward for the tendency approaching the desired trajectory, is defined based on the Gaussian function:

$$F(q\_i) = \frac{1}{\sigma\sqrt{2\pi}}e^{-\frac{1}{2}\left(\frac{d-\mu}{\sigma}\right)^2},\tag{15}$$

where *μ* and *σ* represent the mean and the SD of the desired joint angle, respectively. In addition, *d*, the absolute value of the difference between the real-time angle *qi* and the desired joint angle *qdi* is defined as follows:

$$d = |q\_i - q\_{d\_i}|.\tag{16}$$

On the other hand, the function *G*(*qi*), representing the penalty for exceeding the desired trajectory, is defined as Equation (17).

$$G(q\_i) = -M(y\_{\max}) - M(y\_{\min}),\tag{17}$$

where *M*(·) is defined as follows:

$$M(y) = \begin{cases} 0 & y \le 0 \\ y & y > 0 \end{cases},\tag{18}$$

and

$$y\_{\max} = q\_i - q\_{\max} \tag{19}$$

$$y\_{\rm min} = q\_{\rm min} - q\_{i\prime} \tag{20}$$

where *qmax* and *qmin* are the maximum and the minimum joint angle, respectively.

#### **6. Methodology**

#### *6.1. Overall Framework*

As depicted in Figure 5, the overall framework for gait adjustment consists of two parts: state reconstruct and SD3. First of all, the simulated environment initialization. Secondly, we reconstruct the initial observation via extracting features from existing states based on the attention mechanism, where the states are sampled in pairs with actions from the replay buffer randomly.

In the second part, the reconstructed state is taken as the input of SD3. Then, the actor network selects an action *ai* according to the observation where *i* refers to the serial number of the action corresponding to different actor networks, and following, the critic network evaluates the value of the state action pair *Q*(*s*, *ai*). Moreover, the final action *a* depends on the result of comparing action-values which are evaluated by two critic networks. It is worth noting that, we add noise directly to the actor network parameters for a state-dependent exploration, which ensures a dependency between the sampled state and the corresponding selected action.

**Figure 5.** The overall framework for gait adjustment.

#### *6.2. Key-Value Attention-Based State Reconstruction*

In this work, the initial observation is a 339-dimensional state which consists of a 97-dimensional body state and a 242-dimensional target velocity map. Therefore, the RL agent cannot extract effective information easily, and then choose better actions due to too much redundant information in this high-dimensional observation. Moreover, in RL, the observed state *s* and the selected action *a* of an RL agent often plays a significant role for the training of RL algorithms, and the information in each state usually play an important role in the choice of the action. For example, in the case of the same policy and different states, RL agent takes different actions without active exploration. As shown in Figure 6, the actions taken to reach *s*3, *s*<sup>4</sup> are shown by arrows. Although *s*<sup>1</sup> and *s*<sup>2</sup> are very close in space, they are functionally different, and these states contain necessary self-dependent feature information for the agent to perform the corresponding action. In other words, the self-dependent feature information in a state, for example *s*1, is different from shared information that exists in all states, and necessary for decision making, for example *a*1, which differs to the action *a*2. In our work, the musculoskeletal model moves accoring to the target velocity map, if the musculoskeletal model moves to the target position, and then a new target position will be randomly generated. Immediately, the RL agent will make a new action, for example turning right, to move towards another target position. Therefore, in this case, we refer to the specific information contained in the state that signals that the musculoskeletal model has reached the target position as the self-dependent information, which makes the agent makes a specific action.

**Figure 6.** The choice of different actions under the same policy and different states.

The attention mechanism is introduced to focus on the information which is critical to the current task among the input information. Therefore, on one hand, based on the keyvalue attention mechanism, we try to reconstruct the current observation via capturing selfdependent feature information in each sampled state. To be specific, firstly, we randomly sample *n* sets of state action pairs (*s*1, *a*1),(*s*2, *a*2), ... ,(*sn*, *an*) from the replay buffer. Here, the role of the sampled state action pairs (*si*, *ai*) in our proposed framework is equal to (*ki*, *vi*) in the key-value attention mechanism. The state *si* and the action *ai* are used to calculate the attention distribution and aggregate information, respectively. Moreover, we take the state-dependent exploration for the dependency between the sampled state *si* and the sampled action *ai*. In other words, in the case of the same policy, the selected action is only related to the state inputted to the policy. Secondly, considering the advantage of the critic network in dealing with continuous action spaces, for example the simulated environment in our work, the critic network is usually used to approximate action-value function [25], so we take the critic network as the attention evaluation function. Thus, we calculate the action-value *qi* of the above sampled actions with the critic network which takes the current observation and each sampled action *ai* as input.

Based on the above method, a series of action-value *qi* for the sampled actions can be achieved, which will serve as a basis for distinguishing the corresponding sampled state and reconstructing the initial observation. Thus, next to this operation, Softmax is used to normalize the corresponding action-value *qi*, where the normalized action-value *wi* represents the proportion of the sampled state in the reconstructed state. Significantly, the computed proportion *wi* can be seen as the attention distribution *α<sup>i</sup>* in key-value attention mechanism. Then, based on the attention distribution *wi*, the sampled states *si* will be fused with the initial observation proportionally. In a word, the self-dependent feature information in each sampled state corresponding to the sampled action with higher actionvalue *qi* will account for a larger proportion in reconstructed state. It is worth noting that, the way we perform feature fusion is element-wise addition. Based on this approach, the reconstructed state is influenced by the agent's action, and accordingly the state contains the information necessary to the action. Thus, the RL agent can select the corresponding action based on the information.

On the other hand, notably, autoencoder [26] is a kind of unsupervised neural network, and the goal of dimensionality reduction can be achieved by adjusting the number of hidden layers in both modules including the encoder and the decoder. Therefore, we use autoencoders to overcome the curse of dimensionality caused by the high-dimensional musculoskeletal model. The specific process is depicted as Figure 7.

**Figure 7.** State reconstruction.

*6.3. AT\_SD3 for Gait Adjustment*

Algorithm 1 presents the pseudocode of AT\_SD3 for the gait adjustment.



#### **7. Experiment Analysis**

*7.1. Experiment Preparation*

7.1.1. Dataset

To validate the effectiveness of the kinematic and ground reaction forces obtained via the simulation based on DRL algorithms, we compare the simulated data with the experimental data in a public dataset [27], where more details of the experiment refer to Section 7.2.2. The dataset contains a single-source, readily accessible repository of comprehensive gait data for a large group of children walking at a wide variety of speeds including very slow (below average speed), slow, free, fast and very fast (above average speed). Specifically, there are seven kinds of gait data: joint rotations, ground reaction forces, joint moments, joint power, EMG (electromyographic), cycle events and an ANOVA table with results for selected parameters in this dataset.

#### 7.1.2. Evaluation Metrics

In order to compare the similarity between the experimental gait data and the simulated gait data, two evaluation metrics are adopted in this paper, namely mean absolute error (MAE), root mean square error (RMSE). These two metrics are defined as follows:

$$MAE = 1/m \sum\_{i=1}^{m} |y\_i - y\_i'| \,\tag{21}$$

$$RMSE = \sqrt{1/m \sum\_{i=1}^{m} \left(y\_i - y\_i'\right)^2},\tag{22}$$

where *m* denotes the total number of gait data, *yi* and *y <sup>i</sup>* represent the simulated and experimental data of the *i* − *th* sample, respectively.

#### 7.1.3. Parameter Settings

The hyperparameters of all methods are summarized in Table 2. It can be observed that two hidden layers are used, and the number of neurons in each hidden layer are 128 and 64, respectively. Considering the high-dimensional environment, we set the replay buffer size to 5 × <sup>10</sup><sup>6</sup> and the batch size is 256. Regarding the learning rate, TD3, AT\_SD3, SD3, SD3\_AE and PPO methods are all set to 0.0001, while DDPG method is set to 0.01. In addition, the hyperparameters, related to the noise added to the actor network, are also listed in Table 1. Note that all parameters are obtained through extensive numerical experiments.

**Table 2.** Hyperparameters of TD3 [14], DDPG [13], SD3, SD3\_AE, PPO [15] and AT\_SD3.


#### *7.2. Results and Analysis*

#### 7.2.1. Algorithm Performance

In order to verify the effectiveness of AT\_SD3 in the respect of gait adjustment based on the musculoskeletal model, we compare it with other state-of-the-art DRL algorithms, including TD3 [14], DDPG [13], PPO [15], SD3\_AE and SD3, on the gait adjustment problem. The result is shown in Figure 8.

**Figure 8.** Performance of AT\_SD3 and other state-of-the-art DRL algorithms.

Figure 8 shows the performance of AT\_SD3 and other state-of-the-art DRL algorithms for the gait adjustment, where the horizontal axis represents the number of episodes and the vertical axis is the average reward. In this figure, each curve indicates the average reward for the gait adjustment using different DRL algorithms over a total of 12,000 episodes. The shaded area represents the SD varying from the mean value of the three independent experiments with same hyperparameters.

It can be noted that, on the one hand, the performance of AT\_SD3 outperforms other traditional DRL algorithms after a certain number of episodes, including DDPG, PPO, TD3 and a novel DRL algorithm SD3. On the other hand, the performance of PPO keeps stable throughout the simulation, and the performance of DDPG is always poor compared to other algorithms, which may result from the limited algorithmic power in dealing with the curse of dimensionality in DRL. On the contrary, TD3, with more complex network structure, has better performance than PPO and DDPG. In our work, the current observation is a 339-dimensional musculoskeletal state, which may lead to this phenomenon. So, we introduce SD3 into our work to deal with the difficulty of gait adjustment caused by this problem. Due to the complexity of network structure, SD3 has a relative advantage over other RL algorithms in dealing with 'the curse of dimensionality'. However, as can be seen from Figure 8, after a certain number of episodes, the performance of SD3 keeps stable gradually but the rewards are relatively low. Therefore, an attention mechanism-based framework for gait adjustment is proposed. Based on the reward difference between AT\_SD3 and other algorithms observed in Figure 8, we can conclude that AT\_SD3 is more efficient than other traditional algorithms for the gait adjustment. Moreover, we provide an ablation experiment, named SD3\_AE, to prove the effectiveness of our proposed framework. To be specific, we combine SD3 with the autoencoder for the gait adjustment. As can be seen in Figure 8, the performance of SD3\_AE is better than SD3 due to the advantage of feature extraction and solving the curse of dimensionality. More importantly, by comparing the performance of AT\_SD3 and SD3\_AE, we can conclude that state reconstruction through the key-value attention mechanism is effective in gait adjustment. Through the above

groups of comparative experiments, the experimental result demonstrates the effectiveness of fusing the self-dependent feature information necessary for decision making in each sampled state with the current observation.

#### 7.2.2. Gait Adjustment

We compare different gait trajectories including the unadjusted trajectory obtained in previous work, the adjusted trajectory obtained in this work and the desired trajectory obtained in [27].

a. Unadjusted Trajectory and Desired Trajectory

Figure 9 shows the gait trajectories for different joints, including the ankle flexion/extension, the knee flexion/extension, the hip adduction/abduction and the hip flexion/extension corresponding to sub-figure (a) to (d), respectively, where the horizontal axis represents the gait cycle and the vertical axis represents different gait trajectories. In each sub-figure, red curve indicates the desired trajectory and another curve represents the unadjusted trajectory obtained by the human gait simulation in previous work. In terms of RMSE and MSE, Table 3 shows these similarity metrics between the desired trajectory and the unadjusted trajectory simulated in previous work.

**Figure 9.** The simulated kinematics compared to the experimental data in [27].

As can be seen from Figure 9, the unadjusted trajectory for different joints obtained in previous work are similar in shape to the desired trajectory, which is the mean kinematics calculated from the maximum and minimum value of the kinematics. However, as shown in Table 3 and Figure 9, there is a deviation between the unadjusted trajectory and the desired trajectory, which result from the randomness of the gait simulated by the algorithms in previous work. As can be seen from Table 3, these two kinds of metrics obtained in previous work are no more than 2.64 SD and no less than 1.22 SD.


**Table 3.** Metrics between desired trajectory and the unadjusted trajectory.

b. Adjusted Trajectory and Desired Trajectory

Figure 10 shows the trajectories for different joints, including the ankle flexion/extension, the knee flexion/extension, the hip adduction/abduction and the hip flexion/extension corresponding to sub-figure (a) to (d), respectively, where the horizontal axis represents the gait cycle and the vertical axis represents the gait trajectory for different joints. In each sub-figure, red curve indicates the desired trajectory and another curve represents the adjusted trajectory obtained in this work. Table 4 summarizes the similarity metrics between the desired trajectory and the adjusted trajectory obtained in this work, in terms of RMSE and MSE.

**Figure 10.** Desired trajectory and the adjusted trajectory based on the simulated lower limb exoskeleton.

**Table 4.** Metrics between desired trajectory and the adjusted trajectory.


As can be found from Figure 10, the gait trajectories for different joints obtained in this work are almost consistent with the desired trajectory in shape and value. This phenomenon demonstrates the effectiveness of gait adjustment with the simulated lower limb exoskeleton, which is modeled as a MDP problem in this work. However, in subfigure (c), the adjusted trajectory for the hip adduction/abduction deviate from the desired trajectory in part of the gait cycle, which may result from the randomness. As can be found from Table 3, these metrics are no more than 0.32 SD which is much lower the figures in

Table 4, and these figures also demonstrate the effectiveness of the gait adjustment with the simulated exoskeleton.

#### **8. Conclusions and Future Work**

In order to verify the effect of gait rehabilitation for patients with mobility disorders, one available approach is to adjust gait without using physical equipment, where the musculoskeletal model is used in 2019 NeurIPS "Learning to Move–Walk Around" challenge. In this paper, we adopt MDP to model the gait adjustment problem. Moreover, based on DRL algorithms and the attention mechanism, a framework named AT\_SD3 for the gait adjustment is proposed. Taking advantages of the attention mechanism, the self-dependent feature information for decision making in the sampled states generated by the agent's actions can be captured, with which we can reconstruct the initial observation with more interpretive information. Considering the high dimension of RL state and the advantage of autoencoder, the autoencoder is applied to solve the problem of 'the curse of dimensionality'. To investigate the performance of the proposed framework, the proposed framework and other traditional DRL algorithms are applied to the gait adjustment. The comparison results suggest that the performance of the proposed framework is superior to other traditional RL algorithms. Moreover, we compare different trajectories, including the unadjusted trajectory and adjusted trajectory obtained in previous work and in this paper, respectively, and comparative results suggest the trajectories simulated by using our proposed framework are closer to the desired trajectory in both shape and value, which outperforms the related previous work. In terms of the evaluation metrics of MAE and RMSE, results show the trajectories obtained in this paper are more accurate than those obtained in previous work.

As for the future work, the way to extract the information in each sampled state that is critical to the selected action is still worth studying. Moreover, we will purchase an actual lower limb exoskeleton to verify the effectiveness of the proposed exoskeleton control framework. Therefore, in the process of controlling the actual lower limb exoskeleton, the adjustment of exoskeleton parameters and the RL modeling for the exoskeleton control are worth studying.

**Author Contributions:** Conceptualization, A.L.; data curation, A.L.; formal analysis, A.L.; funding acquisition, J.C., Q.F. and Y.W.; investigation, H.W., Y.W. and Y.L.; methodology, A.L. and Q.F.; project administration, J.C., H.W., Y.W. and Y.L.; software, A.L.; supervision, J.C., Q.F. and Y.L.; validation, A.L.; writing—original draft, A.L.; writing—review and editing, Q.F. All authors have read and agreed to the published version of the manuscript.

**Funding:** This work was financially supported by National Key R&D Program of China (No. 2020YFC2006602), National Natural Science Foundation of China (No. 62172324, No. 62072324, No. 61876217, No. 61876121), University Natural Science Foundation of Jiangsu Province (No. 21KJA520005), Primary Research and Development Plan of Jiangsu Province (No. BE2020026), Natural Science Foundation of Jiangsu Province (No. BK20190942).

**Data Availability Statement:** The data presented in this study are openly available at https://doi. org/10.1016/j.jbiomech.2008.03.015 (accessed on 17 November 2022)

**Conflicts of Interest:** The authors declare no conflict of interest.

#### **Abbreviations**

The following abbreviations are used in this manuscript:


#### **References**


**Disclaimer/Publisher's Note:** The statements, opinions and data contained in all publications are solely those of the individual author(s) and contributor(s) and not of MDPI and/or the editor(s). MDPI and/or the editor(s) disclaim responsibility for any injury to people or property resulting from any ideas, methods, instructions or products referred to in the content.

### *Article* **Embedding Uncertain Temporal Knowledge Graphs**

**Tongxin Li, Weiping Wang, Xiaobo Li, Tao Wang \*, Xin Zhou and Meigen Huang**

School of Systems Engineering, National University of Defense Technology, Changsha 410000, China **\*** Correspondence: wangtao1976@nudt.edu.cn

**Abstract:** Knowledge graph (KG) embedding for predicting missing relation facts in incomplete knowledge graphs (KGs) has been widely explored. In addition to the benchmark triple structural information such as head entities, tail entities, and the relations between them, there is a large amount of uncertain and temporal information, which is difficult to be exploited in KG embeddings, and there are some embedding models specifically for uncertain KGs and temporal KGs. However, these models either only utilize uncertain information or only temporal information, without integrating both kinds of information into the underlying model that utilizes triple structural information. In this paper, we propose an embedding model for uncertain temporal KGs called the confidence score, time, and ranking information embedded jointly model (CTRIEJ), which aims to preserve the uncertainty, temporal and structural information of relation facts in the embedding space. To further enhance the precision of the CTRIEJ model, we also introduce a self-adversarial negative sampling technique to generate negative samples. We use the embedding vectors obtained from our model to complete the missing relation facts and predict their corresponding confidence scores. Experiments are conducted on an uncertain temporal KG extracted from Wikidata via three tasks, i.e., confidence prediction, link prediction, and relation fact classification. The CTRIEJ model shows effectiveness in capturing uncertain and temporal knowledge by achieving promising results, and it consistently outperforms baselines on the three downstream experimental tasks.

**Keywords:** uncertain temporal knowledge graph; temporal knowledge graph; knowledge graph embedding; confidence score

**MSC:** 68T07; 68T30

### **1. Introduction**

KGs, which store various relation facts in the real world, are extensively applied in downstream tasks such as natural language processing [1], information retrieval [2], and knowledge question answering [3]. A relation fact (or triple) is composited of two entities (as nodes) and the relation that connects them (as the edge), which can be described as (*h*,*r*, *t*) or (*s*, *p*, *o*) [4]. Although KGs contain millions of such triples, it is known to suffer from incompleteness. This issue gives rise to the task of KG completion, which entails predicting the information missing in KGs. KG embedding, also known as knowledge representation learning, has become the mainstream method for KB completion by building the distributed representations (or vector embeddings) of entities and relations [5].

Specifically, KG embedding represents a symbolic triple (*h*,*r*, *t*) as low-dimensional, dense real-valued vectors **(***h***,***r***,***t***)**, each corresponding to the head entity, relation, and tail entity, respectively. Various embedding methods are currently emerging, mainly including translation-distance-based and semantic-matching-based models. TransE [6] is an original model based on translation distance and is known for its effectiveness and simplicity. In the TransE model, the sum of the head entity vector *h* and its relation vector *r* is close to its tail entity vector *t* for each relation fact, i.e., *h* **+** *r ≈ t*. TransE can effectively capture the structural and semantic information of the KG, but it cannot handle complex relations. To solve this problem, researchers have proposed multifarious models [7–9]. In addition, there

**Citation:** Li, T.; Wang, W.; Li, X.; Wang, T.; Zhou, X.; Huang, M. Embedding Uncertain Temporal Knowledge Graphs. *Mathematics* **2023**, *11*, 775. https://doi.org/ 10.3390/math11030775

Academic Editor: Mikhail Goubko

Received: 30 December 2022 Revised: 31 January 2023 Accepted: 1 February 2023 Published: 3 February 2023

**Copyright:** © 2023 by the authors. Licensee MDPI, Basel, Switzerland. This article is an open access article distributed under the terms and conditions of the Creative Commons Attribution (CC BY) license (https:// creativecommons.org/licenses/by/ 4.0/).

are many embedding models based on semantic matching [10–12], which have achieved a high accuracy in link prediction tasks.

The above research methods are all reasoning on deterministic and static KGs without considering the uncertain and temporal information of triples, which leads to some key issues. The first is how to embed the uncertain KG. Uncertain KGs, such as Concept-Net [13] and NELL [14], associate each relation fact with a confidence score representing the likelihood of that fact to be true. During the construction of a KG, many automated methods generate noise and conflict, resulting in a certain degree of uncertainty for each triple. Embedding such uncertain knowledge can critically capture the uncertain nature of reality and provide more precise reasoning. The second is to learn information about the temporal dynamics of the relation facts in KGs. Most relation facts in KGs change over time, for example, the fact Claudio Raineri, coach, Chelsea is only true from 2000 to 2004, and ignoring such temporal information may lead to ambiguity and misunderstanding. The temporal information of relation facts also carries essential causal patterns that can assist the link prediction. To sum up, embedding the uncertain and temporal characteristics of relation facts can help KGs to perform better reasoning.

For the uncertainty of triples, uncertain KG embedding (UKGE) [15] calculates a score function based on the DistMult model and considers a probabilistic soft logic to generate confidence scores for unseen relation facts, but it does not fully exploit the structural information in the KG. Structural and uncertain knowledge embedding (SUKE) [16] employs an evaluator and a confidence generator to embed the confidence scores and structural information simultaneously, but the evaluator and the confidence generator are not combined into a unified framework, which means that the entity and relation vectors generated by the two components are not shared. Chen et al. [17] abandoned the probabilistic soft logic to generate extra training samples and leveraged a pool-based semisupervised learning model PASSLEAF to generate confidence scores for unseen relation facts. This model could partially solve the false-negative problem caused by random negative sampling, but it only considered the knowledge confidence and ignored the rich information contained in the graph structure. For embedding temporal information in KGs, a significant number of temporal KG representation learning models have recently emerged. The models TTransE [18] and HyTE [19] learned the distinct representations on each snapshot, and ATiSE [20] simplified the evolution of a temporal KG as a diachronic entity representation. Lately, most of the models apply neural networks to characterize the structural information and temporal evolution of KGs [21–23]. However, none of the aforementioned studies exploit both uncertainty and temporal information. Chekol et al. [24] explored Markov logic networks and probabilistic soft logic for reasoning on uncertain temporal KGs without utilizing embedding-based approaches and obtained a high computational complexity and low efficiency.

In response to the above issues, we propose the confidence score, time, and ranking information embedded jointly model CTRIEJ for the uncertain temporal KG embedding, which integrates the uncertainty, temporal information, and structural information into a unified framework. The CTRIEJ model first utilizes the sequence model to incorporate temporal information into the embedding of relations and then applies the sum of two loss functions as the objective function for training, one is the square loss function representing the confidence prediction, and the other is the pairwise ranking loss function representing structural information. When evaluating the model on multiple downstream tasks, we still employ the score function based on semantic matching for the confidence prediction and relation fact classification, and we design a score function based on translation distance and semantic matching to predict missing relation facts in the uncertain temporal KG. In addition, we adopt a self-adversarial negative sampling technique to train the model.

The main contributions of this paper can be summarized as follows:

• We leverage a GRU-based sequence model to incorporate temporal information into the embedding of the relation sequence and tie in two score functions on account of semantic matching and translation distance simultaneously to characterize the confidence information and structure information for the uncertain temporal KG in a unified framework.


The rest of the paper is organized as follows. We introduce the definition of uncertain temporal KGs and then review related work in Section 2. In the following two Sections, we propose our CTRIEJ model and conduct related experiments. Finally, we draw a conclusion in Section 5.

#### **2. Related Work**

As far as we know, there is currently no embedding learning method for the uncertain temporal KG, so we introduce the related work from three aspects: deterministic KG embedding models, temporal KG embedding models, and uncertain KG embedding models. For the sake of understanding, we first define the relevant problems of the uncertain temporal KG.

#### *2.1. Problem Definition*

The relevant definitions of the uncertain temporal KG are given as follows.

**Definition 1.** *Temporal knowledge graph: A temporal KG can be denoted by G* = (*E*, *R*, *Q*)*, where E and R represent the set of entities and relations, respectively, and Q represents the set of temporal relation facts. Each relation fact* (*h*,*r*, *t*) *in the graph has a valid time* [*Ts*, *Te*]*, which denotes the closed interval from Ts to Te, with Ts* ≤ *Te and Ts*, *Te* ∈ *T, i.e., f* = (*h*,*r*, *t*, [*Ts*, *Te*])*. We refer to f as a temporal fact.*

For a temporal KG *G*, its snapshot at time *T* is the graph (the nontemporal KG): *G*(*T*) = {(*h*,*r*, *t*)|(*h*,*r*, *t*, [*T*, *T*]) ∈ *G*}.

**Definition 2.** *Uncertain temporal knowledge graph: An uncertain temporal KG consists of temporal relation facts with confidence scores that typically model the inherent uncertainty. We can represent a fact as u* = *f* ,*sf , where f* = (*h*,*r*, *t*, [*Ts*, *Te*]) *is a temporal relation fact, and sf* ∈ R[0,1] *is a real-valued weight assigned to f .*

**Example 1.** *Uncertain temporal knowledge graph: the following uncertain temporal KG represents sport's personality Claudio Raineri's career [24]:*


**Definition 3.** *Uncertain temporal knowledge graph embedding: Given an uncertain temporal KG, the embedding can be expressed as a mapping function <sup>f</sup>* : *<sup>h</sup>* → *<sup>h</sup>* ∈ R*dE , <sup>r</sup>* → *<sup>r</sup>* ∈ R*dR , <sup>t</sup>* → *<sup>t</sup>* ∈ R*dE , Ttoken* → *<sup>T</sup>token* ∈ R*dT , where <sup>h</sup>*,*r*, *and <sup>t</sup> are the vector representations of the head entity, relation, and tail entity, respectively, Ttoken is the vector representation of the temporal token, which is described in detail in Section 4.2, dE, dR, and dT represent the dimension of the entity vector, relation vector, and temporal token vector, respectively. In this model, we make dE* = *dR* = *dT* = *d.*

#### *2.2. Deterministic Knowledge Graph Embeddings*

The deterministic KG contains a series of triples (*h*,*r*, *t*), where *h*, *t* ∈ *E*, *r* ∈ *R*. The deterministic KG can be regarded as an uncertain KG with triples whose confidence scores are all one. At present, the deterministic KG embedding models can be mainly divided into three categories: tensor-decomposition-based models, translation-based distance models, and neural-network-based models.

Structured embedding (SE) [25] is one of the earlier knowledge representation methods. For a relation fact, SE projects the head and the tail entity vector into a relation vector space through its two matrices and then calculates the distance between the two projection vectors in this space. This distance reflects the semantic relevance of the two entities under the relation, and the smaller their distance is, the more likely it is that the fact triple is established. In addition, the semantic matching energy model (SME) [26] defines several projection matrices and utilizes bilinear functions to describe the internal relationship between entities and relations. Bilinear functions are also utilized in the latent factor model (LFM) [27], which proposes to employ a relation-based bilinear transformation to characterize the second-order relationship between entities and relations. The DistMult model [11] also explores a simplified form of latent factor, which sets the relation matrix as a diagonal matrix. Based on the LFM, the neural tensor network (NTN) [28] model further employs the bilinear transformation of the relation to characterize the relationship between entities and relations. In addition, some researchers have proposed to apply matrix factorization for knowledge representation learning, and the RESACL model [10] is the representative method in this regard. The basic idea of RESACL is similar to the aforementioned LFM, and the difference lies in that RESACL optimizes all positions in the tensor, including the position with a value of zero, while the LFM only optimizes the triples that exist in the KG.

Bordes et al. were inspired by the translation invariance of the semantic and syntactic relationship in the word vector space and proposed the TransE model [6], which treated the relation in the KG as a translation vector between the head and tail entity. Compared with previous models, TransE has fewer parameters and a low computational complexity, and it can directly establish complex semantic connections between entities and relations. Bordes et al. conducted evaluation tasks such as link prediction on the WordNet and Freebase data sets, and experimental results showed that the performance of TransE was significantly improved, especially on large-scale sparse KGs. However, TransE has difficulty handling one-to-many, many-to-one, and many-to-many relations. To overcome the shortcomings of TransE, TransH [7] introduces a relation hyperplane, which is based on the idea of allowing an entity to have different vector representations in different relation triples. By employing a relation-specific hyperplane, the TransH model distinguishes different roles of the same entity in different triplets. The TransR [8] model also allows entities and relations to be in different dimensional representation spaces and then maps both to the same dimension by exploring the relational-related transformation space. There are many variants based on TransE, including TransM [29], TransF [30], TransA [9], etc., and most of these algorithms were introduced to further solve the defects of TransE and improve the expressive ability of the model. There are not only translation transformations in the representation space, but also rotation transformations. The RotatE [31] model represents the relation in the KG as a rotation operation in complex space based on Euler's formula. Through such a design, RotatE can express symmetric and antisymmetric relations, reciprocal relations, and compositional relations contemporarily, which was not available in previous models.

According to a variety of neural networks, knowledge embedding models of neural networks can generally be divided into five categories: linear/bilinear neural networks, convolutional neural networks (CNNs) [32–34], recurrent neural networks (RNNs) [35–37], graph neural networks (GNNs) [38–41], and generative adversarial networks (GANs) [42].

#### *2.3. Temporal Knowledge Graph Embeddings*

Current research in KG embedding focuses on static KGs, where relation facts do not change over time, such as the TransE model, TransH model, RESCAL model, etc., mentioned above. However, KGs are usually dynamic in practical applications, where facts evolve over time and are only valid for a specific period. Previous static KG embedding models completely ignore temporal information, which makes these methods unable to work in practical scenarios. Therefore, a significant number of temporal KG embedding models have emerged.

Know-Evolve [21] updates the embedding representation of entities subject to temporal changes by building an RNN on top of the static KG representation. TTransE [18] utilizes time information to constrain triples and models the time-predicate sequence for inference. TA-TransE and TA-DistMult [22] utilize the temporal information to constrain relation representations and construct temporal relation representations for each knowledge instance with a digital-level long short-term memory (LSTM) model. ATiSE [20] fully mines the impact of time on the evolution of entities, not only including the impact of past time but also mining the impact of future time on entities through the trend, cycle, and randomness of time series. RE-NET [23] converts time into a sequence of events with temporal information, constructs RNN-based encoding of entities in the sequence to capture the influence of their historical information, and finally leverages a relation-aware GCN to aggregate information about the entities within the same time. Chang2vec [43] splits the temporal KGs into multiple static KGs on each snapshot and employs metapath encoding for each KG to recompute the entity representation of nodes that have changed and update their embedding. CyGNet [44] exploits the historical information in KGs by designing a special replication module, while the generation module is designed to predict the knowledge that appears for the first time. xERTE [45] combines low-dimensional static vectors and temporal functions for the embedding representation of entities, not only to represent long-term properties of entities that do not change over time and the characteristics of change affected by time but the model can also visualize the paths interpretably for inference. RE-GCN [46] learns the evolutional representations of entities and relations at each timestamp by modeling the KG sequence recurrently and also incorporates the static properties of entities (such as entity types) via a static graph constraint component to obtain better entity representations.

Most of the above approaches make use of the temporal and structural information in the KG, but all assume that the triples are deterministic, and neither of them considers the confidence score of each relation fact.

#### *2.4. Uncertain Knowledge Graph Embeddings*

Some open KGs with uncertain information, such as NELL, ConceptNet, etc., add a confidence score to each triple to describe the uncertainty of this relation fact. Different KGs have different strategies for calculating confidence scores. The confidence level is obtained through the frequency of crowdsourcing annotations in ConceptNet [13], while NELL calculates the confidence value with probabilistic semantics by the EM algorithm [14].

Compared with the deterministic KG, the uncertain KG has additional triple confidence information. Recently, some research has been conducted on the representation and inference of uncertain KGs from different perspectives. GTransE [47] aims to improve the robustness of the representation model in learning noisy data. Specifically, it uses the confidence scores of triples to dynamically adjust intervals in the pairwise ranking loss, so that the higher confidence triples have larger intervals between positive and negative examples, thus making the model more focused on learning higher confidence triples.

UKGE [15] first proposed the task of learning the representation of uncertain KGs and embedding the structural information and confidence information at the same time. Specifically, it calculates the mean square error (MSE) Loss to fit the confidence scores of triples based on the energy function of DistMult. In this way, the confidence information is embedded into the distance of entities and relations, and we can employ the energy

function of the triple to predict its confidence score. In addition, UKGE also introduces logic rules as prior knowledge, employs PSL probabilistic soft logic to reason about unseen facts, and applies them as training data to train to embed, thereby preserving the constraints of the rules into the embedding representation.

SUKE [16] still applies the DistMult model as an energy function and explores different logistic functions to transform the energy score into a structural information function and a confidence prediction function. The model consists of two parts: an evaluator and a confidence generator. For unseen triples, the evaluator learns the structural information and uncertain information to evaluate their plausibility and obtains a candidate set. The confidence generator then predicts corresponding confidence scores by learning the uncertain information of triples in the candidate set. However, the embedding vectors of entities and relations generated by the two components are independent of each other, which means that twice as much storage and computational space needs to be allocated.

PASSLEAF [17] argued that if we set the confidence scores of all observed triples to zero, it would cause a false-negative problem. In an uncertain KG, in addition to visible triples with confidence scores, there are more unseen triples that may also have a variety of confidence scores. The model leveraged semi-supervised learning and a sample pool to generate training samples in order to consider confidence scores of unseen triples. Moreover, multiple types of score functions were compared in the experiments of the model.

#### **3. Confidence, Time, and Ranking Information Embedded Jointly**

*3.1. The Framework Overview*

In this section, we propose the CTRIEJ model, which can simultaneously infer the missing relation facts and predict their confidence scores. The overall framework of the model is shown in Figure 1. It consists of three main components: a time-aware embedding model that incorporates time embedding in the relation embedding, a confidence prediction model that characterizes the uncertain information, and a pairwise ranking loss model that represents the structural information. In Section 4.2, a gate recurrent unit (GRU) is employed to process the sequence of the relation and time to obtain the relation embedding incorporating time. In Section 4.3, we describe in detail two functions based on semantic matching and translation distance, which characterize the uncertain information and structural information in the uncertain temporal KG, respectively. Finally, we combine the loss functions of the two components to form a joint embedding model and adopt a selfadversarial negative sampling technique to generate negative samples, which sample the negative triples according to the current embedding vectors. The details are in Section 4.4.

**Figure 1.** The overall framework of the CTRIEJ model.

#### *3.2. GRU for Time-Aware Embedding Sequences*

Contrary to all previous approaches, we encode sequences of temporal tokens with a GRU. A GRU is a neural network architecture particularly suited for modeling sequential data. Given an uncertain temporal KG where some triples are augmented with temporal information, we can decompose a given timestamp into a sequence consisting of some of the following temporal tokens.

As shown in Figure 2, the month and the day are represented by numbers 0 to 9. In addition to these numbers, the year has an extra "-", which is used at the beginning to indicate BC. The year usually consists of 4 digits, the month consists of 2 digits to characterize January to December, and the number of days consists of 2 digits to represent one day in a month. Hence, temporal tokens have a vocabulary size of 31. A complete timestamp should contain a start time *Ts* and an end time *Te*, which we combine as the sequence of temporal tokens. Moreover, for each triple, we refer to the concatenation of the relation and its sequence of temporal tokens as the relation sequence *rseq* <sup>=</sup>

*r*, *Ts*<sup>1</sup>*<sup>y</sup>* , *Ts*<sup>2</sup>*<sup>y</sup>* , *Ts*<sup>3</sup>*<sup>y</sup>* , *Ts*<sup>4</sup>*<sup>y</sup>* , *Ts*<sup>1</sup>*<sup>m</sup>* , *Ts*<sup>2</sup>*<sup>m</sup>* , *Ts*<sup>1</sup>*<sup>d</sup>* , *Ts*<sup>2</sup>*<sup>d</sup>* , *Te*<sup>1</sup>*<sup>y</sup>* , *Te*<sup>2</sup>*<sup>y</sup>* , *Te*<sup>3</sup>*<sup>y</sup>* , *Te*<sup>4</sup>*<sup>y</sup>* , *Te*<sup>1</sup>*<sup>m</sup>* , *Te*<sup>2</sup>*<sup>m</sup>* , *Te*<sup>1</sup>*<sup>d</sup>* , *Te*<sup>2</sup>*<sup>d</sup>* with length 17, where the suffixes *y*, *m*, and *d* indicate whether the digit corresponds to the year, month, or day information. Now, an uncertain temporal KG can be represented as a set of quadruples of the form *h*,*rseq*, *t*,*s* , where the sequence of relation *rseq* includes the temporal

information. These relation token sequences are used as input to a GRU. The equations defining a GRU are as follows:

$$\begin{array}{l} \Gamma\_{\boldsymbol{u}} = \sigma(\boldsymbol{\mathcal{W}}\_{\boldsymbol{u}} \cdot [\mathbf{c}\_{n-1}, \mathbf{x}\_{\boldsymbol{n}}]) + \mathbf{b}\_{\boldsymbol{u}} \\ \Gamma\_{\boldsymbol{r}} = \sigma(\boldsymbol{\mathcal{W}}\_{\boldsymbol{r}} \cdot [\mathbf{c}\_{n-1}, \mathbf{x}\_{\boldsymbol{n}}]) + \mathbf{b}\_{\boldsymbol{r}} \\ \mathbf{c}\_{n} = \Gamma\_{\boldsymbol{u}} \ast \left( \tanh(\boldsymbol{\mathcal{W}}\_{\boldsymbol{c}}[\Gamma\_{\boldsymbol{r}} \ast \mathbf{c}\_{n-1}, \mathbf{x}\_{\boldsymbol{n}}]) + \mathbf{b}\_{\boldsymbol{c}} \right) + (1 - \Gamma\_{\boldsymbol{u}}) \ast \mathbf{c}\_{n-1} \end{array} \tag{1}$$

where *n* = 1, 2, ··· , 17, **Γ***<sup>u</sup>* and **Γ***<sup>r</sup>* are update and reset gates, respectively, *c* is the hidden state, *<sup>σ</sup>*(·) is an activation function, and *<sup>x</sup><sup>n</sup>* ∈ R*<sup>d</sup>* is the embedding of the *<sup>n</sup>*th element of the relation token sequence *rseq*.

Each token of the input sequence *rseq* first gets its corresponding d-dimensional embedding by a random initialization, and the resulting embedding sequence is used as the input to the GRU. The relational sequence embedding is the last hidden state representation of the GRU, that is *rseq* = *c*17. Now that we have the relational sequence embedding, which characterizes temporal information, in the next section, we combine it with the head and tail entity embedding in varied loss functions.

**Figure 2.** The temporal tokens.

#### *3.3. Incorporating Uncertain Information and Structural Information*

We leverage two score functions based on semantic matching and translation distance, namely, *S h*,*rseq*, *t unce* and *S h*,*rseq*, *t rank*, and the corresponding loss function consists of two segments *Lunce* and *Lrank*, where *Lunce* characterizes the confidence prediction and *Lrank* models the graph structure information. The first component of the score function *S h*,*rseq*, *t unce* can be employed to predict the confidence scores of triples, and the second one *S h*,*rseq*, *t rank* is mainly designed to complete the missing relation facts. The MSE loss function in UKGE treats the semantic-matching-based DistMult model as its energy score function, which shows satisfactory performance, and therefore, our CTRIEJ model preserves the representation of uncertainty information through the MSE loss function. Specifically, we first obtain the energy function based on DistMult:

$$f = r\_{\text{seq}} \cdot (\hbar \circ \mathbf{t}) \tag{2}$$

where *h* and *t* represent the head and tail entity embedding of the triple, *rseq* denotes the relation sequence embedding obtained by the GRU in the previous step, ◦ is the elementwise product, and · is the inner product. Then, we still leverage two different conversion functions [15] to transform energy scores into confidence scores in the range of 0 and 1:

$$S\left(h, r\_{scq}, t\right)\_{unc}^{logi} = \frac{1}{1 + e^{-\left(wf + b\right)}}\tag{3}$$

$$\left(S(h, r\_{scq}, t)\right)\_{uncc}^{ret} = \min(\max(wf + b, 0), 1)\tag{4}$$

where *w* is a weight, *b* is a bias, and *rseq* is the sequence of relation tokens with time mentioned in the previous section. *S h*,*rseq*, *t logi unce* denotes the confidence score function transformed using the logistic function, and *S h*,*rseq*, *t rect unce* denotes the confidence score function transformed using the bounded rectifier.

The MSE loss function containing positive samples *Dpos* and negative samples *Dneg* is as follows:

$$L\_{\rm unc} = \left| S \left( h, r\_{\rm seq}, t \right)\_{\rm unc} - s \right|^2 + \left| S \left( h', r'\_{\rm seq}, t' \right)\_{\rm unc} \right|^2 \tag{5}$$

where *h*,*rseq*, *t* <sup>∈</sup> *Dpos* is an observed fact in the data set, *<sup>s</sup>* is its confidence score, *h* ,*r seq*, *t* ∈ *Dneg* is a corresponding negative sample through random negative sampling, and the function *S*(·)*unce* can be either *S*(·) *logi unce* or *S*(·) *rect unce*.

Then, the structural loss of the KG that employs the energy function based on TransE is calculated:

$$S\left(h, r\_{seq}, t\right)\_{rank} = -d\left(h, r\_{seq}, t\right) = -\left||h + r\_{seq} - \mathbf{t}||\_{l\_1/l\_2} \tag{6}$$

where ·*l*1/*l*<sup>2</sup> represents the *l*<sup>1</sup> or *l*<sup>2</sup> norm. The smaller the value of the distance function *d h*,*rseq*, *t* , the more likely the triple exists.

Following the TransE model, we can acquire a margin-based pairwise ranking loss function. Since the confidence level of each triplet is varied, we employ the confidence score as the weight of the ranking loss for each sample to obtain the following loss function:

$$L\_{rank} = \mathbf{s} \cdot \max\left(\gamma - S\left(h, r\_{\text{seq}}, t\right)\_{\text{rank}} + S\left(h', r\_{\text{seq}}, t'\right)\_{\text{rank}'}, 0\right) \tag{7}$$

where *γ* > 0 represents a margin hyperparameter. This allows us to focus more on learning those triples with higher confidence scores and cut down on the contribution of those triples with lower confidence scores. To validate the generality of our proposed framework, relatively primitive score functions are employed in both segments above. We can further explore higher performance score functions based on semantic matching and translation distance for integration into our framework in future work.

#### *3.4. Joint Loss Function*

Negative sampling has been shown to be quite effective for learning KG embeddings. The commonly applied uniform negative sampling produces poor-quality negative samples and does not contribute much to the training of the model. Utilizing GAN to generate negative samples can effectively improve the efficiency of negative sampling, but it can also enhance the complexity of the model. To improve the quality of negative sampling without introducing additional model parameters, we leverage the idea of the self-adversarial negative sampling technique proposed in the RoTATE model [29] to our proposed model by figuring the scores of negative samples on the ground of the current entity and relation embeddings. The higher the scores, the higher the weights of the negative samples, so that the contribution of high-quality negative samples to the model can be raised.

In calculating the MSE loss function, we first utilize uniform negative sampling for a visible triplet *h*,*rseq*, *t* to randomly generate *n* negative samples, and then we assign varied weights to negative samples based on the score function of the current entity and relation embeddings:

$$w\_{\rm unc}\left(\left(h\_{\rm i}^{'},r\_{\rm scq;i}^{'},t\_{\rm i}^{'}\right)|\left(h,r\_{\rm scq;i},t\right)\right) = \frac{\exp\,\mathrm{S}\left(h\_{\rm i}^{'},r\_{\rm scq;i}^{'},t\_{\rm i}^{'}\right)\_{\rm unc}}{\sum\limits\_{j=1}^{n}\exp\,\mathrm{S}\left(h\_{\rm j}^{'},r\_{\rm scq;j}^{'},t\_{\rm j}^{'}\right)\_{\rm unc}}\tag{8}$$

where *<sup>i</sup>* <sup>=</sup> 1, 2, ··· , *<sup>n</sup>*, and *wunce<sup>h</sup> i* ,*r seqi* , *t i h*,*rseq*, *t* represents the weight of the *i*th negative sample when computing the MSE loss of the triple *h*,*rseq*, *t* . In this way, we acquire the MSE loss function with the self-adversarial negative sampling technique.

$$L\_{\rm unxc} = \left| \mathcal{S}(h, r\_{\rm seq}, t)\_{\rm unxc} - s \right|^2 \ + \sum\_{i=1}^n w\_{\rm unxc} \left( \left( h\_i', r\_{\rm seq}', t\_i' \right) \left| \left( h, r\_{\rm seq}, t \right) \right. \right) \cdot \left| \mathcal{S} \left( h\_i', r\_{\rm seq}', t\_i' \right)\_{\rm unxc} \right|^2 \tag{9}$$

Similarly, when computing the pairwise ranking loss function, we also employ this technique to assign different weights to negative samples and obtain the final ranking loss function.

$$w\_{\rm rank}\left(\left(h\_{j}^{'},r\_{\rm seq}^{'},t\_{i}^{'}\right)\big|\left(h\_{r}^{'},r\_{\rm seq}^{'},t\_{i}^{'}\right)\_{\rm rank} = \frac{\exp\operatorname{S}\left(h\_{j}^{'},r\_{\rm seq}^{'},t\_{i}^{'}\right)\_{\rm rank}}{\sum\limits\_{j=1}^{n}\exp\operatorname{S}\left(h\_{j}^{'},r\_{\rmseq}^{'},t\_{j}^{'}\right)\_{\rm rank}} = \frac{\exp\operatorname{-d}\Big(h\_{i}^{'},r\_{\rmseq}^{'},t\_{i}^{'}\right)}{\sum\limits\_{j=1}^{n}\exp\operatorname{-d}\Big(h\_{j}^{'},r\_{\rmseq}^{'},t\_{j}^{'}\right)}\tag{10}$$

$$L\_{\rm rank} = \mathbf{s} \cdot \max\left(\gamma - S\left(\mathbf{h}, r\_{\rm seq}, t\right)\_{\rm rank} + \sum\_{i=1}^{n} w\_{\rm rank}\left(\left(\mathbf{h}'\_{i\cdot}, r'\_{\rm seq}, t\right)\right)\left(\mathbf{h}, r\_{\rm seq}, t\right) \cdot S\left(\mathbf{h}'\_{i\cdot}, r'\_{\rm seq}, t\right)\_{\rm rank}, 0\right) \tag{11}$$

Combining Equations (9) and (11), we get the final joint loss function with the selfadversarial negative sampling.

$$\begin{array}{l} L\_{\text{joint}} = L\_{\text{unca}} + L\_{\text{rank}}\\ \|h\|\_{2} \le 1, \|r\|\_{2} \le 1, \|t\|\_{2} \le 1 \end{array} \tag{12}$$

We employ two different computational models for scoring *S*(*h*,*r*, *t*)*unce*, referring to the variant using Equation (3) as CTRIEJ*logi* and the variant using Equation (4) as CTRIEJ*rect*.

#### **4. Experiments**

Our proposed model was evaluated on three tasks: confidence prediction, link prediction, and relation fact classification. Obtaining the confidence scores of existing facts is the goal of confidence prediction, that is, for a given relation fact, with the head and tail entity, relation, and time, the corresponding confidence score should be predicted. The link prediction task aims to forecast the missing relation facts, e.g., given the head entity, relation, and its corresponding time, the missing tail entity should be predicted. Relation fact classification is a binary classification problem. We classified relation facts in wikidata\_5k into strong and weak relation facts according to a given threshold *τ*, and the facts with confidence scores above the threshold were considered strong relation facts, otherwise, they were weak relation facts.

#### *4.1. Datasets*

At present, universal uncertain temporal datasets are not available. We applied the datasets extracted from Wikidata mentioned in [24]. Wikidata contains structured temporal information obtained from various sources using open information extraction (OIE). Ref. [24] obtained over 6.3 million temporal facts from Wikidata with confidence scores for various relations including plays for (>4 million facts), educated at (>6K), member of (>23K), occupation (>4.5K), spouse (>20K), and so on. Several of the extracted datasets are similar in composition, so we chose only one of them, named wikidata\_5k.

**Data preprocessing**. We first performed preprocessing operations on this dataset. The initial confidence scores in wikidata\_5k range from 1 to 10, where 96.4% are less than or equal to 5.0 in the dataset. For normalization, we first bounded the confidence scores to *s* ∈ [1.0, 5.0], and then applied the min-max normalization on *s* to map them into [0.0, 1.0]. After data preprocessing, the wikidata\_5k dataset contained 2233 entities, 6 relations, and 4818 uncertain temporal relation facts with a mean confidence score of 0.269 and a variance of 0.225.

#### *4.2. Experimental Setup*

We divided the dataset into 85% for training, 7% for validation, and 8% for testing. To test if our model could correctly interpret negative links, we added the same number

of negative links as existing relation facts into the test set. We used the Adam optimizer for training and the grid search method to select optimal parameters in the following set: the embedding dimension *d* ∈ {64, 128, 256, 512} of entities, relations, and time; the training batch size *b* ∈ {128, 256, 512, 1024}; the learning rate *lr* ∈ {0.001, 0.005, 0.01}; and the margin value *γ* ∈ {1, 2, 10} in the ranking loss. We used the *L*2-norm when computing the translation distance. Through experiments, we concluded that in the wikidata\_5k dataset, the best parameters for CTRIEJ*logi* were {*d* = 512; *b* = 256; *lr* = 0.001; *γ* = 2}, and the best parameters for CTRIEJ*rect* were {*d* = 128; *b* = 256; *lr* = 0.001; *γ* = 2}. We evaluated the results of all models on the ground of setting the best parameters for each experiment.

#### *4.3. Baselines*

We considered three types of baselines in our comparison, which included the deterministic KG embedding models TransE [6] and DistMult [11], the uncertain KG embedding models UKGE*rect* and UKGE*logi* [15], and the temporal KG embedding models TA-TransE and TA-DistMult [22].


#### *4.4. Confidence Prediction*

**Evaluation metrics**: The goal of confidence prediction is to obtain corresponding confidence scores of the existing relation facts. We acquired the confidence score for each relation fact through Equation (3) or Equation (4) and used the MSE and mean absolute error (MAE) as evaluation metrics for good or bad prediction. The smaller the MSE and MAE, the more accurate the prediction and the better the model performance.

**Experimental results**: The confidence prediction results are shown in Table 1. The deterministic KG representation learning model could not predict the confidence score, so we only employed the uncertain KG embedding model UKGE as the benchmark model. In general, on the wikidata\_5k dataset, both of our variant models outperformed the corresponding UKGE variants, and CTRIEJ*rect* performed best on both MSE and MAE. Compared with the best-performing benchmark model UKGE*rect*, CTRIEJ*rect* reduced the MSE by approximately 13.8% and the MAE by approximately 19.6%. Our proposed model outperformed UKGE on the task of confidence prediction, showing that incorporating temporal and structural information into the model could help more accurately predict confidence scores for relation facts.


**Table 1.** MSE and MAE of relation fact confidence prediction (×10<sup>−</sup>2).

#### *4.5. Link Prediction*

**Evaluation metrics**: The link prediction is a typical KG embedding evaluation task, i.e., predicting the missing head or tail entities based on known entities and their relations, or sometimes it means predicting the corresponding relations based on known head entities and tail entities. In the experiments of this paper, we forecast the missing tail entities through the known head entities, relations, corresponding temporal information, and uncertainty information. We obtained the plausibility ranking of each candidate tail entity via computing the score function, and then we calculated the evaluation metrics Hit@K and the average rank. Among them, Hit@K denoted the proportion of candidate tail entities ranked in the top K where the correct tail entities existed, and the average rank was the average of the ranking values for the correct tail entities. Since the confidence score of each triple varied, we followed the PASSLEAF model [17] to linearly weight Hit@K and the average rank to obtain WH@K and WMR as follows:

$$\mathcal{W}H\otimes K = \frac{\sum\_{\{h, r\_{\text{seq}}, t, s\} \in T\_K} s}{\sum\_{\{h, r\_{\text{seq}}, t, s\} \in T}} \tag{13}$$

$$WRR = \frac{\sum\_{\{h, r\_{seq}, t, s\} \in T} \mathbf{s} \cdot rank\_{\{h, r\_{seq}, t\}}}{\sum\_{\{h, r\_{seq}, t, s\} \in T} \mathbf{s}} \tag{14}$$

where *T* represents the test dataset, *TK* represents the top *K* data in the test set, and *rank*(*h*,*rseq*,*t*) represents the ranking value of the triplet *h*,*rseq*, *t* . We utilized the sum of the energy function through a translation distance and the confidence prediction function through semantic matching as the score function to rank the candidate tail entities. When computing WH@K and WMR with the test set, candidate tail entities may exist in both the training set and validation set, and they cannot be considered wrong. Hence, we removed the candidate tail entities that occurred in the training set and validation set to acquire the filtered WH@K and WMR. The larger the WH@K and the smaller the WMR, the better the model performance. For WH@K, we conducted experiments for *K* = 2 and *K* = 10, respectively.

**Experimental results**: The results of WMR, WH@2, and WH@10 are reported in Table 2. It can be seen that the CTRIEJ models generally outperformed the benchmark model, CTRIEJ*logi* performed best on WMR, and CTRIEJ*rect* performed best on WH@2 and WH@10. The deterministic KG embedding models TransE and DistMult did not perform as well as our proposed model because they did not consider the temporal information and confidence scores. UKGE performed poorly also because it only considered the confidence scores and did not leverage the temporal information and the structural information. TA-TransE and TA-DistMult only embedded temporal information, so the performance was not as good as that of our proposed model. Overall, for the task of link prediction, our model performed the best, followed by the deterministic KG embedding models and the temporal KG embedding models, and finally the UKGE model, which also showed the importance of temporal information and structural information to the model. In this paper, we employed the sum of the energy function based on translation distance and confidence prediction function as the evaluation function, and we can explore better function fusion methods to rank the triples in the future.


**Table 2.** Tail entity prediction.

#### *4.6. Relation Fact Classification*

**Evaluation metrics**: We set the confidence score threshold *τ* = 0.3 to classify the strong and weak relations for uncertain temporal relation facts. Under this setting, 36.03% of the relation facts in wikidata\_5k were considered strong relations. By fitting a function between the predicted confidence scores in the training set and their relation categories, we obtained a binary classification model that was applied to classify relation facts in the test set. We used the F-1 score and accuracy to evaluate how well the models classified.

**Experimental results**: The results are shown in Table 3. Overall, our two variant models outperformed the baseline models. From the perspective of F-1 scores, the results of the baseline models did not have much difference, and our two variant models greatly improved the evaluation results. Among them, CTRIEJ*rect* had the best result, which was nearly 29.6% higher than the best-performing baseline model TA-TransE. In terms of accuracies, our model slightly outperformed the baseline models, with CTRIEJ*logi* performing the best, outperforming the best-performing baseline model DistMult by 2.1%. In conclusion, since our model embedded confidence scores, temporal information, and structural information simultaneously, the performance was better than that of the deterministic KG embedding models, the UKGE model, and the temporal KG embedding models.

**Table 3.** F-1 scores (%) and accuracies (%) of relation fact classification.


#### *4.7. Ablation Study*

To verify the effect of incorporating temporal and structural information, and adopting the self-adversarial negative sampling method in our model, we took the variant CTRIEJ*logi* as an example and proposed its three simplified versions, called CTRIEJ*t*−, CTRIEJ*s*−, CTRIEJ*n*−. In CTRIEJ*t*−, we only kept the head entity, tail entity, relation, and corresponding confidence score of each relation fact and removed their time information. In CTRIEJ*s*−, we reserved the MSE loss function for confidence prediction and removed the ranking loss function characterizing structural information. In CTRIEJ*n*−, we utilized a uniform negative sampling method to obtain negative samples.

We experimentally tested the four evaluation indicators of MSE, MAE, F-1 score, and accuracy on these three models, and the results are shown in Table 4. It can be seen that the three simplified versions did not perform as well as the source model CTRIEJ*logi*, thus verifying the effectiveness of our proposed model.


**Table 4.** MSE (×10<sup>−</sup>2), MAE (×10<sup>−</sup>2), F-1 score (%), and accuracy (%).

#### **5. Conclusions and Future Work**

In this paper, we proposed an embedding model, the CTRIEJ model, for uncertain temporal KGs. The model leveraged a GRU-based sequence model to incorporate temporal information into the embedding of relation sequences and then tied in semantic-matchingbased and translation-distance-based energy functions to integrate the confidence scores and structure information of KGs into a unified framework. Moreover, a self-adversarial negative sampling technique was adopted to generate negative samples for training our model. The CTRIEJ model outperformed other benchmarks in three downstream tasks: confidence prediction, link prediction, and relation fact classification. In future work, we will investigate how to integrate better-performing embedding models into our framework and how to better utilize these score functions for evaluating downstream tasks. In addition, predicting the relation facts and the corresponding confidence scores that exist at future moments in uncertain temporal KGs is another topic worth investigating.

**Author Contributions:** Conceptualization, T.L. and W.W.; methodology, T.L.; software, T.L.; validation, T.L., T.W. and X.L.; writing—original draft preparation, X.Z.; writing—review and editing, M.H. All authors have read and agreed to the published version of the manuscript.

**Funding:** This work is supported in part by the National Natural Science Foundation of China under Grant 72101263.

**Institutional Review Board Statement:** Not applicable.

**Informed Consent Statement:** Not applicable.

**Data Availability Statement:** Not applicable.

**Acknowledgments:** National University of Defense Technology.

**Conflicts of Interest:** The authors declare no conflict of interest.

#### **References**


**Disclaimer/Publisher's Note:** The statements, opinions and data contained in all publications are solely those of the individual author(s) and contributor(s) and not of MDPI and/or the editor(s). MDPI and/or the editor(s) disclaim responsibility for any injury to people or property resulting from any ideas, methods, instructions or products referred to in the content.

MDPI St. Alban-Anlage 66 4052 Basel Switzerland Tel. +41 61 683 77 34 Fax +41 61 302 89 18 www.mdpi.com

*Mathematics* Editorial Office E-mail: mathematics@mdpi.com www.mdpi.com/journal/mathematics

MDPI St. Alban-Anlage 66 4052 Basel Switzerland

Tel: +41 61 683 77 34

www.mdpi.com

ISBN 978-3-0365-7263-5