**About the Editors**

#### **Yanhui Guo**

Yanhui Guo received his Ph.D. degree from the Department of Computer Science, Utah State University, USA. He was a research fellow in the Department of Radiology at the University of Michigan and an assistant professor at St. Thomas University. Dr. Guo is currently an associate professor in the Department of Computer Science at the University of Illinois Springfield. Dr. Guo's research area includes computer vision, machine learning, data analytics, neutrosophic set, computer-aided detection/diagnosis, and computer-assisted surgery. He has published 3 books, more than 110 journal papers, and 40 conference papers, completed more than 10 grant-funded research projects, has 2 patents, and worked as an associate editor for different international journals and a reviewer for top journals and conferences. Dr. Guo successfully applied a neutrosophic set to image processing in 2008 and has published many research works in this area. Dr. Guo was the co-founder and chief scientist of MedSights Tech Inc., a top technology company focused on a computer-assisted surgery system. Dr. Guo was awarded a University Scholar in 2019, the university system's highest faculty honor, which recognizes outstanding teaching and scholarship.

#### **Deepika Koundal**

Deepika Koundal is currently associated with the University of Petroleum and Energy Studies, Dehradun. She received recognition and an honorary membership from the Neutrosophic Science Association from University of Mexico, USA. She was also selected as a Young scientist in 6th BRICS Conclave in 2021. She also received the best paper award 2023 from JVCI—Elsevier. She received Master and Ph.D. degrees in computer science and engineering from the Panjab University, Chandigarh, in 2015. She received a B. Tech. degree in computer science and engineering from Kurkushetra University, India. She was the awardee of the research excellence award given by UPES in 2023 and 2022 and by Chitkara University in 2019. She has published more than 100 research articles in reputed SCI- and Scopus-indexed journals, conference proceedings, and three books. She served as a Guest Editor for *Computers & Electrical Engineering*, *Internet of Things Journals* and *IEEE Transaction of Industrial Informatics*, *Computational and Mathematical Methods in Medicine*. She is also serving as Associate Editor for *IEEE Transactions in Artificial Intelligence*, *Healthcare Analytics*, *Supply chain Management* and the *International Journal of Computer Applications*. She also has served on many technical program committees as well as organizing committees and has been invited to give guest lectures and tutorials in faculty development programs, international conferences and summer schools. Her areas of interest include Artificial Intelligence, Biomedical Imaging and Signals, Image Processing, Soft Computing, Machine Learning/ Deep Learning. She has also served as a Reviewer for many repudiated journals of *IEEE*, *Springer*, *Elsevier*, *IET*, *Hindawi*, *Wiley*, and *Sage*.

#### **Rashid Amin**

RASHID AMIN works as an Assistant Professor at the Department of Computer Science, University of Chakwal, Pakistan. Before this, he worked as a Lecturer at the Department of Computer Science, University of Engineering and Technology, Taxila, Pakistan, for more than seven years and at the University of Wah, Wah Cantt, Pakistan, for four years. He received Ph.D. degree from Comsats University Islamabad, Wah campus. His area of research is Hybrid Software Defined Networking. He received MS Computer Science (MSCS) and Master of Computer Science (MCS) degrees from the International Islamic University, Islamabad. He has supervised many MS-level students, and five PhD students are working under his supervision. His current research interests include SDN, HSDN, Distributed Systems, P2P, Machine Learning, and Network Security. He has published several research papers on hybrid SDN, SDN, Clouds, IoT, and machine learning in well-reputed journals (like *IEEE Communication Surveys & Tutorial*, *IEEE Access*, *IEEE TEM*, *Sensors*, *Electronics*, *AIHC*, *CIN*, etc.). He has also served as a reviewer for international journals (e.g., *NetSoft*, *LCN*, *Globecom*, *Fit*, *IEEE Wireless Communication*, *IEE IoT*, *IEEE J - SAC*, *IEEE Access*, *IEEE System Journal*, etc.)

### *Editorial* **Deep Learning in Big Data, Image, and Signal Processing in the Modern Digital Age**

**Deepika Koundal 1,\*, Yanhui Guo 2,\* and Rashid Amin <sup>3</sup>**


#### **1. Introduction**

Data, such as images and signals, are constantly generated from various industries, including the internet [1,2]. As a result, new technologies have surfaced to track the origin of data and determine their potential for collection, quantification, decoding, and analysis. The analysis of big data, signals, and images has gained significance due to the vast amount of domain-specific and valuable information they contain. These data are crucial for addressing issues such as national intelligence, cyber security, marketing, medical informatics, and fraud detection [3]. Deep learning techniques are highly popular in today's modern digital age. They enable the analysis and learning of substantial quantities of unsupervised data, making them significant for information processing when raw data are mostly uncategorized and unlabeled [4]. Moreover, these applications can include working with medical images as well as signal processing for wellness devices, remote monitoring, and neural devices [5]. Industrial data may also be used for early warning alert systems in assembly lines, whereas big data can be derived from huge forms like electronic health records and hospital information systems [6,7]. This Special Issue aimed to delve into the application of deep learning for addressing significant challenges related to big data, images, and signals.

#### **2. Brief Description of the Contributions**

Aleem et al. discussed machine learning (ML) models for diagnosing depression in (Contribution 1). The depression diagnosis model was presented with data extraction, preprocessing, detection, classification, and evaluation. In (Contribution 2), four pre-trained models, MobileNetV2, VGG16, DenseNet121, and InceptionV3, were proposed for detecting Hurricane-Harvey-induced building destruction in the Greater Houston region (2017). Tahir et al. (Contribution 3) considered m7G-LSTM to predict N7-methylguanosine sites. The LSTM model demonstrated superior performance in identifying N7-methylguanosine sites compared to the CNN model. Wang et al. introduced a GAN-based data augmentation model for accurately segmenting ischemic stroke. Experimental evaluation showed superior performance and high-quality generated stroke images compared to alternative methods (Contribution 4). Albahli et al. introduced the sentiment lexicon and employed ELM and RNN for predicting the stock market. Its performance was evaluated using Twitter data and the Sentiment140 dataset across ten different brands (Contribution 5). In (Contribution 6), an innovative and robust vehicle detection system was proposed based on the DNN You Only Look Once (YOLOv2) with DenseNet-201 used for the feature extraction. In (Contribution 7), the authors analyzed cloud/sky image classification features. Advanced GAN technologies (KernelGAN, ESRGAN, PatchGAN, etc.) were utilized to estimate degradation kernel and inject noise for Super-Resolution (SR) of Sentinel-2 remote-sensing images (Contribution 8). In (Contribution 9), the processing of features

**Citation:** Koundal, D.; Guo, Y.; Amin, R. Deep Learning in Big Data, Image, and Signal Processing in the Modern Digital Age. *Electronics* **2023**, *12*, 3405. https://doi.org/10.3390/ electronics12163405

Received: 8 August 2023 Accepted: 9 August 2023 Published: 10 August 2023

**Copyright:** © 2023 by the authors. Licensee MDPI, Basel, Switzerland. This article is an open access article distributed under the terms and conditions of the Creative Commons Attribution (CC BY) license (https:// creativecommons.org/licenses/by/ 4.0/).

1

<sup>1</sup> School of Computer Science, University of Petroleum & Energy Studies, Dehradun 248007, India

based on a gray-level co-occurrence matrix was presented. The method successfully detected COVID-19 pictures and produced favorable experimental results when it integrated the multi-head self-attention with Residual Neural Network (RNN) for recognizing X-ray images for testing.

The suggestion of a more straightforward bubble model and associated solver optimization technique addresses the problem of insufficient realistic simulation and convoluted solutions for bubble-motion behavior in water. First, the computation was made simpler by ignoring the internal bubble velocity and creating the bubble model by just considering the net flow of the intake and outflow bubbles (Contribution 10). Hamdi et al. (Contribution 11) developed a classification model to detect COVID-19 in X-rays images of the human chest. Unbalanced classes were addressed using preprocessing techniques, including conditional GANs for data generation. Transfer learning was used to fine-tune the VGG16 model after pre-training it on ImageNet. A fog-based anomaly detection system was introduced to demonstrate the effective utilization of fog nodes to decentralize cloudbased architectures for IoT networks (Contribution 12). In (Contribution 13), a Multi-Access Edge Computing (MEC) system was designed to minimize energy consumption by employing multiple Mobile Devices (MDs) and servers. For identifying low-resource malware families, the authors looked at cross-family knowledge transmission (Contribution 14). They assessed the knowledge transfer through supportive scores between families and presented the Sequential Family Selection algorithm to improve detection. For COVID-19 image classification, a BoT-ViTNet model based on ResNet50 was provided (Contribution 15). It incorporated a multi-scale self-attention block to enhance global information modeling and a two-resolution transformer-based vision transformer block for fusing local and global information in complex lesion regions.

(Contribution 16) presented a novel VPN traffic classification method by means of Packet Block images. The notion of Packet Block, which is the accumulation of continuous packets in the same direction, is suggested in a traffic categorization approach based on deep learning. The Packet Block's characteristics were taken from network traffic and converted into graphics. In (Contribution 17), a novel scoring function was proposed to rank Fermatean Fuzzy numbers, addressing Fermatean Fuzzy uncertainty within a precise setting. An integration of evolution-based GA and deep CNN VGG16 was presented (Contribution 18). VGG16 performed feature learning on an eight-vehicle class dataset, followed by feature selection using GA. Classification was then carried out using the SVM classifier. Furthermore, (Contribution 19) introduced the temporal and spatial-fusionnetwork-model-based human action recognition method. In (Contribution 20), to evaluate how well three deep learning (DL) approaches performed in predicting the individuals' eyes' state from the EEG information, EEG signals were gathered. It did this by using a DL framework to separate out the best vector quantization from the underlying VQ systems, introducing a discriminative approach to LVQ. In (Contribution 21), a new ensemble model with data mining techniques was presented to predict student performance. The datasets incorporated features like student behavior in parental involvement in academic performance, online classes, and demographic information, demonstrating a significant connection between student conduct and performance. (Contribution 22) Deep deformable was used to create an aesthetic font style transfer network. The stylistic components of an image were translated into the text of a text image, and the font distortion was controlled by adjusting the parameters to create a range of style migration. Furthermore, (Contribution 23) introduced FedTCM, a federated two-tier cache scheme that mitigated the impact of Non-IID (Non-Independent and Identically Distributed) data on user behavior modeling. A neural-network-based sign language recognition model was presented, utilizing an assistive glove to capture real-time data (Contribution 24). To collect signals for the alphabet and numbers, sensor-based assistive gloves are being developed. These symbols make up a very small but crucial portion of the ASL vocabulary since they are crucial to fingerspelling, a common technique for conveying emphasis, lexical gaps, personal names, and technical words. A self-assembled dataset of isolated static digit, alphabetic, and alphanumeric

character postures was utilized to train a fully connected neural network using a scaledconjugate-gradient-based backpropagation technique. In conclusion, Contribution 25 presents a feature-guide conditional generative adversarial network (FG-CGAN). In order to reduce the identity difference between the input and output face images of the generator and maintain the identity of the input facial image during the generating process, a feature guide module introduces both perceptual loss and L2 loss.

#### **3. Conclusions**

This Editorial introduces 25 research articles focusing on the applications of deep learning in the modern digital age. The goal was to collect relevant contributions in the modern digital age such as industry, education, healthcare, and security. The innovative approaches presented in this Special Issue are expected to be regarded as interesting and constructive and achieve recognition from the scientific community and international industry. The research results showcased in this collection anticipate more active development and research in the realm of deep learning in the future. To achieve this, future approaches could involve leveraging deep learning models to improve prediction accuracy and enhance the reliability of prediction models.

**Author Contributions:** Conceptualization, D.K.; formal analysis, R.A. and D.K., writing-review and editing, D.K., R.A. and Y.G.; supervision, D.K. All authors have read and agreed to the published version of the manuscript.

**Funding:** This research received no external funding.

**Acknowledgments:** We thank the writers for their important contributions to this Special Issue. We would like to express our gratitude to the hardworking reviewers for their thorough and prompt evaluations, which significantly raised the standard of this publication. Finally, we would like to express our gratitude to the editorial staff of *Electronics* for their continuous support in making this Special Issue possible.

**Conflicts of Interest:** The authors declare no conflict of interest.

#### **List of Contributions**


#### **References**


**Disclaimer/Publisher's Note:** The statements, opinions and data contained in all publications are solely those of the individual author(s) and contributor(s) and not of MDPI and/or the editor(s). MDPI and/or the editor(s) disclaim responsibility for any injury to people or property resulting from any ideas, methods, instructions or products referred to in the content.

### *Article* **Face Aging with Feature-Guide Conditional Generative Adversarial Network**

**Chen Li 1, Yuanbo Li 1, Zhiqiang Weng 2, Xuemei Lei 3,\* and Guangcan Yang <sup>1</sup>**


**Abstract:** Face aging is of great importance for the information forensics and security fields, as well as entertainment-related applications. Although significant progress has been made in this field, the authenticity, age specificity, and identity preservation of generated face images still need further discussion. To better address these issues, a Feature-Guide Conditional Generative Adversarial Network (FG-CGAN) is proposed in this paper, which contains extra feature guide module and age classifier module. To preserve the identity of the input facial image during the generating procedure, in the feature guide module, perceptual loss is introduced to minimize the identity difference between the input and output face image of the generator, and L2 loss is introduced to constrain the size of the generated feature map. To make the generated image fall into the target age group, in the age classifier module, an age-estimated loss is constructed, during which L-Softmax loss is combined to make the sample boundaries of different categories more obvious. Abundant experiments are conducted on the widely used face aging dataset CACD and Morph. The results show that target aging face images generated by FG-CGAN have promising validation confidence for identity preservation. Specifically, the validation confidence levels for age groups 20–30, 30–40, and 40–50 are 95.79%, 95.42%, and 90.77% respectively, which verify the effectiveness of our proposed method.

**Keywords:** face aging; feature guide; information preserving; generative adversarial network; age classifier module

#### **1. Introduction**

Face aging, also known as age image generation or age regression problem [1–4], can be defined as the process of aesthetically rendering a face image while making the processed image visually appealing with natural aging or rejuvenation of the human face. Face aging has a broad range of applications in different fields, such as cross-age facial recognition, lost children searching, and audio–visual entertainment.

An ideal face aging algorithm should possess the following key characteristics: authenticity, identity preservation, and accuracy when generating images within the target age group. Previous research on facial aging has primarily focused on two categories of methods: physical model-based [4–7] and prototype-based [2,8,9]. Physical model-based methods rely on adding or removing age-related features, such as wrinkles, gray hair, or beards, that align with image generation rules. The prototype-based method refers to first counting the average face of an age group, and then using the differences between different age groups as an aging pattern to synthesize aging faces, which does not preserve identity well. However, these methods often lack a deep understanding of facial semantics, resulting in generated images that may not be authentic.

Along with the rapid development of deep neural networks, generative adversarial networks (GANs) have drawn much attention from researchers in the face aging field [10–16]

**Citation:** Li, C.; Li, Y.; Weng, Z.; Lei, X.; Yang, G. Face Aging with Feature-Guide Conditional Generative Adversarial Network. *Electronics* **2023**, *12*, 2095. https:// doi.org/10.3390/electronics12092095

Academic Editor: Gwanggil Jeon

Received: 28 February 2023 Revised: 9 April 2023 Accepted: 12 April 2023 Published: 4 May 2023

**Copyright:** © 2023 by the authors. Licensee MDPI, Basel, Switzerland. This article is an open access article distributed under the terms and conditions of the Creative Commons Attribution (CC BY) license (https:// creativecommons.org/licenses/by/ 4.0/).

and have been proven that they can generate images with better quality, identity consistency, and aging accuracy compared with the traditional methods. Many studies [17,18] have used unpaired face aging data to train models. However, these methods mainly focus on face aging itself and ignore other key conditional information of the input face (e.g., facial attributes). As a result, incorrect face attributes may appear in the generated results, and the identity information of the generated face cannot be preserved well. In order to suppress such undesired changes in semantic information during face aging, many recent face aging studies [19,20] have attempted to supervise the output by enforcing identity consistency, which to some extent preserves identity information. However, significant unnatural variation in facial attributes is still observed. This indicates enforcing identity consistency alone is not sufficient to achieve satisfactory face aging performance.

To combat this, Variational Autoencoders (VAEs) [21] are combined with GANs to generate new images. Age estimation is used to help implement the generation of aging face images [15]. The subsequent Pyramid Face Aging-GAN [22] combines a pyramid weight-sharing scheme is combined to ensure that face aging changes slightly between adjacent age groups and dramatically between distant age groups. Most existing GANbased methods [18,23] usually use pixel-level loss to train the model to preserve identity consistency and background information. However, to minimize the Euclidean distance between the synthesized image and the input image, the aging accuracy of its generated results is not very high, which indicates that preserving good identity information does not mean a reasonable aging result.

To more effectively preserve the identity information of faces in face aging tasks, a new face aging framework called Feature-Guide Conditional Generative Adversarial Network (FG-CGAN) is proposed in this article. Compared with existing methods in the literature, to ensure that the generative model preserves the identity information of face images well, feature-guide methods are introduced. Specifically, it is to constrain the network features generated in the process of image generation, generating the image and the identity information features of the original image and the generated image. At the same time, to make the generated image should fall into the target age group, an age classifier module is attached to the discriminator network. Finally, abundant comparison experiments are conducted. To summarize, the main contributions are as follows:


#### **2. Related Work**

#### *2.1. Face Aging Methods*

In this part, representative and inspiring works of face aging are reviewed. Until now, existing face aging research can be divided into 3 phases, physical model-based methods, prototype-based methods, and deep generative-based methods.

**Physical model-based methods:** as seen in early face aging applications [4–7], intuitively adding or smoothing "age factors" to an image is a simple way to simulate the appearance of the face at a target age. The advantage of these methods is that they are easily applicable because they only require adding artificial elements to existing face images. However, these methods do not guarantee the visual authenticity of the generated faces. Moreover, the preservation of identity information is not considered.

**Prototype-based methods:** the prototype-based methods take the average face of each age group as the prototype and map the differences between the age group to the input face image. Refs. [2,8] exploit the differences between the average faces of different age groups to transform the age patterns. However, these methods usually ignore the differences between individuals. At the same time, some important age characteristics may be lost due to averaging.

**Deep generative-based method:** with the rapid development of deep learning, the deep generative model is widely used to synthesize the aging face image. Refs. [24,25] use the deep generative model with temporal architectures to synthesize the face images. However, the most critical problem of these methods is that multiple facial images of the same person at different ages are needed in the training stages, so their potential in practical applications is limited. The appearance of GANs reshapes the research pattern of face generation. Grigory et al. [26] propose a conditional generative adversarial network, which brings a supervised learning scheme for GAN. It restricts the feature information of the potential vectors in the down-sampling stage by L2 regularization so that the network can minimize the difference between the original and the generated image during the training process, which helps preserve the identity information. However, it may cause the generated image to be unable to reach the target age group.

To generate enhanced aging face images, Neha Sharma et al. [27] use a fusion-based generative adversarial network. Zhu et al. [10] propose a spatial attention mechanismbased GAN. It limits image modification to areas that are closely related to age changes, which helps maintain high visual fidelity when synthesizing images under unknown circumstances. Ref. [11] achieves good results in the pixel-based image migration task. However, attention/semantic-based work cannot describe local areas of the input face well. It is necessary to deal with a deeper entanglement between age characteristics and identity information.

Hence, improved generative adversarial network methods are discussed. Dual-GAN [12], DiscGAN [13], and CycleGAN [14] utilize circular consistency between the input image and the generated image, which keeps the identity of the generated face the same as the input face. However, because the generative network achieves stable generation quality in a multimodal state, even inputs are ignored. Wang et al. [15] propose to add identity preservation optimization measures to the generator, i.e., using a hidden vector-based L2 paradigm to decrease the loss of identity information caused during down-sampling. However, such an approach may cause a reduction in the diversity of the generated images, and also, it forces the reservation of unnecessary identity information, which may generate images not belonging to our expected age groups.

Most previous works focus on progression and neglect the discussion of a wide range of age transformations. To address this issue, Makhmudkhujaev et al. [28] propose Re-Aging GAN, which can learn personalized age features through high-level interaction between a given identity and target age. These features include identity and target age information, which provides an important indication of how the input face should be at a specific age. How to take advantage of relevant facial conditions is also a research focus. Shen et al. [29] propose a framework Interface-GAN to learn the facial semantic information in the potential space, which can truly operate the corresponding facial attributes without retraining the model. Controlling the attribute operation more accurately can change the facial posture and repair the artifacts generated by GAN accidentally. Liu et al. [30] propose embedding the facial attribute vector into the generator and discriminator in order to stimulate each synthesized aging face image to abide by its corresponding input attribute. Yang et al. [16] use coupling to simulate the constraints of specific features of intrinsic subjects and age-specific facial changes over time, respectively. To render realistic facial details, the advanced age-specific features conveyed by synthetic faces are estimated at multiple scales by pyramid adversarial discriminators. A3GAN [31] embeds facial attribute

vectors into generators and discriminators; this helps synthesized faces be faithful to the attributes of the corresponding input. Moreover, it also utilizes attention mechanisms to limit modifications of age-relevant regions and therefore preserve image detail.

In conclusion, how to take advantage of useful conditions and preserving the identity information of the input face while guaranteeing that aging and visual accuracy is still the main challenge of face aging.

#### *2.2. Age Prediction/Estimation Methods*

Age prediction involves recognizing and predicting age information from facial images, analyzing and processing various features such as wrinkles, eye bags, and facial contours that are important in the aging process of the face. Accurate age prediction provides fundamental data to help facial aging algorithms simulate and predict changes that occur in the face over time. This task is critical for applications related to face recognition, surveillance, and human-computer interaction.

In recent years, deep learning technology has greatly improved the performance of facial age prediction. Deep learning methods can automatically learn feature representations and model parameters without the need for manual design and extraction of features, resulting in improved prediction performance. For instance, Levi et al. [32] propose a method for age prediction using convolutional neural networks (CNN), achieving excellent prediction results. Rothe et al. [33] propose a method for age prediction based on VGG-16 architecture, which can predict real and apparent ages from a single image without the need for facial landmarks.

Li et al. [34] propose a label refinery network (LRN) and a slack regression refinement method that can progressively learn specific age-label distributions for different facial images without assumptions on fixed distribution formulations. To reduce the overlap of face features between adjacent ages and improve age prediction accuracy, Xia et al. [35] propose a face age estimation method that considers various factors affecting biometric information in a face image. The proposed multi-stage feature constraints learning method refines features through three stages to reduce overlap and increase the discrimination between age ranges, resulting in improved accuracy and fast estimation.

Our proposed framework builds upon the foundation of GAN, with the generator utilizing the encoder–decoder architecture to generate images while the discriminator follows existing methods. While our encoder and decoder network structure is similar to existing works, we introduce several modifications that enhance performance (see details below). One unique aspect of our framework is the integration of two additional subnetworks: the feature guide module and the age classifier module. These modules can provide information about what the age of a face should look like. They can play a guiding role in generating images. The method proposed in this article not only preserves the identity information of the face but also produces a more visible and intuitive age change. The generated images look more realistic and reliable, and facial features are well preserved, which is crucial in applications such as age progression and regression.

#### **3. Methodology**

As illustrated in Figure 1, the proposed framework contains four components, including the generator module, the feature guide module, the age classifier module, and the discriminator module. In order to ensure that the comparisons are fair, age categories are divided into five groups: 10–20, 20–30, 30–40, 40–50, and over 50 years old. One-hot labels are used to indicate the age groups. One-hot is filled with 1 in one dimension to indicate the target age and filled with 0 in other dimensions.

**Figure 1.** Feature-Guide Conditional Generative Adversarial Network (FG-CGAN). FG-CGAN mainly adds a feature guide module and an age classification module to the basic GAN structure, enabling the original face image to generate new images that preserve identity information and have good visual effects.

The purpose of FG-CGAN is to generate an aging face that conforms to the target age group from the original face image. The generator module is composed of an encoder and a decoder. In the process of generating face, the feature guide module extracts the corresponding feature maps of the encoder and decoder, respectively, and makes them conform to L2 constraints. At the same time, L2 is also used to constrain the identity features extracted by the pre-trained network. The discriminator is used to distinguish whether the generated face is true or false. The age classifier module determines whether the age of generated face is within the target age group. The details of this framework are explained below.

#### *3.1. Base Network*

The fundamental architecture of the network comprises a generator and a discriminator that utilize GAN principles. The generator has been modified from the Variational Autoencoder (VAE) [21] architecture to enhance its performance.

**Generator module**: given an input face image *<sup>x</sup>* ∈ *<sup>R</sup>h*×*w*×*<sup>N</sup>* and several target age groups *Cg* ∈ *<sup>R</sup>N*, in which, *<sup>h</sup>* and *<sup>w</sup>* represent the height and width of a feature map, *<sup>N</sup>* represents the number of age groups, respectively. To generate a synthetic face image *xt* within the target age group *Ct*, a generator *G* is built, which refers to the work of the VAE. The synthetic face image can be shown in Equation (1):

$$\mathfrak{x}\_t = G(\mathfrak{x}, \mathbb{C}\_t), \tag{1}$$

*G* consists of both an encoder and a decoder.

The encoder module aims to encode the high-dimensional input *x* into a low-dimensional latent vector, thereby forcing the neural network to learn the most informative features. The encoder module in this paper is constructed with a full convolution neural network. In order to preserve the semantic information of the image, the full connection layer of the full convolution network is replaced with the convolution layer in our proposed network. The input pipes' number of the input layer is adjusted to be consistent with the image matrix dimension after adding one-hot coding. Using this encoder structure allows the generator to obtain the expression learning ability in the hidden space so that the data can be manipulated at the semantic level through interpolation or conditional embedding on the hidden variable space [21].

The decoder module aims at restoring the latent vector of the hidden layer to the initial dimension and making the output *xl* ≈ *x*. The fractional step convolution structure is introduced to construct the decoder network.

**Discriminator module:** this module is used to determine the authenticity of the input synthetic image *xt*, which is functionally consistent with the discriminator model in the original generative adversarial network, that is, *maxD*(*xt*, *y*). *D* is a discriminator to identify the facticity of the synthetic face and *y* is the real target image belonging to the target age group *Ct*.

**Adversarial loss:** for the problem that the original GANs cannot generate images with specific attributes. The core of CGANs [26] is to integrate attribute information into generator *G* and discriminator *D*, where attribute information can be any label information. The objective function of CGANs can be expressed as:

$$\min\_{\mathbf{G}} \max\_{\mathbf{D}} V(\mathbf{D}, \mathbf{G}) = E\_{\mathbf{x} \sim p\_{\mathbf{x}}(\mathbf{x})} [\log \mathcal{D}(\mathbf{x}|\mathbf{C}\_{t})] + E\_{\mathbf{y} \sim p\_{\mathbf{y}}(\mathbf{y})} [\log(1 - D(\mathbf{G}(\mathbf{y}|\mathbf{C}\_{t})))],\tag{2}$$

The distributions of *x* and *y* are represented by *px*(*x*) and *py*(*y*), respectively, where *px*(*x*) and *py*(*y*) are denoting the corresponding probability distributions.

However, CGANs share the same drawbacks as the original GANs as they employ cross entropy as the loss function, which results in generated samples being distant from the decision boundary and only achieving a small loss. This instability in the training process leads to a low-quality output from the generator. In contrast, LSGAN [36] aims to minimize the distance between the generated and real faces, thereby making it difficult for the discriminator to distinguish between them. FG-CGAN adopts the conditional LSGAN function as its adversarial loss:

$$L\_D = \frac{1}{2} E\_{x \sim p\_x(\mathbf{x})} \left[ \left( D(G(\mathbf{x}|\mathbf{C}\_t) - 1)^2 \right) + \frac{1}{2} E\_{y \sim p\_y(y)} \left[ \left( D(G(y|\mathbf{C}\_t)) \right)^2 \right], \tag{3}$$

$$L\_{\mathbb{G}} = \frac{1}{2} E\_{y \sim p\_{\mathcal{Y}}(y)} \left[ (D(\mathcal{G}(y|\mathbb{C}\_t) - 1))^2 \right] \tag{4}$$

#### *3.2. Feature Guide Module*

Preserving the identity information of the input face is a critical requirement in the face aging process. However, only utilizing the adversarial loss to make the generated sample and target sample distributions similar may not adequately preserve identity information. To address this issue, we introduce the identity information preservation function to supervise the image generation process by using the features extracted from the network. Specifically, this function is called the feature guide module.

This module obtains the feature maps of the encoder and decoder separately and compares the corresponding feature maps. This requires that the encoder and decoder be symmetrical when designing the network. The advantage of this is that the extracted corresponding feature maps are of the same size, and they can be fed into the same pretrained network, ensuring that the final output is a one-dimensional vector for metric comparison based on L2 normal form. Thus, to make the corresponding feature maps the same size, *Llayer* is defined in the following Equation (5):

$$L\_{layer} = \sum\_{i \in k} \left| \left| f\_{encoder}^i - f\_{decoder}^i \right| \right|^2 \tag{5}$$

Here *k* represents the total number of network layers, and *f <sup>i</sup> encoder* and *<sup>f</sup> <sup>i</sup> decoder* represent the feature maps of the encoder and the decoder at the *i*-*th* layer, respectively.

For identity consistency, the perceptual loss is introduced to minimize the distance between the input and output face image identities of the generator, as shown in Equation (6):

$$L\_{id} = \sum\_{\mathbf{x} \in p\_x(\mathbf{x})} \left| \left| h\_{id}^{\mathbf{x}\_l} - h\_{id}^{\mathbf{x}} \right| \right|^2,\tag{6}$$

where *h<sup>x</sup> id*(·) represents features extracted from a specific feature layer in the pre-trained model with *x* as input. The difference metric of the paired feature maps can preserve the identity information between the original and generated images. The pre-trained network uses ResNet-34 as the basic network structure to classify the age of the generated images. The reason why the formulas use L2 instead of L1 as the metric is that L1 is pixel-based and strongly supervises each pixel, leading to a conservative approach to generating the original image. Consequently, the generative network may lack diversity in the generated images. The overall loss function for identity preservation is:

$$l\_{identity} = l\_{id} + l\_{layer} \tag{7}$$

#### *3.3. Age Classifier Module*

The age classifier module is used to classify the age group *Ct* of the generated face image *xt*. This module is primarily structured as Resnet-34.

In addition to meeting the requirements of visual perception, the synthetic face image must also satisfy the target age group condition. To achieve this, the generator utilizes the age classifier module to regulate the age distribution of the synthetic image through loss estimation. This enables the generator to generate a synthetic image *xt* that conforms to the target age condition by comparing it with the target image *y*, the loss function of the age classifier module is:

$$d\_{\text{age}} = -\frac{1}{M} \sum\_{s=1}^{M} \sum\_{j=1}^{N} \text{sign} \times \log P\_{\prime} \tag{8}$$

where *M* represents the number of samples. *sign* represents the sign function, if the age group of the sample is equal to the true group *Ct*, take 1; otherwise, take 0.

The L-Softmax loss is introduced to learn the intra-class compactness and inter-class separability between features. This loss function enhances the distinguishability of sample boundaries for different categories by adjusting the inter-class angle boundary constraints. By multiplying the preset constant with the angle between the sample and the ground truth class, an angular margin is created. The strength of the margin around the ground truth category is determined by the preset constant, allowing the L-Softmax loss to be customized according to the task requirements.

#### *3.4. Overall Objective Function*

The final integrated objective function can be obtained by combining the aforementioned equations:

$$I\_{final\text{\textquotedblleft}G} = \lambda\_G l\_G + \lambda\_{identity\text{\textquotedblright}identity} + \lambda\_{\text{age}} l\_{\text{age}} \tag{9}$$

$$l\_{final-D} = l\_{D'} \tag{10}$$

where *λG*, *λidentity*, and *λage* are hyper-parameters used to balance the weight of objective function.

#### **4. Experiments and Evaluation**

#### *4.1. Dataset*

To ensure fair comparisons, the CACD dataset [37] is introduced to evaluate face generation based on identity preservation. This dataset comprises over 160,000 face images with variations in pose, illumination, and expression obtained from 2000 celebrities aged between 16–62. All the images are age-annotated, although not very accurately. We first use target detection to calibrate the face position and then perform various data enhancement processes on the input images, including adjusting saturation and brightness, horizontally flipping the image, randomly rotating the angle, and normalizing the image. The final images used comprise approximately 146,794 images with a resolution of 400 × 400 pixels, which is split into two parts for training and validation, with 90% and 10% of the images, respectively. The face images are divided into five age groups: 10–20, 20–30, 30–40, 40–50, and over 50 years old, with the number of samples in each age group being 8656, 36,662, 38,736, 35,768, and 26,972, respectively. To further validate the effectiveness of our method, we utilize the Morph dataset [38] for testing. Table 1 shows the comparison information of both data sets.

**Table 1.** Specific data of CACD-ours and Morph.


#### *4.2. Experimental Details*

In this section, the FG-CGAN method is compared with IPCGANs [15], acGANs [39], and CAAE [19], which can generate real face images of specific age groups with identity conditions as constraints.

The image size of the input generation network is 128 × 128 × 3. For the training parameters of the network, refer to the strategy of Wang et al. [25]. The age recognition network proposed in Section 3.4 is used as a feature extractor to extract the feature map between symmetric network layers. During the training phase, the batch size is set to 32 and the learning rate to 0.001. For the selection of optimizer, Adam is used as the optimizer, taking into account the characteristics that it is difficult to train and optimize the generated countermeasure network itself. In the end, the training iteration of the whole network setup is 50,000 rounds. The hyperparameter settings are consistent with Fang et al.'s work [40].

#### *4.3. Classification Accuracy Influence for Age Classifier Module*

To supervise the training process of the generative adversarial network, we utilize the age classifier module with varying classification accuracy to identify the optimal classifier network for the current generative adversarial network. During the training process of the age classifier module, we preserve the models with training accuracy of 62, 72, and 82, respectively. These three models are then incorporated into the model training process. The experimental results for different accuracies are presented in Figure 2.

(**a**) (**b**)

**Figure 2.** The results from different age classification accuracies: (**a**) the effect of a classification accuracy of 82; (**b**) the effect of a classification accuracy of 62.

Figure 2a depicts the final training output when using the age classifier module with an accuracy of 82. The generated image in Figure 2a displays a trend where the generative adversarial network is less likely to alter the face image due to the larger penalty imposed by the age classifier module. A higher classification accuracy can reduce diversity in age groups, causing the generated images to remain similar or unchanged from the original images. Conversely, using a model with low classification accuracy, as seen in Figure 2b, results in obvious artifacts and unrealistic features in the generated images. A lower

classification accuracy leads to less punishment during training, causing the generative adversarial network to generate more images with a chance-based approach. Based on these results, a model with a classification accuracy of 72 is selected as the age classifier module. The weights of the age classifier module are set according to the experimental findings in IPCGANs.

#### *4.4. Intuitive Visual Display*

**Visual effect:** Figure 3 shows the output of images randomly selected from Morph dataset at the age of 11–20 and 21–30 as inputs. It can be directly seen that the algorithm proposed in this paper has a good effect on the preservation of face identity information and the change of age diversity. Among them, these include changes in hair color, an increase in facial wrinkles, and an enlarged jaw with age. These changes are in line with people's intuitive understanding of changes when the face ages.

**Figure 3.** The generated aged faces by FG-CGAN. The images demonstrate the direct visual effect of FG-CGAN on eight randomly selected input images from the CACD and Morph datasets. As shown, all of these images exhibit a high-quality visual effect of facial aging, showcasing the effectiveness of the model.

**Intuitive comparison:** based on previous work, we compare the effects of different methods of face generation. We randomly extract four face images from the CACD dataset. Different models are then used to generate images corresponding to age groups to visually compare the realism and information preservation of the images. Since IPCGANs do not provide an official implementation, we re-implement the IPCGANs network and use the acGANs algorithm model available on the network. We evaluate the effects of different algorithms and explain the problems and advantages of algorithms in generating images. Figure 4a–d shows the comparison results:

**Figure 4.** Performance comparison with prior works on CACD dataset: (**a**–**d**). Using the same input face images with acGANs, IPCGANs, and FG-CGANs, the synthetic aging images for the age groups of 20–30, 30–40, and 40–50 are displayed.

Figure 4a–d provides a comprehensive comparison of different face-generation methods. The first column of the figure represents the selected original face image, and the last three columns are arranged by age order of 20–30, 30–40, and 40–50. In Figure 4a, the age change of the face generated by the acGANs algorithm is not very evident, and a certain level of smoothing effect is observed. In contrast, the age change of the images generated by IPCGANs and the method proposed in this paper is more apparent, and the identity information of the face is well preserved.

In Figure 4c, acGANs still exhibit a noticeable smoothing effect, and the images suffer from blurred backgrounds. While IPCGANs have a considerable impact on age change and identity preservation, noise appears, resulting in unclear images. The approach proposed in this paper preserves the identity characteristics of the face relatively well and generates images that meet the target age range. Moreover, the images are clearer than those generated by IPCGANs.

In Figure 4b, acGANs result in a certain degree of face identity information loss, and the significant smoothing effect even makes it difficult to recognize the gender of the person in the image. Both IPCGAN and the method proposed in this paper perform well.

In Figure 4d, acGANs still suffer from problems related to face identity information loss, and the age change is not very evident. The IPCGANs differ from the method proposed in this paper for the direction of image aging. IPCGANs prefer to add beards and other elements to show the age difference, while the algorithm in this paper highlights the change characteristics of age from the wrinkles and changes in hair color that occur after face aging.

After conducting the above comparison analysis, it is evident that the method proposed in this paper produces more realistic and reasonable age changes than IPCGANs. The algorithm highlights the features that change with age, such as wrinkles, more frequently. However, both approaches perform well in preserving the identity information of the face. In contrast, the images generated by acGANs exhibit severe dermabrasion and smoothing effects, causing the loss of identity information in the generated images. It is difficult to identify whether the images before and after the generation correspond to the same person or even estimate the gender of the original face. In summary, the approach proposed in this paper and IPCGANs generate images that appear more reliable and realistic as compared to acGANs.

Intuitively looking at the images generated by the three models, the images generated by the work in this paper can not only preserve the identity information features of the face intact but also make the change of age more obvious and intuitive.

**Effectiveness of our algorithm:** the generation effect verification presented above primarily focuses on gradually generating aging images from young images, but this approach does not fully demonstrate the algorithm's effectiveness. To address this, we also verify the spread from middle age to both sides. Specifically, we randomly select the faces of individuals aged 30–45 years and input them into our network to generate images of both young and old age. The experiment is illustrated in Figure 5a–d, and we evaluate the generation effect of each image.

**Figure 5.** Display of face aging results on middle-aged faces by FG-CGAN: (**a**–**d**). Inputting randomly selected face images of individuals aged 30–45 into the network, the resulting images of the same face at a younger and older age are shown. This comparison illustrates the ability of the model to transform a face's appearance across a range of ages.

Figure 5a depicts the original faces of individuals aged 41 and 42 as input, with the generated facial appearance displaying a gradual change in wrinkles and graying of hair.

Similarly, Figure 5b shows the increased wrinkles, graying of hair, and enlargement of the jaw, while preserving the identity information features of the face. The generated images meet the expectations of the experiment. Figure 5c also demonstrates a similar effect to the two groups in Figure 5a,b.

However, in Figure 5d, a small number of whiskers appear on the female face in the first row of images. This can be attributed to the principle of generative adversarial networks, which only fit the distribution of data and may struggle to ensure the strict distinction between the male and female sexes throughout the training process.

#### *4.5. Quality Comparison*

Face aging aims to convert the face of the input image to the target age while preserving personal identity. Therefore, the face aging model can be evaluated from two perspectives: (1) how well the identity of the original image is preserved; (2) what is the quality of the age classification of the generated aging images.

**Identity preservation:** we randomly select 500 samples and calculate the face verification confidence by comparing the input images with the images generated for each age group. This approach can comprehensively evaluate the performance of each generative model in terms of preserving identity and achieving accurate age classification. CACE, acGAN, and IPCGAN are measured in comparison.

To evaluate identity preservation in our face aging experiments, we conducted face verification experiments using FaceNet. For each input image of a young face, not only are the original input and the generated faces compared but also the generated faces themselves. In addition, we verified the face verification rate for each age group using different methods on the dataset.

Table 2 reports verification confidences between synthetic aging images from different age groups of the same face, where high verification confidence indicates consistent preservation of identity information. Notably, as the face ages, authentication confidence decreases, indicating that face aging changes appearance. Table 3 also reports the results of the face verification rate, where the threshold is set to 1.0 in FaceNet. Although IPCGAN has an identity retention module, our proposed feature preservation module can better improve identity preservation.


**Table 2.** Face verification confidence quantitative results.

**Table 3.** Face verification rate quantitative results.


**Accurate age classification:** at this stage, volunteers are asked to make an age prediction of the images. One hundred images from the Morph and CACD datasets are randomly

selected as input to the generative model. Then, different generative models are used to generate images of the target age groups, resulting in each model containing the original 100 images and 300 generated images. Based on 100 original images, we randomly select 20 images and randomly select 20 images from the corresponding 300 generated images to distribute to volunteers for discrimination in two directions.

For face verification, we generate three images for each input image based on the age label. These images are then grouped into three pairs: (input image, age label 0 image), (age label 1 image, age label 2 image), and (age label 3 image, image randomly selected from other generated face images). The first two pairs are compared to verify if the generated images belong to the same person, while the third pair is used to verify if the generated images are similar to other faces. Volunteers complete the face verification task, and we compare the results using different methods. To calculate the accuracy, using the acc formula,

$$\text{acc} = \frac{\left(k\_p + k\_n\right)}{N\_p + N\_n} \tag{11}$$

*Np* and *Nn* represent the total number of sample pairs in the first and second groups, and the number of sample pairs in the third group. *kp* and *kn* represent the number of people considered to be the same in the first and second groups and the number of people considered not to be the same in the third group.

For age classification, volunteers need to estimate the age of the image they receive, that is, vote for the age group to which the image belongs. After scoring multiple votes using different models, record the percentage of face images that report a target age that is consistent with the user's estimated result.

On the other hand, to ensure objectivity in the comparison of effects. We use the face-comparing service provided by Face++ to evaluate how well the generated image matches the face information of the original image. Similarly, we refer to previous work using the VGG-face score to measure image quality. Table 4 shows the result of these experiments.


**Table 4.** Generate image quality assessment results.

#### **5. Conclusions**

Face-based human identification is still a complex problem when it encounters images of the person of different age groups. Cross-age applications such as finding the track missing children after many years, surveillance, and so on are challenging mainly due to the lack of labeled cross-age faces dataset. Hence generating visually realistic faces with GANbased methods seems to be a promising way. However, identity preservation, as well as aging accuracy, are the essential characteristics of most cross-age applications. In this paper, we tackle these issues with a proposed Feature-Guide Conditional Generative Adversarial Network (FG-CGAN), which is composed of four sub-modules, including the generator module, the feature guide module, the age classifier module, and the discriminator module. In the process of face generation, perceptual loss combined with L2 constraint is introduced to minimize the distance between the input and output face image identities in the feature guide module. In the age classifier module, to improve classification accuracy, the ageestimated loss is constructed, in which L-Softmax loss is combined to learn the intra-class compactness and inter-class separability between features. Sufficient experiments are conducted on the widely used face aging dataset CACD and Morph. Encouraging results are obtained, which verify the effectiveness of our proposed method.

In subsequent work, we will address some of the limitations of this method. For example, this method does not consider facial differences between different races. In addition, although the paper introduced L-Softmax losses in experiments to improve classification accuracy, further validation is needed to evaluate the applicability of this method to other data sets and scenarios. Therefore, the universality and applicability of this method need to be further evaluated, and we hope that future research can more effectively address these issues.

**Author Contributions:** Conceptualization, X.L. and G.Y.; methodology, C.L., Y.L., Z.W. and X.L.; software, Y.L. and Z.W.; investigation, G.Y.; writing—original draft preparation, C.L., Y.L. and Z.W.; writing—review and editing, C.L., Y.L. and X.L.; supervision, X.L. and G.Y. All authors have read and agreed to the published version of the manuscript.

**Funding:** This paper is supported by the Research Project of the Beijing Young Topnotch Talents Cultivation Program (Grand No. CIT&TCD201904009), partially by the National Natural Science Foundation of China (Grand No. 62172006 and 61977001) and the Great Wall Scholar Program (CIT&TCD20190305).

**Data Availability Statement:** Data will be made available on request.

**Conflicts of Interest:** The authors declare no conflict of interest.

#### **References**


**Disclaimer/Publisher's Note:** The statements, opinions and data contained in all publications are solely those of the individual author(s) and contributor(s) and not of MDPI and/or the editor(s). MDPI and/or the editor(s) disclaim responsibility for any injury to people or property resulting from any ideas, methods, instructions or products referred to in the content.

**Muhammad Saad Amin 1, Syed Tahir Hussain Rizvi 2,\*, Alessandro Mazzei <sup>1</sup> and Luca Anselma <sup>1</sup>**

<sup>1</sup> Dipartimento di Informatica, Universita degli Studi di Torino, 10149 Torino, Italy


**Abstract:** Sign language recognition is one of the most challenging tasks of today's era. Most of the researchers working in this domain have focused on different types of implementations for sign recognition. These implementations require the development of smart prototypes for capturing and classifying sign gestures. Keeping in mind the aspects of prototype design, sensor-based, visionbased, and hybrid approach-based prototypes have been designed. The authors in this paper have designed sensor-based assistive gloves to capture signs for the alphabet and digits. These signs are a small but important fraction of the ASL dictionary since they play an essential role in *fingerspelling,* which is a universal signed linguistic strategy for expressing personal names, technical terms, gaps in the lexicon, and emphasis. A scaled conjugate gradient-based back propagation algorithm is used to train a fully-connected neural network on a self-collected dataset of isolated static postures of digits, alphabetic, and alphanumeric characters. The authors also analyzed the impact of activation functions on the performance of neural networks. Successful implementation of the recognition network produced promising results for this small dataset of static gestures of digits, alphabetic, and alphanumeric characters.

**Keywords:** assistive glove; American Sign Language (ASL); gesture recognition; neural network (NN); sign recognition

#### **1. Introduction**

In today's world of smart technology, sign language (SL) recognition is a major task. This is also the need for time as it can be used to overcome the communication gap for the Deaf (The cap-case "Deaf" word refers to a community of deaf people who share a language and a culture. In contrast, the lower-case "deaf" refers to the audiological condition of not hearing). Globally, almost every country has Deaf communities (according to the world's population, 15% to 20% of people are part of the deaf population [1]) and people from these communities are not always able to communicate by using the vocal national language in written form. So, in order to help Deaf communities to overcome the language barrier, many researchers try to develop software and hardware translation systems. For this purpose, different methodologies such as sensor based, vision based, or hybrid approaches have been adopted in the literature to design assistive models for capturing sign gestures [2]. These methodologies require the acquisition of posture data made by Deaf people.

Sensor-based prototypes cope with different types of sensors only [3–5]. Choosing a good combination of different sensors is a subjective matter [6]. Based on the dataset and classification requirements, a variety of different sensors can be used collectively. However, this creates a problem. If the number of sensors is increasing, then system complexity and cost are also increasing, and complex systems often result in low or bad accuracy [7]. Similarly, for vision-based approaches, only image-based or video-based data can be analyzed [8]. Usually, there is no proper involvement of sensors in the visionbased model. However, this model also has some drawbacks regarding data extraction

**Citation:** Amin, M.S.; Rizvi, S.T.H.; Mazzei, A.; Anselma, L. Assistive Data Glove for Isolated Static Postures Recognition in American Sign Language Using Neural Network. *Electronics* **2023**, *12*, 1904. https://doi.org/10.3390/ electronics12081904

Academic Editors: Enzo Pasquale Scilingo and Rui Pedro Lopes

Received: 14 November 2022 Revised: 16 February 2023 Accepted: 27 March 2023 Published: 18 April 2023

**Copyright:** © 2023 by the authors. Licensee MDPI, Basel, Switzerland. This article is an open access article distributed under the terms and conditions of the Creative Commons Attribution (CC BY) license (https:// creativecommons.org/licenses/by/ 4.0/).

from the foreground, background, and noisy channels [9]. Lastly, a hybrid approach is the combination of both sensor-based and vision-based models [10]. This approach can normally be used for experimental setups, though the cost of these prototypes is very much high and the prototype models are very much complex. For the fast computation of data, normally GPUs or GPGPUs are required [11].

In this paper, we have developed a smart assistive glove (*data glove*) to capture two specific sets of signs which are alphabetical signs and the numbers 0–10. Even though numbers and alphabetical signs are a small fraction (thirty-seven signs) of the ASL dictionary (The project https://www.spreadthesign.com/ (accessed on 1 February 2023) contains more than 20,000 signs for over 40 sign languages), these signs play an essential role in *fingerspelling,* which is a universal signed linguistic strategy for expressing personal names, technical terms, gaps in the lexicon and emphasis [12]. Both alphabetical signs and numbers are signs which can be captured by a data glove since, to the best of our knowledge, in ASL they are naturally signed only with hands, that is without using other articulators such as the head, eyebrows, or shoulders.

This data glove contains five flex sensors embedded on each finger of the hand and a gyroscope sensor attached to the top of the palm [13]. According to the posture orientation for standard numeric (ASL), as shown in Figure 1, the dataset is collected for thirty-seven different sign postures. These sign postures include data for the digit numbers 0 to 10 and from the letter A to letter Z. Self-collected thirty-seven separate postures data are used to train the fully connected bilayered and trilayered neural networks. A scaled conjugate gradient back propagation-based algorithm is used to perfectly classify these sign data. Listing all the deployments of the designed model, the whole procedure consists of the following points.

**Figure 1.** An overview of complete methodology for NN-based ASL recognition.


This proposed framework is novel as it is utilizing just two kinds of sensors to catch the total ASL numbers and letter sets information, which simplifies our model. Beforehand, various researchers working in the SL space had utilized a wide range of sensors that made the framework complex. Due to a vast amount of information based on sensor values, significant performance parameters such as general framework precision, effectiveness, and acquisition time are impacted [14–19]. In the framework we propose, just two sorts of sensors are utilized, i.e., a flex sensor to acquire the finger bowing data and an accelerometer/gyroscope sensor to get the hand orientation. Furthermore, we had to gather information physically since, to the best of our knowledge, no dataset containing the total data on ASL stances utilizing just two sorts of sensors is available. This allowed us

to perform the acquisition of the new dataset in an efficient way as no complex information had to be collected. In addition to this, we have carried out our developed dataset on various variations of neural networks and got noteworthy and cutting-edge performance results as discussed in the later sections. These outcomes reflect the quality of our gathered ASL dataset. Previously, most analysts exploited just one kind of neural model for obtaining maximum precision and accuracy results [20].

We also used TabTransformers and gMLP-based most recent and state-of-the-art models, though it does not perform well with our data in some preliminary experiments. This is due to the two following reasons: (one) our dataset does not have any categorical features, thus it contains all numeric features representing only sensor values; and (two) we obtained an overfitting of data, which is obviously not desirable. To make the classification and recognition process simple, we preferred to use a fully connected version of the neural network (i.e., multilayer perceptron—MLP).

We are aware that data gloves are not always well accepted by the Deaf community for at least two reasons. The first technical limitation of the data gloves is that they cannot capture articulators differently by hands [21,22]. Furthermore, in our project, we focus on numbers and alphabetical signs that are signed using only hands. A second sociological limitation of data gloves is that the burden of communication by wearing the glove is taken only by the deaf person to produce a one-way asymmetrical communication Deaf-to-not Deaf, thus not solving the general problem of accessing the speech. We believe that this second limitation is generally true, though, in some specific situations, data gloves can be used advantageously. For instance, gloves can be used as educational tools for SL learners. Moreover, in particular tasks, such as buying tickets in person, one can imagine that a Deaf person can use the glove for communicating the name of a city to a not-signing ticket seller by using fingerspelling [23].

The remaining paper Is structured as follows: a literature review is discussed in Section 2. Section 3 focuses on methodology. Materials and methods are discussed in Section 4. The results of the implementation are briefly discussed in Section 5, and Section 6 presents conclusion statements.

#### **2. Literature Review**

Accurate identification and classification of sign gestures perfectly and accurately is always a challenging task for all researchers in this domain. Many different techniques and methodologies have been adopted to perform this task. Different strategies have also been adopted for capturing and classifying postures data. Keeping in mind the major aspects of sign language, literature review-based studies are categorized into three main domains. Sensor-based recognition models, vision-based recognition models, and hybrid recognition models.

Sensor-based recognition models purely focus on one or a combination of different types of sensors. For data acquisition, flex sensors, gyroscope sensors, accelerometer sensors, contact sensors, optical sensors, or inertial motion sensors have been used [24]. Authors have used the mentioned sensors solo or in combination with different sensors to capture sign data [25]. Some of the authors have also worked on EEG signals for capturing brain data in the form of analog signals and then converting analog data into digital form for machine training [26–29]. In this challenging aspect, some authors have also used commercial data gloves that are purely made for capturing gesture data. However, in this scenario, the purpose of using an already made commercial data glove is to increase the accuracy and efficiency of an already-developed model [30]. Some of the authors in this domain have also focused on regional languages e.g., Pakistani sign language, Italian, Indian, Arabic, Russian, Chinese, Taiwanese, and Persian SLs, etc. [31–35]. This is considered a more challenging task as there is no predefined dataset available for regional languages and all the time authors must collect their own dataset for very few postures [36,37]. The good thing about sensor-based prototypes is that they are each worn and carried in

public. The resultant data is normally displayed on an LCD or transmitted to the mobile or computer screen via a Bluetooth module [38,39].

In concluding the literature discussion, our model is better than the literature models due to the accompanying reasons. Initially, the majority of the writers have zeroed in on just a single sort of SL information, for example, numbers or letter sets. Some of them had zeroed in on both, however, this did not cover the total ASL domain due to stance and sensor intricacies. However, we have zeroed in on all ASL numbers and letters in order and a blend of numeric and alphabetic information also. Secondly, due to the expanded number of sensors, by and large, framework effectiveness and precision have not created many astonishing outcomes. However, in our model, we have utilized an extremely fine blend of two kinds of sensors that gave us the best outcomes with phenomenal precision and effectiveness. Third, a large number of authors take care of just the AI or neural network model that gives them great outcomes. Be that as it may, we have tried our manually-collected dataset on different neural models and it performed very well in all neural formats, which mirrors the creativity and flawlessness of our model and information. A point-by-point accuracy examination is likewise recorded in tabular form in the results and discussion section.

#### **3. Methodology**

In sign language recognition, there is a list of concatenated tasks starting from capturing posture data with the help of an assistive glove to the identification of resultant values. For the development of assistive gloves, five flex sensors and gyroscope sensors are used. It is a property of the flex sensor to produce a resistance value based on the bending performed to make gestures. Sensors attached to each finger and the palm of the hand help in getting values regarding one-sign posture. A user wearing an assistive glove will make sign gestures for ASL and the resultant sensor values will be analyzed and captured with the help of a microcontroller. A prototype design is a combination of microcontrollers and sensors. The purpose of the microcontroller in the development of assistive gloves is to capture sensor-based values and transmit these values to the processing unit i.e., the computer or server. These collected values are preprocessed and then stored in a database or file with the help of a parallax microcontroller data acquisition add-on tool for Microsoft Excel (PLX-DAQ). The core functionality of PLX-DAQ is the transmission of sensor values i.e., coming through a microcontroller via serial communication directly into the Excel file. This is the point where dataset generation is performed by collecting all sensor values into a local or online server-based file. This processed data is forwarded to a neural network for training purposes. Once a model is completely trained, it is tested for new incoming data to analyze its performance. The complete methodology is discussed in Figure 1, which displays a neural network-based classification process for digits. Alphabetic and alphanumeric neural models also work in the same way.

Neural network-based implementation of sign language requires data in numeric format. The preprocessed data is utilized as input to train a fully connected neural network. Based on patterns of sensor values, deep gesture classification is performed for thirty-seven sign postures. A scaled conjugate gradient back propagation algorithm is used which has proved helpful in getting maximum accuracy.

#### **4. Materials and Methods**

Materials are the connected components that are used collectively for capturing sign postures. In our developed assistive glove, we have used flex sensors, gyroscope sensors, resistances, and an Arduino microcontroller as materials and we have used a neural network-based scaled conjugate gradient back propagation algorithm as a method to classify postures made by wearing an assistive glove. A very brief description of materials and methods is discussed in the sections below.

#### *4.1. Hardware Components*

#### 4.1.1. Flex Sensor

A flex sensor is also known as a bending sensor. The internal structure of the flex sensor is based on a phenolic resin substrate with conducted ink deposits which produces increased resistance when it is bent to some angle. A flex sensor works on the principle of the voltage divider rule where Vin is the input voltage, Vout is the final output voltage, while R1 and R2 are combinations of fixed resistances, and Rflex is the resistance of the flex sensor, as shown in Equation (1):

$$\mathbf{V\_{out}} = \mathbf{V\_{in}} \left[ \mathbf{R\_1}/(\mathbf{R\_2} + \mathbf{R\_{flex}}) \right] \tag{1}$$

The bending of the flex sensor is directly proportional to the resistance value. The higher the bending is, the higher the resistance inside the material. The physical shape of the sensor consists of two pins. While interconnecting with the microcontroller, as shown in Figure 2, one pin is connected with the analog pin of the microcontroller, and the other pin is connected to the ground. To avoid voltage overflow, a minimum value resistance is also connected to the pin of the flex sensor. In our assistive glove, we have used five flex sensors and five resistances connected with these flex pins.

**Figure 2.** Flex sensor, gyroscope, and Arduino-based data glove designed for capturing ASLbased gestures.

#### 4.1.2. MPU 6050

A gyroscope sensor is a three-axis based shrewd sensor gadget that assists in catching with protesting direction. Concerning the SLR model, not all standard American sign motions can be caught with just flex sensors. This is due to the idea of sign motions. All number-based sign motions have no kind of stance covering. Taking into account the ASL letter sets' poses, this is a kind of complicated characterization issue having 26 classes. In an alphabet-based recognition problem, posture overlapping happens. Certain signals cannot be caught without catching motion direction. A gyroscope sensor is utilized in this experiment to effortlessly catch sign directions.

Hand orientations made toward any direction are caught as 3-axis-based numeric values. Three-directional information is captured as the angle is caught. Hand orientationbased or directional change in representing any letter set is caught with the assistance of three parametric values, for example, the x-axis, y-axis, and z-axis. A complete prototype design is shown in Figure 2.

#### 4.1.3. Arduino Microcontroller

For processing the input data from the sensors, an AT mega 328P-based Arduino microcontroller is used. This microcontroller has both analog and digital pins attached to it. This is a 10-bit microcontroller having values ranging from 0 to 1023 and can easily operate on 16 MHz frequency. Arduino has 32 KB of memory and 2 KB of RAM for quick data processing. It can easily be operated with the help of a 5v DC battery or by connecting with the USB port on the computer. While interconnecting with flex sensors, five sensor pins are connected with the five analog ports of Arduino, and the common ground of Arduino is attached to all the second pins of flex sensors. A simple interconnection of the flex and the Arduino microcontroller is shown in Figure 3.

**Figure 3.** A simple interconnection of the flex sensor with the Arduino microcontroller.

#### *4.2. Dataset Generation*

For the implementation of SL classification, we have used a self-collected dataset based on the flex and gyroscope sensor values. For this experiment, we have created and gathered three separate datasets: numeric ASL having 11 (numbers 0 to 10), alphabetic ASL having 26 classes (letters A to Z), and alphanumeric ASL stances having 37 classes (0–10 and A-Z). Every SL pose has 200 examples gathered from 9 distinct male and female volunteers 24 to 26 years of age. All datasets are gathered under ordinary conditions of the lab. The dataset size for every variation can be determined by multiplying the number of sign posture classes with the number of SL samples gathered for each stance. This dataset is further split into training, validation, and testing sets for neural implementation.

#### *4.3. Neural Network Architecture*

The classification of sign gestures is usually considered a complex task. In our experiment, we have used a fully connected bilayered and trilayered neural network having 5 inputs and 11 outputs for the digit datasets, as shown in Figure 1; similarly, 8 inputs and 26 and 37 outputs for alphabet and alphanumeric datasets, respectively. After the input layer, the second layer is the hidden layer and the third one is the output layer. The preprocessed training data is fed into the network through the input layer and the resulting classified data is analyzed through the output layer of the network. All the statistical information of the bilayered and trilayered neural models is listed in Table 1.


**Table 1.** Statistical information of different variants of neural models on digit, alphabetic, and alphanumeric datasets.

#### *4.4. Scaled Conjugate Gradient Back Propagation Algorithm*

We consider the scaled conjugate gradient (SCG) back propagation algorithm for implementing back propagation. With respect to other algorithms, it is computationally fast and does not require a line search after each iteration. Equation (2), given below, is the mathematical notation of the SCG algorithm where E(w) is a global error function that depends on the biases and the weights associated with the neural network. E(w) is calculated with one forward pass and E (w) is calculated with one backward pass of the neural network iteration. On each iteration, the optimal distance is measured which leads to a better line search for gradient computation as in Equation (3). In Equation (3), p is the number of patterns presented to the network as weighted vectors during training, and ak denotes the step size of the function that aims at regulating the indefiniteness of the Hessian metrics.

$$\mathbf{E}(\mathbf{w} + \mathbf{y}) = \mathbf{E}(\mathbf{w}) + \mathbf{E}'(\mathbf{w})^\mathsf{T}\mathbf{y} + \frac{1}{2} 2\mathbf{y}^\mathsf{T}\mathbf{E}''(\mathbf{w})\mathbf{y} \tag{2}$$

$$\mathbf{y\_{k+1}} = \mathbf{y\_k} + \mathbf{a\_k} \times \mathbf{p\_k} \tag{3}$$

The complete operational pipeline of the proposed model starts with the prototype design. The purpose of making a new data glove is twofold; (one) it is possible to capture all static sign postures with the help of only two sensors. This can make the computational model less complex and fast in computations, and (two) analysis of the neural model performance in case of less complex data samples i.e., whether it perfectly classifies or goes towards underfitting or overfitting. While experimenting with capturing signs, in between transitions of signs occurred when the signer switched from one posture to another posture. To cope with this problem, we adopted a dual conditional approach i.e., we first checked the orientation of each finger for each ASL posture and then analyzed the hand orientation for each individual posture. Then, we set the minimum and maximum range for each sensor to get the label of each posture made by the signer. In case of the posture perfectly matching

the ranges of sensor value, we get the numeric or alphabetic label by the microcontroller, e.g., 1,2,3 or A, B, C. In case of no matching, we get '−1 as noise which was filtered out for dataset formation.

#### **5. Results and Discussion**

Sign language recognition being the most emerging and challenging domain requires very efficient and accurate findings. Results obtained after the successful implementation of the discussed models are illustrated in detail in this section. Statistical information of the neural model used for classification and recognition is completely listed in Table 1. The information of the model includes the preset, the number of fully connected layers, the first layer size, the activation function used, the limit of maximum iterations, the prediction speed, the accuracy, and the training time. As in the implementation, different variants of neural networks are used. Therefore, statistical information related to each neural model is included in the table. Apart from different neural models, three different types of datasets are also used. These different datasets include digits, alphabets, and alphanumeric datasets. A very comprehensive description of each dataset is reported below.

a. *Number datasets*

The number dataset contains sensor information for eleven distinct stances. These stances incorporate information from numbers 0 to 10, hence this is an 11-class problem. Training of the neural network results into a display of performance in the form of training, validation, and testing plots occurred. These plots provide information concerning epochs and cross entropy of the model under progress. The blue line indicates training, the green line reflects validation, the red line displays testing, and the dotted line highlights the best performance of the model. The best validation performance for digits is 9.1511 × <sup>10</sup>−<sup>7</sup> at the 59th epoch, as shown in Figure 4a. For digit classification, only flex sensors are utilized. Therefore, the value ranges for five flex sensors are listed on the y-axis and the total number of sign gestures for 11 numbers of ASL sign postures are displayed on the x-axis of Figure 4b. Each color represents each flex sensor attached to the prototype.

**Figure 4.** Training, validation, and testing performance plot (**a**) along with flex sensors values plot (**b**) for number dataset.

#### b. *Alphabets dataset*

The alphabet dataset contains sensor information for twenty-six distinct stances. These stances incorporate information from letters A to Z, hence this is alluded to as a 26-class problem. The training, validation, and testing plot of the alphabetic neural network is shown in Figure 5a below with the best validation performance of 1.2097 × <sup>10</sup>−<sup>6</sup> at the 62nd epoch. For alphabet classification, a combination of flex sensors, accelerometer, and gyroscope sensors are utilized. Therefore, the value ranges for the five flex sensors, the three-axis accelerometer, and the gyroscope sensors are listed in the y-axis and the total number of sign gestures for the 26 letters of ASL sign posture is displayed on the x-axis of Figure 5b. Each color represents each sensor value attached to the prototype.

(**a**) (**b**)

**Figure 5.** Training, validation, and testing performance plot (**a**) along with flex, accelerometer, and gyroscope sensors values plot (**b**) for the alphabets dataset.

#### c. *Alphanumeric dataset*

The alphanumeric dataset contains sensor information for thirty-seven distinct stances. These stances incorporate information from letters A to Z and data from numbers 0 to 10, hence this is alluded to as a 37-class problem. The training, validation, and testing plot of the alphanumeric neural network is shown in Figure 6a below with the best validation score of 1.6671 × <sup>10</sup>−<sup>6</sup> at the 102nd epoch. For alphanumeric sign classification, the same combination of flex sensors, accelerometer, and gyroscope sensors is utilized. Therefore, the value ranges for the five flex sensors, the three-axis accelerometer, and the gyroscope sensors are listed on the y-axis, and the total number of sign gestures for the 37 alphanumeric ASL sign postures are displayed on the x-axis of Figure 6b. Each color represents each sensor value attached to the prototype.

Activation functions play a very important role in updating the weights of the neural nodes during training. Choosing the correct and most appropriate activation function for your model helps in achieving good accuracy and training results. The authors in this paper also adopted the strategy of analyzing the impact of activation functions on the performance of the neural networks by using three different activation functions i.e., ReLU, Tanh, and Sigmoid. Replicating the same experiment by changing the activation function results in different accuracies, as listed in Figure 7 below. This experimental strategy is repeated six times by taking three types of activation functions on a bilayered neural network shown in Figure 7a and then implementing the same three types of activation functions for the trilayered neural network shown in Figure 7b. The analysis states that for the bilayered neural networks, ReLU has the highest accuracy for all formats of the dataset, i.e., number, alphabetic, and alphanumeric. Tanh stands second in this implementation and sigmoid lags due to the mathematical behavior of the function. The same is the case in the trilayered neural network model. ReLU performs very well by providing the best results

for number, alphabetic, and alphanumeric datasets. Tanh stands second and sigmoid is in the last stage in this comparison. All these model values are also listed in Table 1.

**Figure 6.** Training, validation, and testing performance plot (**a**) along with flex sensors values plot (**b**) for the alphanumeric dataset.

**Figure 7.** Activation function impact on the accuracy of the neural model.

Accuracy comparison is a good way of checking the developed model's progress. Considering the literature review-based implementation of gesture classification, we have compared the results of the literature with our results. Table 2, given below, highlights the algorithmic performance of the literature model corresponding to the accuracy and the reference number. Comparing our results (in bold) with the literature review, it is clearly seen that our model performed very well in all aspects of evaluation, i.e., accuracy, speed, and training time.


**Table 2.** Literature review-based accuracy analysis and comparison.

For experimental and educational purposes these types of assistive technologies play a very vital and effective role in society. For experimentation, the focus of researchers is mainly on computational speed, model performance, prototype cost, and recognition response. However, the prototypes associated with real-time recognition or translation of sign postures must deal with all types of social factors as well, i.e., enabling two-way communication by not putting the burden of communication on the Deaf only. Considering the applications of sign-to-speech (S2S) assistive technologies, they only deal with 50% of problems in the case of Deaf people.

Similarly, dealing with regional languages, e.g., Italian, Spanish, etc., requires much experimental and analysis work to do since sign gestures for every region vary from each other. Even considering just one regional language, it is not possible to capture and translate all language postures with the data glove only. Data gloves can only capture hand movements, not arm, head, articulation, and other body movements. If we consider increasing the number of sensors to capture all movement types, then it would be very unrealistic to go in public with a body full of sensors. These are some challenges and future directions associated with our implementation that can lead researchers to think and work accordingly.

#### **6. Conclusions**

In this paper, neural network-based model for sign language recognition was proposed where the assistive glove was designed and implemented for capturing real-time data and compiling it into a dataset. Among different domains of gesture classification, we have focused on the purely sensor-based implementation of standard ASL postures. An assistive glove was used to collect a dataset having 200 samples each for 11 numbers, 26 letters, and 37 alphanumeric sign postures. A fully connected bilayered and trilayered neural network was used to classify eleven, twenty-six, and thirty-seven isolated static sign gestures. A scaled conjugate gradient back propagation algorithm was used to train neural models for the self-collected datasets. The impact of the activation function on the performance of the model was also analyzed in this paper. Successful implementation of the model has helped the authors in achieving promising training and testing accuracy for numeric, alphabetic, and alphanumeric datasets, respectively.

However, our self-generated dataset has a small portion of static gestures used by the American Sign Language Community. In the future, all representative samples of ASL would be collected using this glove and other models would be trained to perform the recognition.

**Author Contributions:** Conceptualization, M.S.A.; Methodology, M.S.A. and S.T.H.R.; Software, M.S.A.; Writing—original draft, M.S.A.; Writing—review & editing, S.T.H.R., A.M. and L.A.; Supervision, S.T.H.R., A.M. and L.A. All authors have read and agreed to the published version of the manuscript.

**Funding:** This research received no external funding.

**Conflicts of Interest:** The authors declare no conflict of interest.

#### **References**


**Disclaimer/Publisher's Note:** The statements, opinions and data contained in all publications are solely those of the individual author(s) and contributor(s) and not of MDPI and/or the editor(s). MDPI and/or the editor(s) disclaim responsibility for any injury to people or property resulting from any ideas, methods, instructions or products referred to in the content.

### *Article* **A Clustered Federated Learning Method of User Behavior Analysis Based on Non-IID Data**

**Jianfei Zhang \* and Zhongxin Li**

School of Computer Science and Technology, Changchun University of Science and Technology, Changchun 130000, China

**\*** Correspondence: jfzhang@cust.edu.cn

**Abstract:** Federated learning (FL) is a novel distributed machine learning paradigm. It can protect data privacy in distributed machine learning. Hence, FL provides new ideas for user behavior analysis. User behavior analysis can be modeled using multiple data sources. However, differences between different data sources can lead to different data distributions, i.e., non-identically and nonindependently distributed (Non-IID). Non-IID data usually introduce bias in the training process of FL models, which will affect the model accuracy and convergence speed. In this paper, a new federated learning algorithm is proposed to mitigate the impact of Non-IID data on the model, named federated learning with a two-tier caching mechanism (FedTCM). First, FedTCM clustered similar clients based on their data distribution. Clustering reduces the extent of Non-IID between clients in a cluster. Second, FedTCM uses asynchronous communication methods to alleviate the problem of inconsistent computation speed across different clients. Finally, FedTCM sets up a two-tier caching mechanism on the server for mitigating the Non-IID data between different clusters. In multiple simulated datasets, compared to the method without the federated framework, the FedTCM is maximum 15.8% higher than it and average 12.6% higher than it. Compared to the typical federated method FedAvg, the accuracy of FedTCM is maximum 2.3% higher than it and average 1.6% higher than it. Additionally, FedTCM achieves more excellent communication performance than FedAvg.

**Keywords:** federated learning; Non-IID; user behavior; user modeling

#### **1. Introduction**

As the computer field boomed, users generated a variety of behavior data while surfing the Internet, such as video-clicking behavior, shopping behavior, and more. In recent years, deep learning techniques have been used to uncover the hidden information behind such behavioral data. It is well known that the predictive power of deep learning models relies on training data. However, with the increasing emphasis on user privacy, it will be more difficult to collect and share data across organizations, thus creating isolated data islands. Besides, the owners of some highly sensitive data may object to the unrestricted use of such data. In this environment, how to solve the fragmented data and isolated data island will become the primary problem in the field of machine learning.

To simultaneously achieve privacy protection and train models using data, FL is proposed. FL aims to build a federated model based on global data. Edge devices with fragmented data participate in training models using their data while keeping their data secure. In this case, the client does not transmit their local data but rather the parameters of the model trained using the local data. At the end of the training, all clients will obtain a model that meets the requirements.

However, user behavior data is influenced by age, gender, lifestyle, and other factors. The local data on the client side is likely to be non-independently and non-identically distributed. Due to the characteristic of "data does not move, model moves" of FL, the central server cannot directly operate the local data of each client. The authors of [1] proposed a

**Citation:** Zhang, J.; Li, Z. A Clustered Federated Learning Method of User Behavior Analysis Based on Non-IID Data. *Electronics* **2023**, *12*, 1660. https://doi.org/10.3390/ electronics12071660

Academic Editor: Antoni Morell

Received: 27 February 2023 Revised: 24 March 2023 Accepted: 28 March 2023 Published: 31 March 2023

**Copyright:** © 2023 by the authors. Licensee MDPI, Basel, Switzerland. This article is an open access article distributed under the terms and conditions of the Creative Commons Attribution (CC BY) license (https:// creativecommons.org/licenses/by/ 4.0/).

framework for FL and designed the FedAvg algorithm. The FedAvg algorithm has excellent performance on identically and independently distributed (IID) data. However, Non-IID data can affect the performance of the FedAvg algorithm. As the degree of Non-IID increases, there will be a significant decrease in model accuracy. Local training of FL relies on stochastic gradient descent, and only the IID of local data ensures that the local stochastic gradient is an unbiased estimate of the global gradient. Gradient differences will cause the global model to be influenced by the client-side local model. Hence, differences in data distribution between clients are a critical issue for federated learning.

Recently, various approaches attempting to solve the impact of Non-IID data on FL models have been proposed. A part of the method applies adaptive optimization. In [2–4], these methods limit the update magnitude of the local model by adding proximal terms and reducing the gap between the local model and the global model. However, all these methods require adjusting the ratio parameter of the proximal term. Different ratio parameters can affect the accuracy and convergence speed of the model. A part of the approach tries to adjust the data distribution of all clients to improve the performance of the FL model. In [5,6], clients share a part of the data to build a public dataset. The public dataset can reduce data distribution differences across all clients. In [7,8], data augmentation is performed for classes with fewer samples to balance the data samples for each client. The method of sharing data will leak privacy, especially in the field of user behavior analysis. User data are supposed to be highly confidential, and sharing data methods violate the original intent of federated learning to protect privacy.

As the Internet grows, users may interact with different devices. User behavior data is stored on multiple devices. The devices involved in the training may be cell phones, tablets, or edge devices with restricted resources. As a result, during the training of the model, the devices have different computational speeds called system heterogeneity. In [1–3], the client and server use synchronous communication, which has the advantage of high accuracy and fast convergence. Since clients have different computational speeds, if synchronous communication is used, the device with the slowest computation speed will prolong the overall training time. The existing approach [2] is to drop the slow training devices, but dropping the slow devices only alleviates the problem of 'wait'. Randomly dropping devices will destroy the integrity of the overall data. Asynchronous communication solves the problem of different computing speeds for clients. It requires all cluster mediators to communicate directly with the server after the local update is complete rather than waiting for all clusters. Due to the special way of asynchronous communication, each client can only generate aggregated models with the server individually. However, the aggregation parameter weights are difficult to adjust. A series of methods for aggregating parameter weights for asynchronous communication was proposed in [9,10], but these methods still have gaps with synchronous communication.

We propose an FL algorithm based on a two-tier cache mechanism (FedTCM) that supports the mitigation of the impact of Non-IID data on the model and considers the different computational speeds of devices. First, FedTCM clusters all clients based on the similarity of their data distribution, and the clients within the cluster can optimize their model parameters. Additionally, this process does not involve operating the private data of the client, so it is more secure. Second, due to the difference in the computing speed of the devices, the cluster mediator asynchronously communicates with the server. Finally, FedTCM sets up a two-tier cache structure on the server side, where the first-tier cache stores the latest model parameters for each cluster. In this way, FedTCM avoids the problem of difficult adjustment of aggregation weights for traditional asynchronous communication. In the second-tier cache, we use a random distribution strategy to train the model in different clusters. If the training order is not considered, the model in the second-tier cache will contain data information for all clusters. The impact of Non-IID data on FL models is reduced by a random distribution strategy.

The main contributions of this paper are summarized as follows:


#### **2. Related Work**

To solve the impact of Non-IID data, various methods have been proposed. The existing work falls into four main categories: data augmentation, cluster-based multimodel learning, adaptive optimization methods, and personalized federated learning.

#### *2.1. Data Augmentation*

The authors of [5] found that the difference between local and global data distribution reduces the accuracy of the model and proposed an approach based on data sharing. Experiments show that sharing 5% of the data per client will improve accuracy by 30%. The literature [7] uses the smote oversampling technique to generate new samples for categories with a small number of samples. Similarly, the authors of [11] use the conditional generative network to generate new samples. These methods reduce the heterogeneity of local client data by way of data expansion. However, these methods involve operating client local data, which violates the original intent of FL to protect privacy.

#### *2.2. Cluster-Based Multi-Model Learning*

The literature [12] proposed an approach based on multiple centers. The local model and center location are iteratively optimized by modifying the loss function of the local client. Clients with similar distances to the same center are divided into groups. The clients in the group train the same model. The authors of [13] propose a novel clustering approach where the process of clustering does not require the participation of all clients to avoid the additional communication overhead added during clustering. The literature [14] proposed FedAMP, an attention-based FL algorithm. Different from the previous approach, FedAMP does not put the client into a specified certain cluster. By introducing the attention mechanism, a client can maintain multiple cluster models at the same time. These methods can perform well in Non-IID; clients with different data distributions do not interfere with each other. However, in a real-world application setting, multiple models do not generalize for all clients. Our goal is to train an FL model that is shared by all clients.

#### *2.3. Adaptive Optimization*

The literature [2] proposed adding proximal terms on the client side. The proximal term is used to limit the gap between the global model and the local models. The authors of [3] proposed an FL framework based on gradient correction. The gradient correction term is used to limit the extent of model updates. The study [4] proposed to estimate the global device knowledge separately using local control variables and server control variables. The literature [15] found that when only a small number of devices are involved in the training, the model accuracy decreases significantly; they also proposed that the momentum-based algorithm FedAvgM improves the accuracy. FedAvgM speeds up the model training process and improves model accuracy by adjusting the gradient. These methods perform well in improving FL model convergence. However, these methods are difficult to adjust the ratio parameter of the proximal term.

#### *2.4. Personalized Federated Learning*

The literature [16] introduced the concept of life-long learning. This method treats clients with different distributions as different tasks. After training several rounds, all the tasks are combined into a global task. The authors of [17] introduced the FL of taskonomy to efficiently aggregate heterogeneous data by learning task correlations between clients. The study [18] proposed a delayed aggregation method. After the training of local clients is completed, the server collects all model parameters. Then, the server sends these model parameters to other clients for further training. This algorithm uses multiple clients to train the same model to mitigate the impact of heterogeneous data. The literature [19] proposed the ARUBA framework. It is based on online learning to implement adaptive meta-learning under FL. When used in combination with FedAvg, it can improve model generalization performance.

In this work, we propose an FL algorithm based on a two-tier caching mechanism called FedTCM. FedTCM clusters similar clients, and clients within the same cluster perform local training synchronously. Clusters send model parameters to the server asynchronously to overcome the heterogeneity of the system. On the server side, we propose a caching mechanism to solve the impact of heterogeneous data on the model. Therefore, FedTCM is a combination of centralized and decentralized federated learning methods.

#### **3. The Principle of FedTCM**

To reduce the impact of Non-IID data on the performance of the FL model, we design an FL algorithm based on a two-tier cache mechanism (FedTCM). The main parameters of the system are summarized in Table 1.


**Table 1.** Notion and definitions.

Figure 1 shows the structure of FedTCM. Although we cannot rearrange the data for each client, we can rearrange each client. In FedTCM, we first cluster all the clients. The clustering is based on the similarity of the client data distribution. A model trained commonly by clients with similar data distribution is better than a model trained by a single client. Additionally, the degree of heterogeneity between clients in the cluster is reduced. Clients in the same cluster can optimize their models and reduce the number of communications with the server during the training process, ensuring system stability. After the clustering is complete, the server sends an initial model to each cluster. The cluster mediators share the initial model with each client within the cluster, and the local clients start training. Local clients send the model parameters to the cluster mediator after the training is complete. The cluster mediators aggregate model parameters of all clients within the cluster and send the model to the server through asynchronous communication. Asynchronous communication will effectively alleviate the problem of system heterogeneity.

**Figure 1.** Structure of FedTCM. FedTCM mainly includes clusters and a server. Clients with different colors denote that the clients have different data distributions. The cluster consists of clients and a cluster mediator. The server contains a first-tier cache list *m* and a second-level cache list *n*.

At the same time, we design a two-tier caching mechanism on the server side. Each cluster corresponds to one location in the first-tier cache and one in the second-tier cache, respectively. When the cluster calculation is completed, the model is sent to the server. The server updates the corresponding locations in the first-tier cache and the second-tier cache. The first-tier cache is used to store the latest model for each cluster, and it is used to solve the problem of difficulty in adjusting aggregation weights in asynchronous communication. In traditional asynchronous communication, the server requires aggregation of a single model parameter with the model parameters in the server, but the aggregation weights affect the performance of the model. Since FedTCM set up the first-tier cache, after the cluster computation is completed, FedTCM only needs to aggregate all the models in the first-tier cache. In this way, FedTCM avoids aggregation of single model parameters with server model parameters. The second-tier cache stores the aggregated model parameters

corresponding to the first-tier cache, and it uses a model random distribution strategy to mitigate the impact of Non-IID data on the model. Inspired by the literature [18], we believe that a model trained on all the clustered data will be more reliable than a model from a single cluster if the order of training of the data is not considered. Therefore, FedTCM let the models in the second-tier cache be trained with data from different clusters. In this way, the models in the second-tier cache carry more information about the data of the cluster. The degree of model heterogeneity in the second-tier cache will be lower than that of the models trained in a single cluster. Therefore, the second-tier cache is used to mitigate the impact of Non-IID on the model.

#### *3.1. Clustering the Clients Based on the Distribution of Local Data*

Various existing methods [12,20–22] use the model or the computational speed of the device in the process of clustering. However, all these methods need to be pre-trained before clustering. We propose to use the number of sample labels from the client for clustering. The client counts the amount of local data before local training begins. Calculate the percentage of each type of label according to the number of labels in the sample data. We call it the label count vector. FedTCM cluster is based on the similarity of the label count vector. The approach has two advantages. One is that sending a label count vector to an honest server does not leak data from the client, and the other is that we do not need to pre-train before clustering compared to previous clustering methods. The label count vector is one-dimensional, and the size of the label count vector is negligible compared to the model. Before the training starts, clients send the local label count vector to the server. The server clusters all clients according to the cosine similarity of the label count vector. The cosine similarity is calculated using Equation (1).

$$q = \frac{V\_{\bar{i}} \cdot V\_{\bar{j}}}{||V\_{\bar{i}}|| ||V\_{\bar{j}}||} = \frac{\sum\_{k=1}^{n} v\_{i,h} \times v\_{j,h}}{\sqrt{\sum\_{h=1}^{n} (v\_{i,h})^2} \times \sqrt{\sum\_{h=1}^{n} \left(v\_{j,h}\right)^2}},\tag{1}$$

where *Vi* and *Vj* represent the label count vector of clients *i* and *j*, respectively. *vi*,*<sup>h</sup>* and *vj*,*<sup>h</sup>* represent the components of the vectors *Vi* and *Vj*, *q* denotes the cosine similarity between clients.

The server calculates the cosine similarity between all clients. If the cosine similarity of the label count vector of several clients exceeds a set threshold, then they are in the same cluster. Data within a cluster approximately conforms to the same data distribution.

#### *3.2. The Process of Training on Clients*

In the training process, the main tasks of the client are receiving the initial model from the cluster mediator, training the model using local data, and uploading the model parameters to the cluster mediator.

In the process of local client training, the target of the local clients is defined as follows:

$$
min F = \frac{1}{n\_{b,j}} \sum\_{k \in P\_{b,j}} l\_k(\boldsymbol{x}\_k, \boldsymbol{y}\_k; \boldsymbol{\omega})\_\prime \tag{2}
$$

where *lk* is the loss function for data sample {*xk*, *yk*}, *xk* is the characteristics of the data sample, and *yk* is the label of the data sample. The subscript *b*, *j* represents the client *j* in cluster *b*. *Pb*,*<sup>j</sup>* is the dataset of client *j*. Let *nb*,*<sup>j</sup>* = *Pb*,*<sup>j</sup>* . *<sup>F</sup>* is the local empirical risk of the device in the local data *Pb*,*j*. The loss function *lk* uses the cross-entropy loss function, which is defined as follows:

$$d\_k = -\sum\_{d} y\_d \log a\_{d\prime} \tag{3}$$

where *d* is the number of dataset labels, *yd* is the one-hot vector of *yk*, and *ad* is the one-hot vector of the model output. To minimize the loss in the client training process, the local client iteration phase (the stochastic gradient method as an example) is interpreted as:

$$
\omega\_{b,j}^{t+1} = \mathcal{W}^t - \eta \nabla F\_\prime \tag{4}
$$

where *W<sup>t</sup>* is the model parameters sent by the cluster mediator to client *j* in round *t*, *ωt*+<sup>1</sup> *b*,*j* is the local model computed by the local client based on *W<sup>t</sup>* , and *η* represents the learning rate. When the local client has finished computing, the local client sends the local model parameters to the cluster mediator.

#### *3.3. The Major Tasks of Cluster Mediator*

The cluster mediator has two main tasks in the process of model training. One task is to receive the initial model from the server and send the initial model to each client in the cluster. Another is to receive the model parameters of all clients in the cluster, aggregate the collected model parameters, and send the aggregated model parameters to the server. The cluster mediator collects the local models and aggregates them using Equation (5).

*ωt*+<sup>1</sup> *<sup>b</sup>* <sup>=</sup> <sup>∑</sup> *nb*,*<sup>j</sup> nb ωt*+<sup>1</sup> *<sup>b</sup>*,*<sup>j</sup>* , (5)

where *ωt*+<sup>1</sup> *<sup>b</sup>*,*<sup>j</sup>* represents the model parameters for the client *j* in cluster *b* at round *t* + 1, and *ωt*+<sup>1</sup> *<sup>b</sup>* represents the aggregation parameters of cluster *b* at round *t* + 1. The cluster mediators send *ωt*+<sup>1</sup> *<sup>b</sup>* to the server after the aggregation is completed. Due to the difference in the computational speed of the client, the cluster mediator will communicate with the server asynchronously.

After the client communicates with the cluster mediator, the cluster mediator immediately aggregates all the received models. The data distribution of all clients in the same cluster is consistent. It avoids the interference of other clients when aggregating model parameters. The model parameters of the cluster mediator are more representative of the overall data distribution of the cluster. Only all cluster mediators need to communicate with the server to cover all client information, instead of all clients.

We measured the communication performance of FedTCM in terms of the number of times the parameter is sent and received. FedTCM puts clients with similar data distribution into the same cluster. The model information of a cluster mediator represents the data information of all clients within the cluster. The number of cluster mediators is much smaller than the number of clients. In addition, FedTCM uses asynchronous communication and the two-tier cache mechanism. When aggregating model parameters, the server does not need to communicate with all clients, but only with part of all cluster mediators to complete the calculation. Additionally, each cluster mediator only needs to communicate with clients in its cluster to complete the local computations. Therefore, with the two-tier caching mechanism and cluster mediator, FedTCM can reduce the number of times the server receives and sends model parameters.

#### *3.4. The Major Tasks of the Server*

We designed a two-tier caching mechanism on the server side. The first-tier cache is used to avoid the problem of adjusting the aggregation parameters for asynchronous communication. The second-tier cache uses a random distribution strategy to mitigate the impact of Non-IID data on the model.

In the previous section, we mentioned that the cluster communicates with the server using asynchronous communication methods. However, there is a disadvantage of asynchronous communication: aggregation weights are difficult to adjust when a single model is aggregated with a model stored on the server. To avoid the impact of aggregation

weights on the model in asynchronous communication, the first-tier cache is used to solve this problem.

The first-tier cache structure is shown in Figure 2. The cluster mediator of cluster *b* sends model parameter *ωt*+<sup>1</sup> *<sup>b</sup>* to the server. The server replaces the stale model parameters *ωt <sup>b</sup>* of cluster *<sup>b</sup>* with the latest model parameters *<sup>ω</sup>t*+<sup>1</sup> *<sup>b</sup>* , which means that the first-tier cache will store the latest models of all clusters. FedTCM uses Equation (6) to aggregate all model parameters in the first-tier cache and store these aggregation model parameters in the corresponding location of the second-tier cache.

$$\mathcal{W}\_b^{t+1} = \frac{1}{\mathcal{S}} \sum\_{i=1}^{\mathcal{S}} m[i] \tag{6}$$

where *S* is the number of clusters, and *m* is the first-tier cache list. Although FedTCM uses asynchronous communication, it avoids aggregation of single model parameters with server model parameters.

**Figure 2.** Update principle of the first-tier cache.

The second-tier cache uses a model random distribution strategy. Figure 3 illustrates the principle of the second-tier cache. After the first-tier cache aggregation is completed, the aggregation model parameter *Wt*+<sup>1</sup> *<sup>b</sup>* is stored in the second-tier cache. The server will randomly select a model parameter in the second-tier cache to send to cluster *b* (the second-tier cache model of cluster *S* is selected in Figure 3). Cluster *b* receives the model parameters and starts the next round of training. We aggregate all the model parameters in the second-tier cache to obtain the global model. The global model does not participate in training, and it can be obtained at any time point. This approach allows the models in the second-tier cache to be trained based on different cluster data. Therefore, the model in the second-tier cache can carry more information about the data of the cluster. The random distribution strategy will help to reduce global loss. The model in the second-tier cache gradually approaches the optimal global model. The heterogeneity of the models in the second-tier cache is lower compared to the single cluster model. In this way, a simple aggregation of the model parameters in the second-tier cache can reduce the impact of Non-IID data on the model.

**Figure 3.** Update principle of the second-tier cache. The second-tier cache sends the model to cluster *b* using a random distribution.

#### *3.5. The Process of FedTCM*

FedTCM is illustrated in Figure 4. The clients in cluster *b* send the model parameters to the cluster mediator (step 1). The cluster mediator aggregates the received parameters. The cluster mediator sends the aggregated model parameters *ωt*+<sup>1</sup> *<sup>b</sup>* to the server. Asynchronous uploading is used here (step 2). The server updates the model parameters in the first-tier cache (step 3). The server aggregates all model parameters in the first-tier cache and stores the aggregated results *Wt*+<sup>1</sup> *<sup>b</sup>* in the corresponding second-tier cache (step 4). The server will randomly select a model in the second-tier cache and send the model to cluster *b* (step 5). The cluster mediator sends the model to the client. Intuitively, cluster *b* is trained based on model parameters from different clusters (step 6).

**Figure 4.** The process of FedTCM.

The pseudocode of FedTCM is shown in Algorithm 1. Line 3 represents our clustering method based on cosine similarity. Lines 6–7 represent the update and aggregation process of the first-tier cache, where *m* represents the list of first-tier cache, and *m* is used to store the latest model for each cluster. Lines 8–9 represent our random distribution strategy, where *n* denotes the list of the second-tier cache, and *n* is used to store the aggregation model parameters from the first-tier cache. Lines 13–16 represent the process of receiving and sending models by cluster mediators. Lines 18–22 represent the training process of the local client.

#### **Algorithm 1.** FedTCM.

**Input:** *cluster*() is the cosine similarity clustering algorithm **Output:** global model *g<sup>t</sup>* 1: **server process:** 2: Before training starts, receive label count vector *vi* 3: *C* = *cluster*(*vi*) 4: **for** *t* = 0,1 . . . ,*T* **do:** 5: Receive model *ωt*+<sup>1</sup> *<sup>b</sup>* from cluster *b* 6: *m*[*b*] = *ωt*+<sup>1</sup> *b* 7: *Wt*+<sup>1</sup> *<sup>b</sup>* <sup>=</sup> <sup>1</sup> *<sup>S</sup>* <sup>∑</sup>*<sup>S</sup> <sup>i</sup>*=<sup>0</sup> *m*[*i*] 8: *n*[*b*] = *Wt*+<sup>1</sup> *b* 9: Send *n*[*k*] to cluster *b*, *k* ∈ *random*(0, *S*) 10: **end for** 11: *g<sup>t</sup>* = <sup>1</sup> *<sup>S</sup>* <sup>∑</sup>*<sup>S</sup> <sup>i</sup>*=<sup>0</sup> *n*[*i*] 12: **cluster mediator:** 13: Receive *W<sup>t</sup>* from server, send *W<sup>t</sup>* to clients in the cluster 14: Receive *ωt*+<sup>1</sup> *<sup>b</sup>*,*<sup>j</sup>* from clients in cluster 15: Aggregate the collected parameters: *ωt*+<sup>1</sup> *<sup>b</sup>* <sup>=</sup> *nb*,*<sup>j</sup> nb <sup>ω</sup>t*+<sup>1</sup> *b*,*j* 16: Send *ωt*+<sup>1</sup> *<sup>b</sup>* to server 17: **client device:** 18: Receive *W<sup>t</sup>* from cluster mediator 19: **for** local iteration **do:** 20: local update *ωt*+<sup>1</sup> *<sup>b</sup>*,*<sup>j</sup>* ⇐ *<sup>W</sup><sup>t</sup>* <sup>−</sup> *<sup>η</sup>*∇*<sup>F</sup>* 21: **end for** 22: Send update model *ωt*+<sup>1</sup> *<sup>b</sup>*,*<sup>j</sup>* to cluster mediator

#### **4. Experiment and Results**

In this section, we introduce the experiments and analyze the simulation results to verify the performance of the FedTCM. In our experiments, there are *Na* clients, *S* mediators, and one central server. The data similarity of the clients is used to divide the *Na* clients into different clusters. *S* mediators will be chosen randomly by the clients in each cluster. We present the results of our experiments on user shopping behavior data and user sports behavior data.

For every experiment on user shopping behavior data, we repeated the following hyperparameters. SGD was used as the optimization method (learning rate, *η* = 0.1; epoch, *e* = 5; and the number of clients participating in training, *Na* = 20). Cosine similarity is used to describe the degree of similarity of client data. The threshold value of cosine similarity is *p* = 0.98.

For every experiment on user sports behavior data, we repeated the following hyperparameters. SGD was used as the optimization method (learning rate, *η* = 0.03; epoch, *e* = 5; batch size, *B* = 3; and the number of clients, *Na* = 10). The threshold value of cosine similarity is *p* = 0.98.

#### *4.1. Dataset and Pre-Processing*

In recent years, online shopping has become the most convenient way to shop. The huge amount of user behavior data can support various large training tasks. Exploring the hidden information can reduce the recommendation cost of an e-commerce platform, and it will provide a great convenience for online shopping in practical applications. The data we selected contains multiple user characteristics and shopping behaviors. After removing the null and error values in the dataset, we select 10 categories of items. This dataset is highly heterogeneous and non-uniform. In this dataset, users have four types of behavior: browse, like, add to cart, and buy. We need to use known features to forecast the behavior of the user.

Moreover, we evaluated the results of our experiments on the user sports behavior dataset. The smart devices that users carry around with them contain various sensors. It does not require active user settings and can record various sports behaviors (including walking, running, etc.). The data we selected contains a variety of user characteristics and exercise habits, so we will use the user sports behavior data to predict the user's physical fitness. The physical condition of the user reflects the lifestyle habits of the user; smart devices can send exercise reminders and customize personalized exercise programs to users based on their physical fitness. Hence, user sports behavior analysis will be of relevance.

#### *4.2. Federated Data Splitting*

The goal of FedTCM is to mitigate the impact of Non-IID data on the model. Therefore, we will set up Non-IID data to represent this problem when dividing the dataset. We use the Dirichlet distribution to generate Non-IID data for each client. The Dirichlet distribution is also known as the multivariate Beta distribution. The density function of Dirichlet is Equation (7):

$$\operatorname{Dir}(X|a) = \frac{1}{B(a)} \prod\_{i=1}^{d} X\_i^{a\_i - 1} \tag{7}$$

where *<sup>α</sup>* <sup>=</sup> {*α*1,..., *<sup>α</sup>d*} <sup>&</sup>gt; 0, *<sup>B</sup>*(*α*) <sup>=</sup> <sup>∏</sup>*<sup>d</sup> <sup>i</sup>*=<sup>1</sup> Γ(*αi*) Γ(∑*<sup>d</sup> <sup>i</sup>*=<sup>1</sup> *<sup>α</sup>i*) . We can sample the dataset according to the Dirichlet distribution. The parameter *α* of the Dirichlet distribution can control the sampling probability of each class of labels in the dataset. In this way, the data amount of each client can calculate the number of each label.

#### *4.3. Baseline Algorithm*

We choose the following algorithm as the baseline algorithm:

NonFed: supervised learning tasks will be executed on 20 devices, but the federation framework will not be deployed on these clients.

FedAvg: an FL algorithm with SGD is executed. Additionally, the aggregation of global models uses a weighted average algorithm.

#### *4.4. Results and Discussion*

First, compared with the NonFed and FedAvg on user shopping behavior data, we analyze the accuracy of FedTCM. As the FedTCM is an asynchronous framework, we cannot compare the three methods in the same round. Therefore, we must set a time basis. We utilize system time on the central server. Additionally, the time baseline is the time consumed by FedAvg running 220 rounds. When the FedAvg ran 220 rounds, it almost has been of convergence. Additionally, the FedTCM and NonFed will run with the same time slice. To simulate the real-world situation, we assume that the clients compute at different speeds. Hence, we assign each client a different training speed via a delay time. The delay time is controlled by parameters *ct*. In the other words, the delay time parameter *ct* will control the calculation time of the slowest device, and the calculation time of other devices is *t* ∈ [1, *ct*]. FedTCM and FedAvg consider the effect of device computational speed. In NonFed, we do not consider the effect of device computation speed, as it will be executed on a single client. In our experiment, FedAvg runs the least number of rounds in fixed time, followed by FedTCM, and finally by NonFed.

Figure 5 shows the accuracy of 3 methods with different degrees of data heterogeneity when batch size *B* = 3 and *ct* = 2. In Figure 5a, when *α* = 10, the data are lower in heterogeneity, and the model is less affected. The accuracy of NonFed fluctuates slightly from start to finish, and FedTCM is 10.1% more accurate than NonFed. In contrast, the federated method performs better in the case of low heterogeneity, and their accuracy curves are smooth with almost no fluctuations. Additionally, FedTCM is more accurate than FedAvg by 1.2%. In Figure 5b, the degree of heterogeneity of the data is increased, and the accuracy curves of the three methods fluctuate to different degrees, especially the fluctuation of NonFed is the most obvious. Hence, NonFed is more sensitive to the degree of heterogeneity of the data. In the federated method, FedAvg converges slower and with a slight decrease in accuracy. FedTCM is not affected by Non-IID in the convergence phase, although the accuracy fluctuation is more obvious in the initial training phase. Additionally, the accuracy of FedTCM is 10.3% and 1.1% higher than NonFed and FedAvg, respectively. In Figure 5c, the degree of data heterogeneity further is increased. Affected by the heterogeneous data, the accuracy of NonFed becomes significantly reduced. Federated methods are also significantly impacted: the accuracy of FedTCM is still 10.4% higher than NonFed. Comparing the two federated methods, the accuracy fluctuations of FedAvg are more obvious and slightly reduced accuracy. However, FedTCM performs better than FedAvg: the accuracy of FedTCM is 1.1% higher than FedAvg. Additionally, whatever the degree of heterogeneity is, FedTCM always can achieve the highest accuracy among the three methods.

**Figure 5.** Comparing FedTCM with the baseline algorithms on the shopping behavior dataset at batch size. *B* = 3, *ct* = 2 (**a**) FedTCM and baseline algorithms are all set *α* = 10. (**b**) FedTCM and baseline algorithms are all set *α* = 5. (**c**) FedTCM and baseline algorithms are all set *α* = 3.

With changes in the degree of data heterogeneity, NonFed does not converge. The accuracy of FedTCM is on average 10.2% higher than NonFed. As the NonFed can only utilize the dataset of one client, the generalization ability of NonFed is the worst one. Thus, the federated methods mitigate the impact of Non-IID data better than NonFed. FedTCM and FedAvg use data from 20 devices under the federation framework. FedTCM converges in 350 system times, and FedAvg converges in 380 system times. Due to the simple parameter aggregation method of FedAvg, FedAvg is susceptible to the effects of non-IID data. Compared with FedAvg, FedTCM is more accurate and converges faster. The experiments verified the effectiveness of the two-tier caching mechanism. FedTCM performs better than the baseline algorithm in the Non-IID shopping behavior data.

As FedTCM uses a two-tier cache mechanism and asynchronous communication between the server and cluster mediator, the server does not need to communicate with all clients but only with part of all cluster mediators. The server is not affected by the speed of client computing. Hence, FedTCM takes less time to run one round. Compared to FedAvg, FedTCM shortens the computation time of each round. Correspondingly, FedTCM ran more rounds than FedAvg in the specified time. As shown in Figure 5, for instance, if we use 65% as the target accuracy, FedTCM takes 157 system times to achieve the target accuracy, while FedAvg takes 228 system times when *α* = 10. Additionally, when *α* = 5, FedTCM takes 172 system time, and FedAvg takes 270 system time. When *α* = 3, FedTCM requires 254 system time, and FedAvg requires 302 system time. Compared with FedAvg, FedTCM requires less time to achieve the target accuracy.

Since the user shopping behavior dataset is highly heterogeneous, three methods would be affected by different batch size *B*. The simulation results with batch size *B* = 5 and *ct* = 2 are shown in Figure 6. In Figure 6a, compared to the results in Figure 5a, although NonFed can achieve higher accuracy, the accuracy fluctuates drastically. Additionally, the federated method has a slower convergence rate. The federated method is more stable than the NonFed, and the accuracy of FedTCM is 7.5% higher than NonFed. Moreover, the accuracy of FedTCM is 1% higher than that of FedAvg. Especially in Figure 6b, the degree of heterogeneity of the data increases, and the accuracy of NonFed is reduced significantly. Hence, NonFed is more sensitive to changes in parameters. At the same time, the accuracy of the federated methods decreased, especially the FedAvg accuracy decreased more. FedTCM is still the best performer among the three methods. The accuracy of FedTCM was 11.2% and 1% higher than NonFed and FedAvg, respectively. In Figure 6c, three methods are affected to different degrees as the data heterogeneity increases. The fluctuation of NonFed accuracy was severe. Although the accuracy of the federated method slightly decreases, the accuracy of the federated method is still higher than NonFed, especially FedTCM. FedTCM is 10.8% more accurate than NonFed. The federated method is affected drastically before 100 system time. Comparing the two federated methods, the accuracy of FedTCM is 1.2% higher than that of FedAvg.

The adjustment of batch size has an impact on all three methods. When increasing the value of batch size, the model will probably fall into local minima, which will affect the experimental results. In all experiments, as the degree of data heterogeneity increases, we note that NonFed still does not converge, and the accuracy of NonFed fluctuates drastically. NonFed is sensitive to parameter changes and has the worst generalization performance. FedTCM starts converging at 380 system times and FedAvg starts converging at 400 system times. Hence, the federated method still performs better than NonFed. Compared to Figure 5, although the convergence rate of the federated method is significantly lower, FedTCM performs better than FedAvg in both convergence speed and accuracy. FedTCM is most suited to process Non-IID data and it has excellent generalization capability.

Same as Figure 5, FedTCM takes less time to run one round. FedTCM ran more rounds than FedAvg in the specified time. In Figure 6, using 65% as the target accuracy, FedTCM takes 257 system time to achieve the target accuracy, while FedAvg takes 304 system time when *α* = 10. When *α* = 5, FedTCM takes 312 system time, and FedAvg takes 396 system time. When *α* = 3, FedTCM requires 358 system time, and FedAvg requires 410 system time. Although we adjusted the experimental parameters, compared with FedAvg, FedTCM still requires less time to achieve the target accuracy.

To verify the influence of the client training speed on the experimental results, we design a group of experiments with different values of *ct*. Figure 7 shows the comparison of FedTCM with the baseline method for different degrees of heterogeneity with *ct* = 4. Since NonFed does not consider the effect of device computation speed between clients, we do not list the results of NonFed. As concluded in the previous section, even after considering the computational speed of the device, FedTCM outperforms FedAvg both in terms of

convergence speed and accuracy. The computational speed of the device does not affect the effectiveness of FedTCM.

**Figure 6.** Comparing FedTCM with the baseline algorithm with different *α* at batch size *B* = 5, *ct* = 2. (**a**) FedTCM and baseline algorithms are all set *α* = 10. (**b**) FedTCM and baseline algorithms are all set *α* = 5. (**c**) FedTCM and baseline algorithms are all set *α* = 3.

In Figure 7 for the same reasons, FedTCM requires less time to achieve the target accuracy than FedAvg.

Figure 8 shows the accuracy of the 3 methods with different degrees of data heterogeneity when *ct* = 2 in user sports behavior data. In Figure 8a, when *α* = 10, the highest accuracy is achieved by all three methods. However, the accuracy curve of NonFed fluctuates drastically and does not converge. The accuracy of FedTCM is 14.4% higher than NonFed. In contrast, the federated methods can perform better with heterogeneous data, both in terms of accuracy and convergence speed. Compared to the typical federated algorithm FedAvg, the accuracy of FedTCM is 1.6% higher than it. In Figure 8b, as the degree of data heterogeneity rises, the accuracy of all three methods is affected, and the accuracy curve fluctuates more significantly, especially NonFed. The accuracy of FedTCM is 15.8% higher than NonFed. Additionally, in the federated method, despite the impact of heterogeneous data on FedTCM, the accuracy of FedTCM is 2.3% higher than FedAvg. In Figure 8c, the degree of data heterogeneity continues to rise, and the accuracy of all three methods is drastically affected. Nevertheless, the accuracy of FedTCM is still the highest among the three methods. The accuracy of FedTCM is 15.6% higher than NonFed and 2.2% higher than FedAvg.

In user sports behavior data, as the degree of data heterogeneity increases, NonFed has the largest drop in accuracy. Hence, NonFed is more sensitive to data changes. In Figure 8, the accuracy of the federated methods is much higher than NonFed, and the federated method is more suitable for heterogeneous data than NonFed. In contrast, the accuracy of FedTCM was, on average, 2.0% higher than FedAvg, and the fluctuations of the accuracy curve of FedTCM were the slightest. FedTCM performs better than the baseline algorithm in the Non-IID sports behavior data.

**Figure 7.** Comparing FedTCM with FedAvg with different *α* at batch size *B* = 3, *ct* = 4. (**a**) FedTCM and FedAvg are all set *α* = 10. (**b**) FedTCM and FedAvg are all set *α* = 5. (**c**) FedTCM and FedAvg are all set *α* = 3.

**Figure 8.** Comparing FedTCM with the baseline algorithms on the sports behavior dataset at *ct* = 2. (**a**) FedTCM and baseline algorithms are all set *α* = 10. (**b**) FedTCM and baseline algorithms are all set *α* = 5. (**c**) FedTCM and baseline algorithms are all set *α* = 3.

Additionally, the comparison of computational time between FedTCM and FedAvg, which is based on the sports behavior data, has a similar result with experiences on the shopping behavior data. FedTCM requires less time than FedAvg to achieve the same target accuracy.

Tables 2 and 3 show the number of model parameters sent/received by the client and server with different degrees of heterogeneity and different device computation speeds on user shopping behavior data. As the degree of heterogeneity increases, the number of model parameters sent/received by the server rises significantly. This occurs because changing the degree of heterogeneity of the data will affect the number of clusters. The number of clusters is highest in *α* = 3 compared to *α* = 5 and *α* = 10. The more clusters there are, the less the cluster will be affected by slow computing clients and the more frequently the cluster will communicate with the server. Moreover, the number of servers sent/received in Table 2 is significantly less than in Table 3 because the computational speed of the device affects the overall experiment time. The slower the computation speed of the device, the longer the time it takes for the model to converge. The faster device does not wait for the slower device, so the faster device can communicate with the server more frequently. Therefore, when *ct* increases, the number of servers sending/receiving increases. The number of models sent/received by local clients depends on the speed of computation of all clients in the cluster, Tables 2 and 3 show the average number of devices sent/received.

**Table 2.** Number of sent/received model parameters with different degrees of heterogeneity in FedTCM (*ct* = 2), \* denotes the result of averaging.


**Table 3.** Number of sent/received model parameters with different degrees of heterogeneity in FedTCM (*ct* = 4), \* denotes the result of averaging.


Table 4 illustrates the number of model parameters sent/received by FedTCM and FedAvg during a fixed experimental time on user shopping behavior data. The sent/received numbers of FedTCM in Table 4 are an average value of the sent/received numbers shown in Tables 2 and 3. Compared to FedAvg, although FedTCM has increased the number of sent/received on the local clients, it has significantly reduced the number of sent/received on the server. For local clients, the increase in the number of local training rounds is not catastrophically burdensome. However, for central servers in large network structures, reducing the communication burden can effectively reduce data congestion and increase the efficiency of communication and computation. Hence, the FedTCM can provide more excellent communication performance than FedAvg, too.

**Table 4.** Number of model parameters sent/received by different methods, \* denotes the result of averaging.


From the client's perspective, FedTCM runs more rounds than FedAvg in the same amount of time. This demonstrates the effectiveness of our designed clustering mechanism and asynchronous communication. Under the same conditions, FedTCM is more robust to device computation speed than FedAvg. As a result, FedTCM can fully utilize local computing resources and execute them more efficiently. From the server's perspective, FedTCM dramatically reduces the number of communications with the server. FedTCM differs from the typical FL algorithm which does not require all clients to communicate with the server. Since each cluster trains the model parameters independently, the aggregated model of the cluster mediator can represent the data information of all clients in the cluster. FedTCM can reduce the communication burden on the server while improving the accuracy of the model.

#### **5. Conclusions**

In this work, we proposed an FL algorithm, FedTCM, based on the two-tier cache mechanism. FedTCM can reduce the impact of Non-IID on user behavior modeling. Although the method without the federated framework can be trained, it cannot converge on the Non-IID data set. Compared to the method without the federated framework, FedTCM exhibits outstanding performance on Non-IID data. In the user shopping behavior dataset, the accuracy of FedTCM is 11.2% higher than NonFed at the maximum, 7.5% higher than NonFed at the minimum, and 10% higher than NonFed at the average under different degrees of data heterogeneity. In the user sports behavior dataset, the accuracy of FedTCM is 15.8% higher than NonFed at the maximum, 14.4% higher than NonFed at the minimum, and 15.2% higher than NonFed at the average for different degrees of data heterogeneity. Therefore, FedTCM has better generalization ability on Non-IID data. In the user shopping behavior dataset, the accuracy of FedTCM is 1.2% higher than FedAvg at maximum, 1% higher than FedAvg at minimum, and 1.1% higher than FedAvg at an average under different degrees of data heterogeneity. In the user sports behavior dataset, the accuracy of FedTCM is 2.3% higher than FedAvg at maximum, 1.6% higher than FedAvg at minimum, and 2% higher than FedAvg at an average under different degrees of data heterogeneity. Meanwhile, FedTCM converged faster than FedAvg, and FedTCM can provide more excellent communication performance than FedAvg. At the same time, in the convergence phase, the accuracy of the baseline algorithm is more volatile with decreasing the *α*, and in contrast, FedTCM maintained a smoother accuracy.

The goal of our proposed approach is to use a single model to mitigate the impact of Non-IID data on the model. A potential limitation is that a single model is an optimal solution for the global task, but with Non-IID data, it is not optimal for every client task. Even though the global model achieves the highest accuracy in the global data, it may not be suitable for each client because the local client data distribution is different from the global data distribution. In the future, we will consider federated learning approaches that generate multiple personalized models. Personalized models are created for each client through both the public knowledge from the other clusters and the specific knowledge of the current client, enhancing the generalizability of personalized models on different data distributions.

**Author Contributions:** Conceptualization, J.Z. and Z.L.; methodology, J.Z. and Z.L.; software, Z.L.; validation, J.Z. and Z.L.; formal analysis, J.Z. and Z.L.; investigation, J.Z. and Z.L.; resources, J.Z. and Z.L.; data curation, J.Z. and Z.L.; writing—original draft preparation, J.Z. and Z.L.; writing—review and editing, J.Z. and Z.L.; visualization, J.Z. and Z.L.; supervision, J.Z.; project administration, J.Z. and Z.L.; funding acquisition, J.Z. All authors have read and agreed to the published version of the manuscript.

**Funding:** This paper is supported by the project "User Behavior Features Oriented Research on Analysis of Multi-Source Data in CDN" (20200401082GX), which is financially supported by the Science and Technology Development Program of Jilin Province, China.

**Data Availability Statement:** Due to the nature of this research, participants of this study did not agree for their data to be shared publicly, so supporting data is not available.

**Conflicts of Interest:** The authors declare no conflict of interest.

#### **References**


**Disclaimer/Publisher's Note:** The statements, opinions and data contained in all publications are solely those of the individual author(s) and contributor(s) and not of MDPI and/or the editor(s). MDPI and/or the editor(s) disclaim responsibility for any injury to people or property resulting from any ideas, methods, instructions or products referred to in the content.

### *Article* **Deep Deformable Artistic Font Style Transfer**

**Xuanying Zhu 1, Mugang Lin 1,2,\*, Kunhui Wen 1, Huihuang Zhao 1,2 and Xianfang Sun <sup>3</sup>**


**Abstract:** The essence of font style transfer is to move the style features of an image into a font while maintaining the font's glyph structure. At present, generative adversarial networks based on convolutional neural networks play an important role in font style generation. However, traditional convolutional neural networks that recognize font images suffer from poor adaptability to unknown image changes, weak generalization abilities, and poor texture feature extractions. When the glyph structure is very complex, stylized font images cannot be effectively recognized. In this paper, a deep deformable style transfer network is proposed for artistic font style transfer, which can adjust the degree of font deformation according to the style and realize the multiscale artistic style transfer of text. The new model consists of a sketch module for learning glyph mapping, a glyph module for learning style features, and a transfer module for a fusion of style textures. In the glyph module, the Deform-Resblock encoder is designed to extract glyph features, in which a deformable convolution is introduced and the size of the residual module is changed to achieve a fusion of feature information at different scales, preserve the font structure better, and enhance the controllability of text deformation. Therefore, our network has greater control over text, processes image feature information better, and can produce more exquisite artistic fonts.

**Keywords:** style transfer; generative adversarial networks; deformable convolutional networks; artistic font generation

#### **1. Introduction**

An artistic font is a beautifully deformed font based on traditional fonts [1] from an artistic and decorative interpretation according to the meaning, character shape, and structural features of the texts. Because of their beautiful and interesting eye-catching characteristics, artistic fonts are widely used in propaganda, advertising, trademarks, and other scenarios and are becoming increasingly popular among the public. A traditional artistic font is designed by professional font designers, so its effect is influenced by the professional level of the designers and other factors. In recent years, with the advent and development of machine learning technology [2], people have applied deep learning methods to artistic font generation to achieve better results.

Currently, the majority of image style transfer methods are based on convolutional neural networks (CNNs). These methods adjust a noisy random image by using an optimization function so that the generated image maintains the content of a normal image while keeping part of the style of the original image. Since artistic fonts can be viewed as beautiful images, image style transfer methods can also be applied to artistic font style transfers. However, the key to artistic font generation is to synthesize text texture and add colorful texture information to the target text. Compared with image style transfer methods, artistic font style transfer methods need to extract the edge features of texts more accurately to maintain the integrity of the font structure in the stylization process. As CNNs adopt a fixed shape of convolution kernels and lack internal mechanisms

**Citation:** Zhu, X.; Lin, M.; Wen, K.; Zhao, H.; Sun, X. Deep Deformable Artistic Font Style Transfer. *Electronics* **2023**, *12*, 1561. https:// doi.org/10.3390/electronics12071561

Academic Editor: Donghyeon Cho

Received: 27 February 2023 Revised: 19 March 2023 Accepted: 24 March 2023 Published: 26 March 2023

**Copyright:** © 2023 by the authors. Licensee MDPI, Basel, Switzerland. This article is an open access article distributed under the terms and conditions of the Creative Commons Attribution (CC BY) license (https:// creativecommons.org/licenses/by/ 4.0/).

to adaptively change the shape of convolution kernels, it is difficult to extend them to new tasks with unknown, complex geometric transformations. For example, for visual recognition with fine localization, different locations need to correspond to different scales or perceptual field sizes appropriate to them while the fixed convolutional kernels limit CNNs to satisfying this requirement. For artistic font style transfer methods based on CNNs, both structural disjunction and stroke overlap will occur in the case of a complex glyph structure. In addition, some conventional CNNs will consider some background features as edge features of the text during feature extraction, which leads to the addition of noise points to the image in the feature extraction process of the glyph, resulting in a double shadow and style spillover in the style transfer process. These problems directly lead to stylized font images not being accurately identified.

Shape-matching GAN [3] is an effective model that can realize multiscale deformation of artistic fonts. The model encoder consists of a general CNN and a controllable module. The original controllable module is composed of the same double-branch network of two layers of convolution, the convolution kernel of each convolution layer is 3, and the receptive field of the obtained feature map is 5. The deep network with residuals at different depths performs better than the shallow network, but the higher the number of layers is, the more overfitting will occur. In addition, features will be lost in the convolution process, and a larger receptive field can better ensure the integrity of the information. In summary, the inherent nature of the general convolutions and the small receptive field of the controllable module prevent shape-matching GAN from recognizing complex fonts, resulting in unclear image glyphs and style overflow.

By the above analysis, there are two challenges for artistic font style transfer methods based on CNNs to improve their performances: (1) how to accurately extract the edge features of texts to provide integral glyphs for generating artistic font; (2) how to eliminate double shadow and style overflow caused by noise. In this paper, a novel artistic font generation network is proposed. To address the first challenge, an encoder for glyph generation is designed which introduces a deformable convolution [4,5] that can freely change the receptive field by adjusting the offset of sampling locations, thus improving its ability to the geometric variations of texts and making it learn more complete information of glyphs. Aiming at the second issue, the difference in adjacent pixel values is calculated as a smoothing loss, and the smoothness of image edges is maintained by reducing the loss. The contributions of this paper are summarized as follows:

(1) A deep deformable artistic font style transfer network (DAF) is proposed which consists of a Sketch module for learning glyph mapping, a Glyph module for learning style features, and a Transfer module for a fusion of style textures. In the Glyph module, a Deform-Resblock encoder (DR encoder) is designed to extract glyph features, in which a dilation convolution and a deformable convolution are used to change the perceptual field so that the encoder focuses on information about more critical features. The deformable convolution can also help better integrate feature information at different scales to ensure that the generated glyphs maintain their complete font structure.

(2) Ghosting is eliminated and the image is smoothed by introducing a smoothing loss function that reduces the difference in the value of adjacent pixels in the image.

(3) Comparing the proposed model with four current advanced artistic font style transfer methods, the experimental results show that the proposed model is effective and has better performance.

The remainder of this paper is organized as follows. Section 2 reviews related work involving deformable convolutional networks, image style transfer, and font style transfer. In Section 3, the proposed DAF model is then proposed with a detailed description. To evaluate the performance of our model, a series of experiments are conducted in Section 4. Finally, we summarize this paper in Section 5.

#### **2. Related Work**

#### *2.1. Deformable Convolutional Networks*

Research on CNNs [6,7] dates back to the neocognitron model proposed by Japanese scientist Kunihiko Fukushima [8] in 1980. It is the first neural network to use convolution and downsampling, and it is also the prototype of convolutional neural networks. In 1989, Yann LeCun [9,10] constructed a CNN for computer vision problems, which was one of the first CNNs, i.e., the original version of LeNet. It uses convolutional layers and pooling layers for the first time, and achieves remarkable accuracy in handwritten character recognition tasks. As LeNet continued to be studied and its subsequent variants defined the basic structure of modern CNNs, its success drew attention to the application of CNNs [11]. Recently, CNNs [12] have achieved significant success in visual recognition tasks, but CNNs are limited to modeling large unknown transformations and lack internal mechanisms to handle geometric transformations. They have difficulty handling a finely localized visual recognition of objects of different scales or deformations. Deformable convolutional networks [4,5] overcome these limitations and shortcomings by introducing a deformable convolution module and a deformable RoI pooling module to improve the transform modeling capabilities of CNNs. In the deformable module, the grid sampling positions of the standard convolution are shifted by 2D offsets learned by an extra convolutional layer. Deformable RoI pooling adds 2D offsets to each bin position in the previous RoI pooling. Thus, the sampling and pooling of a deformable convolutional network can vary with the object's different structure so that it can adjust its feature representation according to the object's configuration. Currently, deformable convolutional networks are widely used in fields such as image processing [13–15], complex vision [16,17], pattern recognition [18,19], and other fields [20,21], where they show powerful performance.

#### *2.2. Image Style Transfer*

Image style transfer is the migration of a style image so that the input image has the style characteristics of the style image. In 2015, Gatys [22] proposed neural style transfer to facilitate image style transfer. However, neural style transfer has some drawbacks. For example, the network must be trained in each migration, which is very slow and cannot achieve real-time migration, and the style migrations on photos may be distorted. To address these problems, Johnson [23] proposed a fast neural style transfer method to train one network for each style image so that only one forward process is needed to obtain the generated image in a test with a given content image. Luan et al. [24] proposed photo style transfer, which solved the photo distortion problem by improving the loss function. Generative adversarial networks (GANs) [25] are neural networks designed to solve the problem of generative modeling in which the generative model learns to capture the statistical distribution of training data and synthesizes samples from the learned distribution. Consequently, GANs have become a prevalent method for image style transfer [26,27]. Conditional GANs (CGANs) [28] introduce image-to-image translation and a structured loss function to make networks not only learn the mapping from an input image to an output image, but also learn the loss function of training this mapping. This feature makes CGANs suitable for image generation problems. Because there are no content and style constraints on CGANs, their output results are more similar to artistic creations, and the effect is significantly improved compared to the other image generation models. In recent years, GAN has become a hot research direction, and more variant models have been proposed, such as CycleGAN [29], Wasserstein GAN (WGAN) [30], deep convolutional GAN (DCGAN) [31], and shape-matching GAN [3]. Compared with other methods, GANs produce richer artistic effects for image style transfer.

#### *2.3. Font Style Transfer*

Font style transfer [32], the process of extracting artistic features from images of a given style and integrating artistic characters into text images, is a long-standing research problem. Font synthesis is the process of translating a font from one domain to a font from another domain, and the key to this process is predicting the shapes of the glyphs. Unlike font synthesis, font style transfer is a challenging problem of transferring the color and texture of artistic styles to new glyphs. The BAIR Lab at Berkeley collaborated with Adobe to design a multi-content GAN for font style transfer [33]. First, they developed a new decorative network to predict the color and texture of the final glyphs. Then, Yang et al. [34] researched dynamic artistic text style migration with glyph style degree control and proposed a novel bidirectional shape-matching framework for font style migration. They introduced a scale-ware shape-matching GAN to learn glyph style shape mapping, model the style shape features at multiple scales simultaneously, transfer them to target glyphs, and generate high-quality and controllable artistic text. Subsequently, Zhang et al. proposed a font effect generation model [35] based on pyramid-style features based on Yang's work, using morphological operations to improve the transfer effect. Recently, a diverse transformation network for text style transformation [36] has been proposed, which can generate multiple styles of text images in a single model, allowing all styles to be effectively trained on the network.

#### **3. Artistic Font Generation Network**

Image style transfer migrates the style of an image to another image. Unlike image style transfer, font style transfer migrates the style of an image into the text of another image. Thus, if the image style transfer method is copied, the structural characteristics of the text will be destroyed. Yang et al. [3] studied fast and controllable artistic text style transfer in terms of font deformation and proposed a shape-matching GAN for text style transfer. However, by repeatedly testing the shape-matching GAN model, we found that is unable to extract clear font features for fonts with complex strokes, which leads to problems such as stroke adhesion, fuzzy edges, and severe deformation of the font. Currently, these problems are addressed by preprocessing the input image, but this method is time-consuming and difficult to implement. To overcome these limitations, in the section, we propose a deep deformable artistic font style transfer network. The key features of this network are the design of the DR encoder to learn font features and extract more image information, and the introduction of a smoothing loss to preserve key edge details of the font images. Thus, the network is better able to extract features, control font deformation, and maintain the structure of complex glyphs.

#### *3.1. Overall Network Architecture*

The overall network architecture consists of three main components: (1) a Sketch module for learning glyph mapping, (2) a Glyph module for learning style features to generate deformed font, and (3) a Transfer module for style texture fusion. The network architecture is shown in Figure 1.

As shown in Figure 1a, the network training process is divided into three modules. They are the Sketch module, Glyph module, and Transfer module. In the Sketch module, we first process the style image and turn it into a mask, which can be easily obtained using image editing tools. We call the mask structure mapping *X* of the style image and then use *X* and the style image as the input of the sketch module *GB*. The sketch module *GB* is composed of a smoothing block, a transformation block, and a smoothing loss function. The smoothing block is used to smooth the input image and the transformation block is used to map the smoothed style image back to the text field. Therefore, the edge of the style image can learn the edge features of the text image and realize structural transformation. We first smooth the input style image mask *X* by sketch model *GB* to weaken the edges. We introduce a new loss function called the smooth loss to maintain the smoothness of the image so that the font can better learn the features of the style image that we provide. We transform the style image into different degrees of deformation by adjusting the parameter *l*(*l*∈[0, 1]). After deformation, the mask is generated. In the Glyph module, we train the *GS* network. By clipping mask, the training pairs of sketch shapes with different smoothness can be obtained. The training pair is fed into the glyph network *GS*, which is trained to

map it to the original *X* so that it can characterize the shape features of *X* and transfer these features to the target text. Thus, it can increase data diversity and force the model to learn more robust features, hence effectively improving the generalization ability of the model. Through the *GS* network, the font learns the style structure features and obtains the deformed font mask. In the Transfer module, we train the network *GT* which is similar to training *GS*. It is necessary to randomly cut a style image mask *X* and a style image to form a training pair as the input of *GT*. The network *GT* is trained to perform texture rendering instead. Style migration is performed on the input image to allow the deformed font to have the style features of the style image.

**Figure 1.** Architecture of the deep deformable artistic font style transfer network.

Figure 1b shows the network testing process. *GS* learns the structural features of style images through training. By inputting text mask images and style mask images, text images can learn the corresponding style features and generate deformed text mask images. The deformed text mask is input into *GT* for style texture migration to obtain the final result.

#### *3.2. Glyph Networks (GS)*

The generator encoder of generative adversarial networks is generally a convolutional neural network, which consists of a convolutional layer, a pooling layer, and a batch normalization layer. The DR encoder is redesigned as shown in Figure 2. We first fill the input feature map repeatedly, filling the feature map to a specific size, and then use dilation convolution to expand the receptive field of the network on the feature map. Second, we downsample the feature map twice and shift the target features through deformable convolution to obtain more accurate edge features. Finally, the feature map is fed into the controllable deep residual network and linearly superimposed, and the corresponding feature map is output. By continuously learning and constantly adjusting the size of the convolutional layers to obtain the most suitable depth for this network, the texture generation network retains as many complex font structure features as possible. Considering that pooling degrades the performance of the generative model, the encoder uses stepwise convolution for reduced sampling. In addition, we use transposed convolution for feature upsampling to avoid checkerboard artifacts.

**Figure 2.** DR encoder structure diagram.

The glyph generation network generator consists of an encoder and a decoder. The encoder is crucial in the glyph network, which determines whether the feature fusion process can maintain the glyph structure. To allow the network to effectively recognize font details, we design the structure of the DR encoder, as shown in Figure 2. The glyph generation network extracts the desired text and style features with the DR encoder and optimizes the training process by learning a large number of samples. After the training process, the network has learnt the corresponding stylistic features and can directly perform stylizations to generate font masks with stylistic features, which significantly reduces the time and space complexity compared with other networks, making the application of style transformation possible.

The structure of the DR encoder is improved mainly by designing the residual module size and introducing deformable convolution. The convolutional layer in the encoder is responsible for acquiring the local image features. The field of perception is fixed by the size of the convolution kernel during the computation of ordinary convolutions. We can expand the field of perception only by changing the size of the convolutional kernel or increasing the number of convolutional layers, which inevitably increases the number of parameters and computations of the network model and affects model efficiency. Therefore, we use the dilated convolutional layer instead of the normal convolutional layer to expand the corresponding field of perception without changing the size of the convolutional kernel to increase the network attention to include more features and obtain more detailed information. In CNNs, we calculate the size of the perceptual field by Equation (1):

$$\mathbf{g}\_{\rm n} = \mathbf{g}\_{n-1} + (k\_{\rm n} - 1) \ast \prod\_{I=1}^{n-1} \mathbf{S}\_i \tag{1}$$

where *g* is the receptive field layer, *n* is the number of layers, *Si* is the step size of the *i*-th layer convolution or pooling, and *k* is the size of the convolution kernel which is based on Equation (1) to make the receptive field grow exponentially.

The dilated convolution has a hyperparameter dilation rate *r*, which represents the interval of the convolution kernel, the dilation rate of the standard convolution is 1. We calculate *r* through Equation (2).

$$
\tau = 2^{\log\_2 \text{rate} + 2} - 1 \tag{2}
$$

In our calculation formula, *rate* defaults to 1. Dilated convolution increases the field of perception of the convolution kernel while keeping the number of parameters constant, so that each convolution output contains a larger range of information, allowing us to better detect feature targets and capture contextual information. However, there is a limitation of convolution for complex font cavities, where too large a perceptual field blurs detailed features when there are more strokes in the font. Therefore, to compensate for the dilated convolution insufficiency, we introduce deformable convolution in the encoder and use additional offsets to increase the spatial sampling position in the module so that our convolutional layer can automatically adjust the scale or perceptual field to obtain the best image.

In addition, the adjustment of the direction vector of the convolution kernel is added to the traditional convolution to shift the morphology of the convolution kernel closer to the feature object. The convolution process is shown in Figure 3.

**Figure 3.** Schematic diagram of the joined deformable convolutional network.

The CNNs can extract the feature maps, use the feature maps as input and apply another convolutional layer to them. In Figure 3, there is an additional convolutional layer to learn the offset and to share the input feature maps. The purpose of this layer is to obtain the offset of the convolutional deformation; we use Equation (3), an interpolation algorithm is used to learn the offset, which is learned by backpropagation. The difference in deformable ConvNets is that they perform dense spatial transformations in a simple, efficient, deep, and end-to-end manner. The deformable convolution introduces an offset Δ*Pn* for each point, which is generated from the input feature map with another convolution, usually a fractional number. *Pn* is each offset of *P*<sup>0</sup> in the range of the convolution kernel in Equation (4) and is represented as follows.

$$\mathcal{Y}(P\_0) = \sum\_{P\_0 \in \mathcal{R}} w(P\_n) \cdot \mathbf{x}(P\_0 + P\_n + \Delta P\_n) \tag{3}$$

$$\mathbf{x}(P) = \sum\_{q} \max(0, 1 - |q\_{\mathbf{x}\prime} p\_{\mathbf{x}}|) \cdot \max\left(0, 1 - |q\_{\mathbf{y}} - p\_{\mathbf{y}}|\right) \cdot \mathbf{x}(q) \tag{4}$$

In the subsequent ablation experiments, it is also verified that the introduction of deformable convolution can better extract the features of text images and style images in the encoder, and improve the control of font deformation. This innovation is the key to solving the problem that the feature extraction of complex fonts is not in place, and the font deformation seriously loses the original font structure.

In addition to improved encoder functionality using the introduction of deformable convolution, we find that surface subdivision artifacts appear in the input residual module. This feature can carry the edge features at the edges of the font that are not recognized by the network, which can lead to a style overflow in the network after style migration. To address this issue, we increase the depth of the residual module to further control the font deformation strength.

#### *3.3. Transfer Networks (GT)*

For the *GT* module, we use the texture network structure of the shape-matching GAN [1] model. After the glyph generation network, we obtain a text mask style image with the learned style features. Similar to the glyph generation network *GS*, a large number of data pairs are obtained by clipping the style image and the text mask image, and a large number of data pairs are trained to quickly build an end-to-end fast text style model so that the style network can adapt to the shape of text and quickly generate target images. The network can generate multiple styles of text images and easily control the style of the text. The main idea is that by taking the deformed text images with style characteristics that we have generated as input for the transmission network, we can select the style images that need to be migrated, and all text images can be effectively trained in the network to obtain the corresponding style images. The advantage of this network is that multiple text styles can be generated using a single model, and the generation of text styles can be controlled.

#### *3.4. Loss Function*

The loss of the network *GS* contains a reconstruction loss and an adversarial loss. In the reconstruction loss, *l*(*l*∈[0, 1]) controls the degree of deformation. Set *l* to control font deformation and to realize multiscale style migration. *x* represents the structural sketch obtained after a binary transformation of the style image, and *y* represents a raw style image. <sup>∼</sup> *xli* represents the result of style structure images with different degrees of deformation obtained from the *GB* network. We use a mask image as an information guide to reconstruct the structure of the different style images. The reconstruction loss restores the structure of the different degrees of images for each style to the structure of the original. In the adversarial loss, we add the mask images to the generator and the discriminator, similar to the conditional GAN procedure.

$$\mathcal{L}\_s^{rcc} = \sum\_{i=1}^N \mathbb{E}\_{\mathbf{x}, l, mask} \left[ \parallel \: \mathbf{G}\_{\mathbf{S}} \left( \stackrel{\sim}{\mathbf{x}}\_{l\_i \prime} l, mask\_i \right) - \mathbf{x}\_i \parallel\_1 \right] \tag{5}$$

$$\mathcal{L}\_s^{adv} = \sum\_{i=1}^{N} \mathbb{E}\_{\mathbf{x}, \text{mask}}[\log D\_{\mathbf{S}}(\mathbf{x}\_i, \text{mask}\_i)] + \sum\_{i=1}^{N} \mathbb{E}\_{\mathbf{x}, \text{mask}}[\log \left(1 - D\_{\mathbf{S}} \left( \mathbf{G}\_{\mathbf{S}} \left( \widetilde{\mathbf{x}}\_{li}, l, \text{mask}\_i \right) \right) \right)] \tag{6}$$

The overall *GS* loss is as follows:

$$\mathcal{L}\_{\mathbb{G}\_{\mathbb{S}}} = \underset{G\_{\mathbb{S}}}{\text{min}} \max\_{D\_{\mathbb{S}}} \lambda\_{\text{s}}^{adv} \mathcal{L}\_{\text{s}}^{adv} + \lambda\_{\text{s}}^{rcc} \mathcal{L}\_{\text{s}}^{rcc} \tag{7}$$

The main task of the *GT* network is to assign texture features to the structural images obtained in *GS*. The loss of the network *GT* includes reconstruction loss, conditional adversarial loss, style loss, and texture loss. Style loss <sup>L</sup>*sty <sup>T</sup>* is proposed in the neural style transfer.

$$\mathcal{L}\_{T}^{\text{rcc}} = \sum\_{i=1}^{N} \mathbb{E}\_{\mathbf{x}, \mathbf{y}, \text{mask}} \left[ || \; G\_{T}(\mathbf{x}\_{i}, \text{mask}\_{i}) - \mathbf{y}\_{i} \; ||\_{1} \right] \tag{8}$$

$$\mathcal{L}\_T^{adv} = \sum\_{i=1}^N \mathbb{E}\_{\mathbf{x}, \text{mask}, \mathbf{y}} [\log D\_T(\mathbf{x}\_i, \text{mask}\_i, \mathbf{y}\_i)] + \sum\_{i=1}^N \mathbb{E}\_{\mathbf{x}, \text{mask}} [\log(1 - D\_T(G\_T(\mathbf{x}\_i, \text{mask}\_i)))] \tag{9}$$

The overall *GT* loss is as follows:

$$\mathcal{L}\_{\mathbb{G}\_T} = \min\_{\mathbb{G}\_T} \max\_{D\_T} \lambda\_T^{adv} \mathcal{L}\_T^{adv} + \lambda\_T^{rec} \mathcal{L}\_T^{rec} + \lambda\_T^{sty} \mathcal{L}\_T^{sty} + \lambda\_T^{lex} \mathcal{L}\_T^{lex} \tag{10}$$

The loss function described above applies to our basic networks, *GS* and *GT*. In the sketch model, we first select a text image *t* as the base image and randomly select an *l* value within [0, 1] to reconstruct image *t*.

$$\mathcal{L}\_B^{\rm rcc} = \mathbb{E}\_{t,l}[||G\_B(t,l) - t\||\_1] \tag{11}$$

After obtaining the reconstructed image *t*, we generate an adversarial loss function to make the reconstructed image more similar to the original image.

$$\mathcal{L}\_B^{\text{adv}} = \mathbb{E}\_{tJ} \left[ \log D\_B \left( t\_\prime I\_\prime \stackrel{-}{t}\_I \right) \right] + \mathbb{E}\_{tJ} \left[ \log \left( 1 - D\_B \left( G\_B(t\_\prime I)\_\prime J\_\prime \stackrel{-}{t}\_I \right) \right) \right] \tag{12}$$

When the target image is smoothed by the sketch model for edge features, the image is not smoothed well due to the influence of the recovery algorithm on the noise amplification, which causes some features to be lost and additional noise features to be added to our image when it is input in *GS*. Consequently, in the process of migrating the resultant image of *GS* for stylistic features, a small amount of noise has a great impact on the result, resulting in shadows at the edges of the images, and the total proportion of images contaminated by noise is significantly larger than the proportion of noise-free images. Therefore, we design a new smooth loss by adding regular terms in the sketch model to maintain the smoothness of the image. The difference in adjacent pixel values in the image can be solved to some extent by reducing the loss, and our loss solves the edge shading problem. We also implement the noise constraint by sacrificing image sharpness, which finally solves the problem of noise and poor edge smoothing in the image. The following equation is the regular term that we add.

$$\Re\_V \beta(\mathbf{x}) = \sum\_{i,j} \left( \left( \mathbf{x}\_{i,j-1} - \mathbf{x}\_{i,j} \right)^2 + \left( \mathbf{x}\_{i+1,j} - \mathbf{x}\_{i,j} \right)^2 \right)^{\frac{\beta}{2}} \tag{13}$$

The overall sketch model loss is as follows:

$$\mathcal{L}\_{Smooth} = \underset{G\_B}{\text{minmax}} \lambda\_B^{adv} \mathcal{L}\_B^{adv} + \lambda\_B^{rcc} \mathcal{L}\_B^{rcc} + \Re\_V \beta(\mathbf{x}) \tag{14}$$

#### **4. Experiment**

#### *4.1. Dataset*

We use the dataset TE141K [37] which contains 152 professionally designed text effects rendered on glyphs, including English letters, Chinese characters, and Arabic numerals. The dataset is divided according to the 8:2 ratio, including 608 pictures in the training set and 152 pictures in the test set. This dataset is one of the largest font style migration datasets to date and can be used in research areas such as font style migration, multidomain transfer, and image-to-image translation.

#### *4.2. Training*

Our model consists of the sketch module *GB*, glyph module *GS* and transfer module *GT*, so we divide the training strategy into three steps and randomly crop the images to a 256 × 256 image size before the training starts. For the optimizer, we use the Adam optimizer and set the learning rate to 0.0002. We perform 3 training epochs. First, we need only input a style image mask to train the sketch module *GB*. Then, the model smooths the input image to reduce the sharpening of the image edges, and in this process, the smoothing effect of the network on the image is further improved by the smooth loss we design. We need the model to connect the source style domain and the target text domain using a smoothing block, which maps the style image and the style image mask to the smoothing domain, where the details are eliminated, and the contours show a similar degree of smoothing. According to the adjustment parameter *l* (*l*∈[0, 1]), the smoothed style image is transformed into different degrees of a mask. Next, we train the *GS*. By clipping the mask, the training pairs of sketch shapes with different smoothness can be obtained. The training pair is fed into the *GS* network, and the glyph network *GS* is trained to map it to the original text mask so that *GS* can characterize the shape features of text image mask and transfer these features to the target text. The encoder we design can more flexibly control font deformation at different levels and enhance the model generalization ability. The dilated convolution, deformable convolution and residual block structure we design make the edges of stylized images more convergent to the edges of text images and font deformation more flexible and controllable. Finally, we train the *GT* module. Here, it is necessary to randomly cut a style image mask and a style image to form a training pair as the input for *GT*. The network *GT* is trained to perform texture rendering instead. Style migration on the input image is performed so that the deformed font has the style features of the style image.

#### *4.3. Comparisons with State-of-the-Art Methods*

We used shape-matching GAN as the baseline and conducted a number of experiments. The effects of our proposed method on artistic text style transfer are shown in Figure 4. On the one hand, our method is superior to the baseline at stylizing complex glyphs. On the other hand, our method represents a significant improvement over the baseline method for complex glyphs, ensuring a clear font structure and improving legibility.

**Effect picture comparison.** In Figure 5, we qualitatively compare our method with four state-of-the-art style transfer methods, neural style transfer (NST) [21], LapStyle [38], multi-style transfer (MST) [36], and shape-matching GAN [3]. These methods are chosen because they are all one-way style transfers, and most style transfer methods are derivative versions of these methods. (a) NST is the most basic style transfer, which uses a CNN for feature extraction and then uses the extracted features for reconstruction. It can transfer the style but cannot learn the style features, and the glyphs are homogenized. (b) LapStyle splits the complex style migration into an initial migration at low resolution and a correction process at high resolution, which effectively improves the quality and the speed of stylization. Thus, LapStyle transfer is more suitable for overall image style migration. However, this method is not applicable to artistic font text generation because it is ineffective in extracting the features of fonts, which represent only one aspect of text images. (c) MST is a recently proposed and diversified transformation network for text style transfer that can generate multiple text images in a single model and control the text style in a simple way. (d) Shape-matching GAN is our baseline method, which cannot maintain the structure of the complex font glyphs. As seen from the results of the comparative experiments in Figure 5, our proposed method has obvious advantages in terms of the effectiveness of artistic font generation. Most other methods involve style transfer of the whole style image and thus have an insufficient feature extraction effect on text and style images, which leads to the inability to generate clear and beautiful artistic text. In contrast, by introducing a deformable convolution and an improved residual module, our proposed network enhances the control of font deformation, enabling a more detailed font

feature extraction effect and solving the problems of severe font deformation and unclear character shapes. It differs from other style transfer networks in that the text has texture details while learning the image style, making the generated artistic characters nondual and artistically ornamental.

**Figure 4.** Our artistic text style transfer effects.

**Execution time comparison.** We compare the time needed to generate an image of different models in the testing process with Intel Core i7-11700k 3080 10G, as shown in Table 1. We input 320 × 320 images into the model and average the reasoning time required for 100 pictures. As seen from Table 1, each image generated by our proposed mode requires only 0.039 s on average, and we can nearly interact with users in real time. Our time is slightly longer than that of shape-matching GAN [3] because of the addition of deformable convolution to the model. Deformable convolution adds only a small overhead for the model parameters and computation. However, it is precisely because of deformable convolution that our model can better capture the edge features of fonts and produce better results. NST [22] takes a long time to execute because it requires several iterations during testing to generate the final result.

**Table 1.** Execution time comparison.



**Figure 5.** Comparison with state-of-the-art methods on various styles.

#### *4.4. Ablation Study*

To analyze the advantages of our improvements on the baseline model, we design the following experiments with different configurations:


The results of this ablation experiment are shown in Figure 6. It can be seen that compared with the baseline network, W/o SL enhances the smoothing performance, which can make the text better learn style features and maintain the font. The W/o NCR model improves the legibility of the font and can guarantee the structural features of the font. However, the edge features are recognized, resulting in style overflow. Therefore, due to the flexibility of deformable convolution in feature extraction, the W/o DC model solves the problem of style overflow caused by identifying unnecessary edge features. In sum, when we adopt the full model, the results effectively solve the problem of the missing glyph structure and greatly increase the visibility of the text.

**Figure 6.** Comparison chart of ablation experiments: (**a**) represents our original data, and (**b**) is the style features we need to migrate. The first row on the right is the resulting graph of the texture generation network *GS*, and the second row is the final output graph. From left to right are the output results of the model we proposed above.

#### **5. Conclusions**

In this paper, we propose the deep deformable artistic font style transfer network that maps the stylistic features of an image to the text of a text image and controls the degree of font deformation by adjusting parameters to achieve diverse style migration. In the network, the DR encoder that we designed can effectively extract font features, control font deformation, greatly improve the recognition accuracy of complex fonts, and enable the network to generate more exquisite art fonts. The DAF network is divided into three modules, and each module can be trained separately. In the sketch module, smooth loss is introduced to enhance the smoothness of the font edges and improve the similarity between the font edges and the edge transformations of the style images. In the *GS* module, the novel DR encoder is used to better preserve the font structure and improve font legibility. The *GT* module is trained to transfer the style image features to the font image so that the font not only retains its own glyph structure but also integrates the style features. We

verified the effectiveness and robustness of the method by comparing it with state-of-the-art migration algorithms. In future work, we hope to integrate the attention mechanism with the DR encoder to improve font adaptivity, which will make the font style transfer more precise for text, resulting in more beautifully migrated text. Additionally, we will work on research measuring an improvement in contour definition.

**Author Contributions:** Conceptualization, X.Z. and M.L.; methodology, X.Z. and M.L.; software, X.Z. and K.W.; validation, X.Z. and K.W.; formal analysis, M.L.; writing, X.Z.; review and editing, M.L., H.Z. and X.S.; visualization, X.Z. All authors have read and agreed to the published version of the manuscript.

**Funding:** This research was supported in part by the Scientific Research Fund of Hunan Provincial Education Department (22A0502), the National Natural Science Foundation of China (61772179), the Hunan Provincial Natural Science Foundation of China (2019JJ40005), the 14th Five-Year Plan Key Disciplines and Application-oriented Special Disciplines of Hunan Province (Xiangjiaotong [2022] 351), the Science and Technology Plan Project of Hunan Province (2016TP1020), the Science and Technology Innovation Project of Hengyang(202250045231), the Open Fund Project of Hunan Provincial Key Laboratory of Intelligent Information Processing and Application for Hengyang Normal University (2022HSKFJJ012), and the Postgraduate Scientific Research Innovation Project of Hunan Province (QL20210262).

**Data Availability Statement:** Not applicable.

**Conflicts of Interest:** The authors declare no conflict of interest.

#### **References**


**Disclaimer/Publisher's Note:** The statements, opinions and data contained in all publications are solely those of the individual author(s) and contributor(s) and not of MDPI and/or the editor(s). MDPI and/or the editor(s) disclaim responsibility for any injury to people or property resulting from any ideas, methods, instructions or products referred to in the content.

### *Article* **Enhancement of E-Learning Student's Performance Based on Ensemble Techniques**

**Abdulkream A. Alsulami 1,2,\*, Abdullah S. AL-Malaise AL-Ghamdi 1,3 and Mahmoud Ragab 4,5**


**Abstract:** Educational institutions have dramatically increased in recent years, producing many graduates and postgraduates each year. One of the critical concerns of decision-makers is student performance. Educational data mining techniques are beneficial to explore uncovered data in data itself, creating a pattern to analyze student performance. In this study, we investigate the student E-learning data that has increased significantly in the era of COVID-19. Thus, this study aims to analyze and predict student performance using information gathered from online systems. Evaluating the student E-learning data through the data mining model proposed in this study will help the decision-makers make informed and suitable decisions for their institution. The proposed model includes three traditional data mining methods, decision tree, Naive Bays, and random forest, which are further enhanced by the use of three ensemble techniques: bagging, boosting, and voting. The results demonstrated that the proposed model improved the accuracy from 0.75 to 0.77 when we used the DT method with boosting. Furthermore, the precision and recall results both improved from 0.76 to 0.78.

**Keywords:** educational data mining; student performance; classification techniques; ensemble methods

#### **1. Introduction**

One of the main concerns for educational institutions is to analyze the factors that affect student performance. Every school tries to reduce the failure of their students. The most popular technique to evaluate and predict students' performance is educational data mining (EDM) [1]. EDM is about developing methods to deal with the different types of data in educational systems to improve students' learning outcomes [2] . EDM creates and modifies statistical, machine learning, and data mining approaches. EDM's primary objective is information extraction from educational data for educational decision-making. Educational data mining (EDM) can predict students' academic achievement early [3]. Their use could enhance the analysis of students' learning processes while taking into account how they interact with the environment. In this study, we investigated the electronic learning (E-learning) data set. E-learning is a field that has dramatically increased recently. Organizations and teachers have identified some challenges in E-learning. One of the main ones is to identify the factors that affect student performance while taking online courses. Therefore, we used a data set with features that allow us to analyze such factors and predict most that affect the performance. The data set has different characteristics such as demographic features, academic background, and behavioral features during taking online classes. Then, a proposed model was applied to the data set to analyze and

**Citation:** Alsulami, A.A.; AL-Ghamdi, A.S.A.; Ragab, M. Enhancement of E-Learning Student's Performance Based on Ensemble Techniques. *Electronics* **2023**, *12*, 1508. https://doi.org/ 10.3390/electronics12061508

Academic Editor: Simeone Marino

Received: 28 February 2023 Revised: 9 March 2023 Accepted: 13 March 2023 Published: 22 March 2023

**Copyright:** © 2023 by the authors. Licensee MDPI, Basel, Switzerland. This article is an open access article distributed under the terms and conditions of the Creative Commons Attribution (CC BY) license (https:// creativecommons.org/licenses/by/ 4.0/).

predict student performance. The model used three traditional data mining techniques to produce a performance model. The techniques are decision trees, naive Bayes, and random forests. Then, two ensemble methods have been used to enhance the results of traditional data mining techniques: bagging and boosting. Furthermore, each ensemble method has included two and three classifications mentioned above, respectively, with the help of voting processes for a more accurate prediction.

#### **2. Related Work**

This study aims to provide a comprehensive model for investigating e-learning students' data, specifically to help educational organizations in forecasting student success. There have been studies examining student performance prediction using previous educational results to predict future results at the same level. Moreover, other studies have examined the factors that affect student performance. The following is a review of some of these studies. In [4], researchers aimed to predict the student's success on the exam. They modeled the study by using the decision tree and K-nearest neighbor. The study concluded that the decision tree predicts students' pass or fail status in an academic course with the best results. In [5], researchers compared different classification techniques: Naive-BayesSimple, multilayer perceptron, SMO, J48, and REP-tree—in the field of comparing student performance. The data set was collected from a computer science department of a college with 300 student records. WEKA was the tool to use in the study. From the results, researchers concluded that the performance of multilayer perceptron is the most effective algorithm for predicting student performance. In comparison to other algorithms, multilayer perceptron accuracy was higher than other classifiers. In [6], researchers aimed to determine the basic factors that have a significant effect on secondary student performance. To do so, they combined single and ensemble-based classifiers to create the proper classification model, which they then used to forecast academic success. In the beginning, three data mining methods were used: decision tree, multilayer perceptron (MLP), and a PART; moreover, three ensemble techniques (multi-boost, bagging (BAG), and voting) were used individually. To improve the previous classifiers' performance, a single classifier and an ensemble classifier were combined, generating nine new models. According to the evaluation's findings, multi-boost with MLP outperformed other approaches in terms of accuracy. In [7], using data mining techniques, researchers were trying to predict student dropout. The findings demonstrated the possibility of dropout with accuracy rates greater than 0.80 in most situations and false positive values varying from 0.10 to 0.15 on average. K-nearest neighbors, random forest, support vector machines, decision trees, logistic regression, and naive Bayes were among the methods they were contrasting. Random forest outperformed other machine learning techniques in terms of accuracy, F-measure, and precision. In [8], researchers have concentrated on forecasting student performance in several interactive online sessions by exploring the information gathered using the E-learning and design suite. The data set keeps track of student participation during classes, including text editing, keystrokes, and the amount of time spent on each assignment. They used five well-known classifiers: naive Bayes, random forest, support vector machine, multi-layer perception, and logistic regression. Three distinct evaluation methods were utilized: five-fold cross-validation and random data split for training and testing. The model was trained in all sessions except the one used for testing. According to the results, the RF classifier model obtained the best accuracy. In [9], researchers investigated various classifier algorithms that are proposed to predict secondary school students' success in mathematics and Portuguese lessons. They classified using support vector machine (SVM), linear discriminant analysis (LDA), and K-nearest neighbor (KNN). Their experimental results demonstrated that the SVM method performed better for the unbalanced class distribution problem. In [10], the classification technique being evaluated by researchers was a hybrid classification. To do so, they used the radial basis function network, C4.5, random forest, and multilayer perceptron algorithms. They observed that hybrid classification algorithms perform more accurately than single algorithms. In [11], researchers were clustering the data by using the K-nearest neighbors (KNN) algorithm with the help of Harris hawks optimization (HHO). Once they classify all of the solutions, redistribution for the solutions into a search space will be applied. Several different machine learning classifiers were used to validate the overall prediction system, such as naive Bayes, KNN, LRNN, and artificial neural network. The results collected demonstrate the significance of anticipating student performance early to reduce student failure and enhance the overall effectiveness of the educational institution. Furthermore, given that LRNN is a deep learning method that can observe past and current input values, the results showed that the modified HHO and LRNN combination outperforms other classifiers with an accuracy of 0.92. In [12]. Researchers concentrated on how crucial it was to take advantage of both technological advancements and potential educational contributions. They tested a new PFA strategy based on various ensemble learning techniques to improve the forecasting of student performance. (random forest, AdaBoost, and XGBoost). The results have demonstrated that XGBoost could predict future student acquisition with the highest performance. In [13], researchers presented the data mining technique used to forecast first-year students' academic performance. They chose three different data models for learning stages and tested them based on the dates of entry, end of the first, and end of the last semesters. Records of bachelor students who enrolled in a program offered by the institution between 2006/2007 and 2015/2016 were obtained and gathered through the institutional database. The best overall performance was gained by a support vector machines (SVM) model, which was chosen to perform database sensitivity analysis. Table 1 shows some papers that used different data mining techniques in order to predict the performance of students.


**Table 1.** Comparison of data mining techniques in predicting student's performance.

In [16], researchers attempted to determine the factors influencing academic performance. Thus, they made use of two different types of data sets. The first data set demonstrates how the performance in a course's required courses might affect a student's performance in the current course. The second data set suggested that the student's grade in any course is related to their performance in the semester until the midterm test. In [17], the results of the model showed that the main contribution to predicting academic performance is related to the following factors: interview, task, questionnaire, and age. The access factor measures student's access to the module, including access to forums and

glossaries. Questionnaire factors summarize the variables in the questionnaire related to the visit and the attempt. The age factor contains the student's age. In [18], the study aimed to investigate the factors affecting student performance. Researchers reviewed and analyzed 36 articles. They concluded that the performance in previous classes and grades, the students' e-Learning activities, and their demographic background had an impact on the performance of the student, academically speaking. In order to determine whether students' learning behaviors were important, researchers examined the same data set [19]. They used the ensemble methods, voting, bagging, and boosting, alongside traditional data mining methods, support vector machines, decision tree (ID3), K-nearest neighbor, and naive Bayes. With the help of the voting process, the highest accuracy was achieved. In [14], an investigation was conducted on learners' relationships with e-learning. A combination of ensemble algorithms with three different types of classifiers was used: decision trees, K-nearest neighbors, and support vector machines. It was found that learners' features were strongly correlated with their performance in the study. In contrast, ensemble techniques increased accuracy. In [15], in order to help decision-makers make the best choices for their organizations, researchers used ensemble methods to predict student performance. They used naïve Bayes, decision tree, and K-nearest neighbor methods. The voting technique was used to combine the three methods. In most scenarios, the proposed model improved the accuracy of naïve Bayes.

#### **3. Methodology**

In this section, we describe the data set used to conduct this study, followed by a discussion of the proposed model and the evaluation measures.

#### *3.1. Data Set*

#### 3.1.1. Data Collection

The data set for this study was obtained from the Kalboard 360 E-Learning system via the Experience API (XAPI). The data set [20] in this study consists of 480 records with 17 attributes. In addition, all attributes are either integer or categorical in nature. The features are categorized into three major types. Table 2 shows (1) demographic attributes such as place of birth, nationality, gender, and parent responsibility for their children. (2) Academic attributes such as educational stage and grade level. (3) Behavioral attributes include opening resources and raising hands in class. These different categories make the dataset appropriate for the classification and prediction of student performance within E-learning systems. Table 2 illustrates the wide range of data set features, along with the category to which they belong, in addition to the description of each attribute.

**Table 2.** Classification results with traditional DM.


#### 3.1.2. Data Visualization

Data visualization is an essential part of the preprocessing process that uses graphs to simplify complex data. We used WEKA software to visualize the data set. The graphical representations can help instructors better understand their students and monitor what's happening in online classes. Figure 1 illustrates the gender. The data set consists of 305 males and 175 females.

**Figure 1.** Student gender.

Figure 2 shows the diversity of nations that the data set contains. For example, 179 students are from Kuwait and 172 students are from Jordan and others.

**Figure 2.** Student nationalities.

The data set also includes a feature for recording students' attendance at school to measure the influence of such a feature. As seen in Figure 3, students are divided into two groups according to the number of absence days: 289 students have fewer absence days than 7, while 191 students have more than 7.

**Figure 3.** Student absences.

Students are classified into three educational levels based on the characteristics listed: there were 199 lower-level students, 248 middle-level students, and 33 high-level students. Students' records were gathered over the courses of two academic semesters: the first and second, with 245 records being obtained from the first semester and 235 records collected from the second. Students chose various topics across these semesters. In total, 95 students were taking the IT topic, 65 were taking the French topic, 59 were taking the Arabic topic, 59 were taking the Arabic language, 51 were taking science, 45 were taking English, 30 were taking biology, 25 were taking Spanish, 24 were taking both chemistry and geology, 22 were studying the Quran, 21 were taking math, and 19 were taking history.

#### 3.1.3. Data Cleaning

As part of preprocessing, data cleaning is essential for removing irrelevant objects and missing values in the data collection. There are zero missing values in the data set.

#### *3.2. Features Selection*

Feature selection refers to selecting the relevant features of a dataset based on specific criteria from an original feature set. There are two types of data reduction methods: wrapper methods and filter methods. The filter method ranks the features using variable ranking methods, with the highly ranked features being selected and implemented into the learning algorithm [21]. In this study, the information gain ranking filter and a correlationranking filter were used. At each decision tree node, and in order to select the test attribute, the information gain measure is taken into account. The information gain (IG) metric determines features with a large number of values. It is calculated with Equation (1).

$$IG(T, a) = H(T) - H(T|a) \tag{1}$$

where *T* is a random variable and *H*(*T*|*a*) is the entropy of *T* given the value of attribute a.

Figure 4 illustrates the feature ranking after using the WEKA tool to apply the information gain filter. Visited resources are ranked first, followed by student absence, raised hands, and other attributes.

**Figure 4.** Information gain filter.

Correlation coefficients are applied to measure correlations among attributes and classes and inter-correlations between features [21]. It is calculated with Equation (2).

$$\rho(X,Y) = \frac{\text{cov}(X,Y)}{\sigma\_X \sigma\_Y} = \frac{\sum\_{i=1}^n (x\_i - \bar{x})(y\_i - \bar{y})}{\sqrt{\sum\_{i=1}^n (x\_i - \bar{x})^2} \sqrt{\sum\_{i=1}^n (y\_i - \bar{y})^2}} \tag{2}$$

where:


Figure 5 shows the rank between the attributes and the class. As we saw in the information gain filter, the most ranked feature is the visited resources feature, which is followed by student absence days and raised hand. From Figures 4 and 5, the filters selected 16 attributes. As a result, the behavior feature impacts student performance more than the other features due to its significant impact.

**Figure 5.** Correlation filter.

#### *3.3. Data Mining Tool*

WEKA is a well-known Java-based machine learning software developed at New Zealand's University of Waikato [22]. The WEKA package includes visualization tools, data analysis and predictive modeling algorithms, and graphical user interfaces for better accessibility to this functionality. It includes numerous algorithms for data mining and machine learning.

#### *3.4. Proposed Model*

The primary purpose of this paper is to compare the performance and results of each prediction model based on the use of traditional data mining techniques and ensemble methods. Figure 6 illustrates the proposed model that will apply to the data set. First, we collect the data set and prepare it to perform the study. Then, three traditional data mining methods will apply (decision tree (DT), naïve Bayes (NB), and random forest (RF)) to produce a performance model. In addition to the classifiers mentioned earlier, two ensemble methods are used to improve their performance. Boosting, as well as bagging, is applied to enhance the student prediction model's success. Two and three methods were

added to each ensemble technique using the voting process for a more accurate prediction. The model's last phase will involve evaluating and discussing the results. The data were divided into training and test sets. Each prediction model's performance was evaluated using K-fold cross-validation. When testing a model, this technique is used to solve the variance problem. In brief, k-fold cross-validation divides the training set into 10 folds. During training, 9 folds are applied before the final fold is tested. As an average of the different accuracies is taken, this better represents the model performance. The method was repeated ten times. All models were run with the WEKA software's default parameters.

**Figure 6.** Proposed model for predication of student's performance.

#### *3.5. Description of Traditional Data Mining*

#### 3.5.1. Decision Tree

The decision tree algorithm belongs to the family of supervised learning algorithms. Decision trees are one of the most effective techniques in various fields, including machine learning, image processing, and pattern recognition. The decision tree algorithm solves both regression and classification problems. By learning basic decision rules from training data, researchers can create training models that predict the class of a specific variable. Each tree is composed of nodes and branches. A feature in the classification category identifies each node, and each subset identifies a value the node can take [23].

#### 3.5.2. Random Forest

In machine learning, random forest is a supervised algorithm that predicts output by merging multiple decision trees into a forest, and then by combining the predictions from all decision trees using ensemble learning, an accurate prediction can be obtained [23].

#### 3.5.3. Naive Bayes

According to Bayes' theorem from probability theory, naive Bayes is a direct and simple Bayesian classifier. An NB algorithm applies Bayes' theorem to each pair of features given the class variable's value and assumes that they are conditionally independent [24]. Every pair of features being classified in naive Bayes is independent of the others.

#### *3.6. Description of Ensemble Methods*

#### 3.6.1. Bagging

Bagging is the most well-known independent method. The method aims to improve accuracy by combining the results of multiple learned classifiers into a single prediction to create an improved composite classifier. The classifiers are trained with replacements on subsets of instances from the training set [25]. There is no difference in sample size between the original training set and each sample. Bagging is a technique for improving a classifier's accuracy by creating multiple classifiers and combining multiple models. The bagging method aims to improve the accuracy of unsteady classification models by constructing a composite classifier and then combining the results of the obtained classification models into a single prediction. It means that every set of data has an equal chance of being taken [26].

#### 3.6.2. Boosting

Boosting refers to a group of algorithms that can modify weak learners into strong ones. Boosting works by training multiple classifiers and obtaining their forecasts, then adjusting the weights of the weakest one to reduce the previous learner's errors. Boosting was used only for binary classification. The AdaBoost algorithm overcomes this adaptive limitation. Boosting determines instances according to their weight with a possible [27].

#### 3.6.3. Voting

A voting classifier is a machine learning model where the class is predicted based on the output with the greatest probability based on different base models. The voting procedure involves learning classifier voting by a majority (for classification) and average (for regression). Eventually, the highest vote or average obtained for each class will be predicted [28].

An independent process or a dependent process can be classified as a set mode. A dependent method is considered to be boosting. Learners are created based on their outcomes in a dependent process. Each learner works independently during the independent process, and their outcomes are merged using a voting process. The bagging method is an independent method [29].

#### *3.7. Measurement Measures*

In this study, data were applied to the WEKA Data Mining tool. Data were fed into the WEKA Data Mining tool in this study. Then, different DM techniques were compared to determine which had higher prediction accuracy than others, and a decision was made based on that. The following common metrics can evaluate a study's performance: accuracy, precision, recall, and F-Measure.

#### 3.7.1. Accuracy

This represents the classifier's accuracy and relates to the classifier's capacity. The accuracy of a predictor relates to the way it accurately predicts the impact of a predicted feature for new information. The percentage of correct predictions divided by the total number of predictions yields the accuracy [30]. It is calculated with the following Equation (3):

$$Accuracy = (TP + TN) / (TP + TN + FP + FN),\tag{3}$$

where:


#### 3.7.2. Precision

Precision is calculated as the ratio of correctly classified positive predictions to total positive predictions, whether correctly or incorrectly classified [30]. It is calculated with Equation (4).

$$Precision = TP/(TP + FP),\tag{4}$$

#### 3.7.3. Recall

The recall is determined by calculating the proportion of correctly classified positive predictions to all positive predictions [30]. It is calculated with Equation (5).

$$Recall = TP / (TP + FN),\tag{5}$$

#### 3.7.4. F-Measure

F-measure conveys both recall and precision in a single measure [30]. It is calculated with Equation (6).

$$F1-measure = (2\*Recall\* \text{Precision})/(\text{Recall} + \text{Precision}),\tag{6}$$

#### **4. Experimental Results**

The results of each of the prediction models (traditional data mining without and within ensemble methods) are provided in this section.

#### *4.1. Traditional DM Techniques*

WEKA software was used to conduct the experiments, as we mentioned in the previous section. In Figure 7, implantation for the traditional DM model is explained. First, in WEKA, a data set was uploaded using an operator called CSVLoader to start building the model. The description and details of the selected dataset were explained in the previous section. The dataset contained 17 attributes in total, and all of them were chosen for this study. By linking the "CSVLoader" to a text viewer in WEKA, a table of all the attributes can be shown. Secondly, the data set was assigned to an operator called "ClassAssigner "to determine the attribute to be considered the class. Third, once the class had been defined, the dataset was connected to the cross-validation operator called in WEKA, "CrossValidationFoldMaker". It was divided into two parts: training and testing. In each iteration of the cross-validation process, nine subsets were trained for the model and one for the test. As a result, the model's training and validation were performed concurrently in a single step, which was recognized as a valid test as a result the dataset being used for testing is unidentified. After that, as Figure 7 illustrates, the cross-validation was connected with each algorithm. The data set was divided into two parts: training, to train the algorithm with the training data, and testing, to test the algorithm with the testing data. Finally, the "Classifier PerformanceEvaluator" was applied to each to get the validation of the model and the performance. Moreover, all the models in Figures 8 and 9 were performed with the same procedure as in Figure 7.

**Figure 7.** Traditional DM technique implementation using WEKA.

#### *4.2. Ensemble Methods*

The same procedure will be performed on the boosting model. Uploading the data set then assigning it to a class and connecting it to the cross-validation to link it to the methods and finally applying the model. As Figure 8 shows, six experiments have been performed in the boosting model. First: three experiments with each of the data mining techniques. Second, with the help of voting methods, boosting was performed with two algorithms at one time. There were two experiments: boosting with naïve Bayes and random forest, and the second was boosting with random forest and decision trees. Third, with the help of the voting method, the last experiment was conducted using boosting with all three traditional data mining techniques. The idea was to observe any difference in the performance when we manipulate the model.

**Figure 8.** Ensemble method (boosting) implementation using WEKA.

The same procedure as discussed in boosting implementation will be used with the bagging model as seen in Figure 9.

**Figure 9.** Ensemble method (bagging) implementation using WEKA.

Table 2 illustrates the evaluation measures for traditional data mining. The decision tree resulted in an accuracy of 75.5%, a precision of 0.760, a recall of 0.758, and an F-measure of 0.759. The accuracy was 67.7%, the precision was 0.675, the recall was 0.677, and the F-measure was 0.671, when using naïve Bayes. The accuracy of random forest was 76.6%, the precision was 0.766, the recall was 0.766, and the F-measure was 0.766.

Table 3 illustrates the evaluation measures for boosting with the traditional data mining techniques. Using boosting with decision trees resulted in accuracy of 77.92%, precision of 0.779, recall of 0.7779, and F-measure of 0.779. The results of using boosting with random forest were 76.25% accuracy, 0.763 precision, 0.763 recall, and 0.762 F-measure. The accuracy of 72.29%, precision of 0.724, recall of 0.723, and F-measure of 0.723 are the results of using boosting with naïve Bayes. Boosting with naïve Bayes and decision trees resulted in an accuracy of 76.45%, precision of 0.7264, recall of 0.765, and F-measure of 0.764. Finally, when using boosting with all traditional data mining techniques, the accuracy was 76.25%, the precision was 0.762, the recall was 0.763, and the F-measure was 0.762.


**Table 3.** Classification results with boosting.

Table 4 illustrates the evaluation measures for bagging with traditional data mining. The results of using bagging with decision trees are 74.37% accuracy, 0.744 precision, 0.743 recall, and 0.743 F-measure. The results of using bagging with random forest are 75.63% accuracy, 0.757 precision, 0.756 recall, and 0.756 F-measure. Bagging with naïve Bayes produces accuracy of 67.7%, precision of 0.677, recall of 0.676, and F-measure of 0.672. Bagging with naive Bayes and decision trees achieves accuracy of 75.62%, precision of 0.756, recall of 0.756, and F-measure of 0.756. The results of using bagging with random forest are 76.46% accuracy, 0.766 precision, 0.765 recall, and 0.765 F-measure. Finally, when using bagging with all three data mining techniques, the accuracy is 76.87%, the precision is 0.768, the recall is 0.769, and the F-measure is 0.768.


**Table 4.** Classification results with bagging.

#### **5. Evaluation Results and Finding**

In this section, the results of different traditional data mining techniques (decision tree, naïve Bayes, and random forest) and ensemble methods (bagging, boosting, and voting) will be interpreted and evaluated. As mentioned above, four different measures will be used to evaluate the performance: accuracy, precision, recall, and F-measure.

#### *5.1. Accuracy*

For all of the experiments that we conducted with traditional data mining techniques and ensemble methods, the accuracy values were above 65%. We observed that naïve Bayes had the lowest accuracy of 67.7% among all methods. On the traditional DM techniques, as shown in Figure 10, the random forest has a more notable high accuracy of 76.6% than other techniques, which indicates that 365 of 480 students were successfully classified according to the suitable class labels and 115 were not.

**Figure 10.** The accuracy for traditional DM techniques.

Moreover, Tables 3 and 4 explained all of the accuracy performances for the models using ensemble methods. Overall, the accuracy of performance was improved by using ensemble methods. In naïve Bayes, the accuracy was enhanced from 67.7% (without ensemble) to 72.29% (with boosting). In comparison between bagging and boosting using voting processes to traditional DM techniques, the accuracy of boosting with decision tree achieved the highest value of 77.9%, where the value was 75.5% without ensemble methods, as Figure 11 shows. That is, from 360 to 375 students were successfully classified to the appropriate class labels. Every ensemble method scenario has outperformed the naïve Bayes accuracy except with boosting, which was equal. Figure 11 illustrates how the proposed model enhanced the accuracy effectively when we used ensemble methods with the traditional data mining separately (boosting + DT) and when we combined several classifiers (Boosting + DT + RF) using the voting process.

**Figure 11.** The accuracy of DM techniques and ensemble methods.

#### *5.2. Precision*

All fifteen experiments have a value of precision above 0.65. Figure 12 shows the precision of traditional DM techniques. We can observe that random forest outperformed the other methods with a value of 0.75.

**Figure 12.** Precision for DM techniques.

Naïve Bayes has the lowest value of 0.67. At the same time, it increased when applying ensemble methods to 0.72 (with boosting), meaning the number of students correctly classified to the right class labels improved from 322 to 346. Boosting with a decision tree recorded the highest value of precision with 0.78 when in traditional data mining, and the highest value was 0.76. Furthermore, other ensemble classifiers outperformed the traditional DM techniques, as shown in Figure 13.

**Figure 13.** The precision of DM and ensemble methods.

#### *5.3. Recall*

All fifteen experiments have a value of precision above 0.65. Compared to traditional DM techniques, random forest performed better in recall than others with a value of 0.76, as shown in Figure 14.

**Figure 14.** Recall for DM techniques.

Naïve Bayes has the lowest value of 0.67, while it increased, when ensemble methods were applied, to 0.72 (with boosting), which is the percentage of correctly classified students to the total number of unclassified, and correctly classified classes improved from 322 to 346. Boosting with a decision tree, as shown in Figure 15, achieved the highest value of recall with 0.78 compared to 0.76 in traditional DM.

**Figure 15.** The Recall of DM techniques and ensemble methods.

#### *5.4. F-Measure*

All fifteen experiments have a value of precision above 0.65. Random forests and decision trees have similar values, as Figure 16 illustrates.

**Figure 16.** F-measure for DM techniques.

As shown in Figure 17, three scenarios from the proposed model outperformed traditional data mining techniques. Naïve Bayes has the lowest value of 0.67, while it increased, when ensemble methods were applied, to 0.72 (with boosting). Boosting with a decision tree achieved the highest value of recall at 0.78 compared to the traditional DM techniques, with a value of 0.77.

**Figure 17.** The F-measure of DM techniques and ensemble methods.

#### *5.5. Validate the Results with Previous Studies*

Validation is essential when creating predictive models because it identifies how reasonable they are. In this study, we compare the results with previous studies that used ensemble methods to predict student performance. The experimental value reports the betterment of proposed model over recent approaches as the accuracy was increased by 1%. The betterment is illustrated also when the voting technique was applied, the researchers increased the accuracy by 1% [19]. In [14], the same data set was used in a study that employed ensemble methods approaches. They increased the accuracy by 2.1% by using the boosting method with a decision tree. Likewise, with regard to [15], researchers aimed to improve student performance using ensemble methods. However, in [31], the results indicated that the ensemble methods improved accuracy by 1% when applied to the same data set. In a study that combined traditional data mining with the help of the voting method, the proposed model enhanced accuracy by 2.1% using the exact data set with ensemble methods [32]. Therefore, it is confirmed that the proposed model in this study has showcased improved performance over the existing approaches. Furthermore, our model enhanced performance by 2.4% when applied to a decision tree using the boosting

method. In the future, an ensemble fusion-based DL model can be developed to improve the performance of the proposed technique.

#### **6. Conclusions**

Academic institutions all over the world are concerned about student success. As learning management systems become more widespread, an enormous amount of information about the interaction between the teaching and learning processes is generated. In this study, the authors developed a new technique for predicting student performance that combined data mining techniques with ensemble methods. The data sets measured various features such as demographic information, student behavior in online classes, and parental involvement in academic performance. According to the findings of the study, there is a strong relationship between student behavior and their performance. For ensuring the enhanced performance of the proposed model, wide-ranging experiments were performed, and the results are inspected under distinct aspects. The proposed model improved the accuracy from 0.75 to 0.77 when we used the DT method with boosting, which resulted in a more accurate prediction of student performance. Furthermore, the precision and recall results both improved from 0.76 to 0.78. Moreover, the extensive experimentation outcomes confirmed the superior performance of the proposed technique compared to other existing techniques. Thus, the proposed technique can be used as a proficient approach for the prediction of student performance. In the context of future development, the presented model can be extended Utilization of ensemble fusion-based DL and data mining techniques for the prediction of the academic performance of students. Moreover, the proposed model can be extended to improve the performance of the proposed technique to several online educational data sets to support decision-makers for high-impact e-learning development.

**Author Contributions:** Conceptualization, A.A.A., A.S.A. and M.R.; methodology, A.A.A. and M.R.; software, A.A.A.; validation, A.A.A., A.S.A. and M.R.; formal analysis, A.A.A. and A.S.A.; investigation, A.A.A. and M.R.; resources, A.A.A., A.S.A. and M.R.; data curation, A.A.A. and A.S.A.; writing—original draft preparation, A.A.A.; writing—review and editing, A.A.A. and M.R.; visualization, A.A.A.; supervision, A.S.A. and M.R.; project administration, M.R.; and funding acquisition, A.A.A., A.S.A. and M.R. All authors have read and agreed to the published version of the manuscript.

**Funding:** This research work was funded by Institutional Fund Projects under grant no. (IFPIP: 26-611-1443).

**Institutional Review Board Statement:** Not applicable.

**Informed Consent Statement:** Not applicable.

**Data Availability Statement:** The data used in this work are available at Kaggle.

**Acknowledgments:** This research work was funded by Institutional Fund Projects under grant no. (IFPIP: 26-611-1443). Therefore, the authors gratefully acknowledge technical and financial support provided by the Ministry of Education and Deanship of Scientific Research (DSR), King Abdulaziz University (KAU), Jeddah, Saudi Arabia.

**Conflicts of Interest:** The authors declare no conflict of interest.

#### **References**


**Disclaimer/Publisher's Note:** The statements, opinions and data contained in all publications are solely those of the individual author(s) and contributor(s) and not of MDPI and/or the editor(s). MDPI and/or the editor(s) disclaim responsibility for any injury to people or property resulting from any ideas, methods, instructions or products referred to in the content.

### *Article* **Human Action Recognition for Dynamic Scenes of Emergency Rescue Based on Spatial-Temporal Fusion Network**

**Yongmei Zhang 1,\*, Qian Guo 1, Zhirong Du 1,2 and Aiyan Wu <sup>1</sup>**


**Abstract:** Targeting the problems of the insufficient utilization of temporal and spatial information in videos and a lower accuracy rate, this paper proposes a human action recognition method for dynamic videos of emergency rescue based on a spatial-temporal fusion network. A time domain segmentation strategy based on random sampling maintains the overall time domain structure of the video. Considering the spatial-temporal asynchronous relationship, multiple asynchronous motion sequences are increased as input of the temporal convolutional network. spatial-temporal features are fused in convolutional layers to reduce feature loss. Because time series information is crucial for human action recognition, the acquired mid-layer spatial-temporal fusion features are sent into Bidirectional Long Short-Term Memory (Bi-LSTM) to obtain the human movement features in the whole video temporal dimension. Experiment results show the proposed method fully fuses spatial and temporal dimension information and improves the accuracy of human action recognition in dynamic scenes. It is also faster than traditional methods.

**Keywords:** spatial-temporal fusion; human action recognition; two-stream convolutional neural network; emergency rescue; spatial-temporal asynchronous information

#### **1. Introduction**

Human action recognition has always been one of the most challenging problems in the field of computer vision [1]. Video is a kind of data over time with a strong temporal correlation. Each pixel in a video has great similarity and strong spatial correlation.

Most of the subjects in videos are people, so the human action recognition technology has piqued considerable research interest as a novel application. The development of artificial intelligence provides a broad space for developing human action recognition technology in the form of virtual reality, intelligent monitoring, motion analysis, humancomputer interaction, etc. [2]

In recent years, various natural and man-made disasters have had a great impact on people's lives. In the face of emergencies, identifying specific situations that need emergency responders is critical [3]. Applying action recognition technology to rescue scenarios, such as major traffic accidents, major terrorist attacks, and earthquakes can effectively improve emergency response by medical rescue team members and is conducive to providing auxiliary decisions for decision-making at the disaster site.

Identifying the actions and current state of both first-responders and victims is crucial in such situations. It would be helpful to have a more comprehensive grasp of the on-site rescue work to achieve efficient guidance and accurately and quickly implement the rescue.

In this work, we propose a two-stream asynchronous fusion network based on Temporal Segment Networks (TSN) and Bi-LSTM for human action recognition in emergency rescue classification of entire video sequences. The main contributions include the following:

(1) This paper further refines the currently available dataset in the literature [4]. The dataset was constructed with reference to the AVA dataset production method. To improve

**Citation:** Zhang, Y.; Guo, Q.; Du, Z.; Wu, A. Human Action Recognition for Dynamic Scenes of Emergency Rescue Based on Spatial-Temporal Fusion Network. *Electronics* **2022**, *12*, 538. https://doi.org/10.3390/ electronics12030538

Academic Editor: Fernando De la Prieta

Received: 30 October 2022 Revised: 15 December 2022 Accepted: 18 December 2022 Published: 20 January 2023

**Copyright:** © 2022 by the authors. Licensee MDPI, Basel, Switzerland. This article is an open access article distributed under the terms and conditions of the Creative Commons Attribution (CC BY) license (https:// creativecommons.org/licenses/by/ 4.0/).

the annotation efficiency, Faster R-CNN was used to detect human positions, including actions and human annotations using Via software. To prevent the overfitting of the network and increase the sample diversity of data, action sequence video data are collected as the emergency first-responder action dataset for annotation, and a data enhancement method is used for data augmentation so that it can better reflect the emergency rescue scene.

(2) A spatial-temporal asynchronous fusion network is proposed. TSN was used to randomly sample video fragments. An RGB image corresponding to each segment and its two-stream field before and after a specified period of time is then input into the spatial and temporal flow networks, respectively, to extract spatial and temporal features. The fusion of spatial-temporal asynchronous information is also realized in the convolutional layer. The fused features were input into Bi-LSTM to extract temporal features and finally implemented for human behavior recognition using Softmax. By modeling the asynchronous relationship between the moving image and the motion sequence (optical flow), the long-term motion can be modeled.

(3) Experiments are conducted on the improved emergency rescue dataset provided in the literature [4] and the publicly available dataset UCF101. We conduct experimental analysis to compare our spatial-temporal fusion network model with other methods, which verifies that the presented method improves the accuracy of action recognition.

The remainder of this paper is structured as follows: Section 2 shows the related work. The proposed spatial-temporal fusion network model is shown in Section 3. The proposed method is presented in Section 4. Experiments and results are described in Section 5, while Section 6 concludes the work of this paper.

#### **2. Related work**

#### *2.1. Traditional Human Action Recognition Methods*

Human action recognition methods include feature extraction and action recognition. Traditional feature extraction methods mainly include global and local feature extraction. Global feature representation is a method of comprehensively describing the overall structure and shape features of moving objects, such as silhouette-based features and optical-flow-based features. A. Mahjoub et al. [5] computed a depth motion map for each sequence to represent the action motion features from the depth data.

The global feature method is greatly affected by occlusion, visual angle change, and noise, which cannot effectively capture the changes in viewpoint and occlusion. The method based on local features does not rely on the global features of the image and only extracts local features, so it is less affected by noise and messy background, has good robustness, and it is more widely used.

The identification method using local feature representation requires the extraction of the rich interest points from the video and the use of local descriptors to express the points of interest, finally gathering them together. This approach obtains local features directly from the point of interest on the image, thus eliminating the pre-processing step [6]. Methods based on the point-of-interest description are often suitable for simpler scenarios and can lead to decreased recognition performance if the video background is more complex.

The extraction and description of the movement trajectory based on tracking points are also the focus of scholars. The Dense Trajectory (DT) method proposed in the literature [7] densely sampled the points of interest in the framed image and used the dense optical flow method to track these points of interest and connect them into trajectories. Along the motion trajectory of the feature points, the features are extracted as motion descriptors. Bag of features (BOF) is used to encode the feature groups to obtain the features, thereby describing the video behavior.

Abdelbaky A. et al. [8] used a 2D convolutional network, PCANet, as an unsupervised feature extractor instead of 3D convolutional networks to learn the spatial-temporal features of video actions. The learned spatial and temporal features are combined with BOF and a vector local aggregation descriptor encoding scheme. Chen T. et al. [9] improved the contour features and realized action recognition based on the improved features and multiple features.

Because features are highly susceptible to extrinsic factors, the recognition of featurebased video behavior is different in different scenarios. Select representative feature descriptions based on a specific task, and the selected features will have a great influence on the recognition accuracy [10]. Some current methods require high computational costs while achieving high recognition accuracy, such as dense trajectory algorithms. Such algorithms require a large number of trajectory operations, so the recognition process is relatively complex and has great limitations. As a result, traditional methods are gradually being replaced by deep-learning-based methods [11].

#### *2.2. Human Action Recognition Methods Based on Deep Learning*

According to the different neural network structures used for recognition, action recognition methods based on deep learning can be divided into action recognition methods based on two-stream convolutional neural networks (CNN), action recognition methods based on 3D CNN, and action recognition methods based on long short-term memory (LSTM). A comparison of the three methods is shown in Table 1.



(1) Action recognition methods based on two-stream CNN:

Inspired by the ventral and dorsal human visual processes, the spatial-temporal twostream network was proposed by Simonyan et al. [12]. However, the proposed two-stream network is difficult to handle for long timing and complex motions. To address this problem, Wang et al. [13] presented TSN, which introduced a sparse sampling strategy.

Thereafter, researchers proposed many improved models based on the two-stream network structures. For example, Zhuang et al. [14] proposed a novel three-flow spatialtemporal attention-enhancing feature fusion network for action recognition. A two-stream 3D ConvNet fusion framework was proposed by Wang et al. [15].

(2) Action recognition methods based on 3D CNN:

Ji et al. [16] first proposed 3D CNN and applied it to action recognition. Based on this work, Tran et al. [17] presented convolutional 3D (C3D). The network utilizes 3D convolution and 3D pooling to process the input video frames. However, there are still some limitations. 3D convolution produces a large number of parameters, which greatly increases the computation cost. Moreover, 3D convolution models both temporal and spatial information, and easily leads to overfitting problems.

To address these issues, Carreira et al. [18] proposed I3D networks to obtain greater spatial-temporal resolution by expanding the 2D convolution operations of InceptionV1 networks into 3D convolution. Diba et al. [19] extended the 3D convolution to DenseNet and presented the Temporal 3D Convnets (T3D). T3D adds a temporal transition layer (TTL) to obtain rich temporal information at different scales and capture short, medium, and long temporal information. Qiu et al. [20] proposed Pseudo-3D Residual Networks (P3D) for the problem of a large number of C3D parameters. The 3D convolution is decomposed into 2D space and 1D time convolution by convolution factors. This method increases the network depth and diversity by separating and flexibly combining the time and space domains and improves the recognition accuracy.

Most of the previous networks have extended the 2D CNN from the time dimension. However, this is not necessarily the best choice. The X3D network was proposed by Feichtenhofer [21]. The X3D extended at multiple scales, including dimensions, such as input frame number, input frame size, sampling frame rate, number of convolutional kernels, depth feature graph width, network depth, and other dimensions. The computational amount and parameters required for X3D are exponentially reduced with state-of-theart performance.

(3) Action recognition methods based on LSTM:

LSTM-based action recognition methods usually combine CNN and LSTM to build networks. This method takes full advantage of CNN and LSTM, CNN extracts the features of spatial dimension, and LSTM extracts information of temporal dimension. LSTM can solve the gradient disappearance problem, thus handling long video data well. Combining CNN and LSTM is mainly able to better capture spatial-temporal information for fusion.

Donahue et al. [22] proposed a Long-Term Recurrent Convolutional Network (LRCN) that combined the traditional CNN and LSTM to extract the spatial-temporal information of the video. The input of the LRCN network can be either a single-frame image or a video with temporal information. In the action recognition neural network structure proposed by Ou et al. [23], CNN is used to separately extract local spatial and local motion information, and LSTM is adopted to extract feature information in video sequences and obtain the context relation of the local spatial-temporal information. Ge et al. [24] improved the Faster R-CNN framework by introducing LSTM, obtaining the spatial features of the action by Faster R-CNN and the temporal features of the action by LSTM. A more accurate recognition effect is obtained by combining the clues of the auxiliary regions.

#### **3. A spatial-Temporal Fusion Network Model**

Currently, 3D CNN and two-stream CNN are mainly used for spatial and temporal fusion. 3D CNN captures spatial-temporal features directly from the original video sequences with universal applicability. However, it does not consider the relationship between the spatial and temporal features, and the densely sampled video frames of the 3D CNN will also produce a large number of parameters, which greatly increases the computation. Two-stream CNN, respectively, extracts the spatial and temporal information through two parallel spatial and temporal networks, which is more conducive to processing and fusing the two, and its training complexity is small. The two-stream network structure is shown in Figure 1.

**Figure 1.** Two-stream network structure.

As can be seen in Figure 1, the spatial flow network processes the static data through multiple convolutional layers and fully connected layers to provide information about the scenes and objects of the network. Using the same convolutional network structure as the spatial flow network, the temporal flow network takes the stack of adjacent *L*-frame image optical flow as input, processes the dynamic information, and represents the motion and temporal information through multiple frames of optical streams. After Softmax, the two-stream network undergoes independent training for fusion operations.

Action recognition research has been greatly influenced by the two-stream network, yet there is still room for improvement. The fusion method is only the direct fusion of the prediction results for the two-stream network classifier, which is relatively simple to implement, but it fails to fully integrate the temporal and spatial information. The spatial stream operates on a single frame. The dense optical flow of the temporal flow can only learn the running information between adjacent frames, and cannot conduct feature learning for the whole video. Therefore, it cannot learn complex and long-term videos, and the network recognition effect is limited.

For this reason, subsequent researchers proposed various methods to improve the two-stream network. A common solution is to use a more intensive image frame sampling method to obtain the long-term information of videos, but it contains a lot of redundant data and increases the cost. TSN offers a time domain segmentation structure with a sparse sampling frame, and this structure can remove some redundant information and extract action information from the whole video.

The TSN framework uses sparse sampling methods to extract short segments of the entire video, randomly samples segments in short sequences, and fuses the category scores of different segments by the segment consensus function. The two-stream CNN utilized in the TSN framework merely extracts the motion and spatial information independently, and finally combines them without taking into account the correlation between the spatial and temporal information. The framework fused after the fully connected layer destroys the space-time properties to some extent. In addition, there are synchronous and asynchronous relationships between the motion information of the video behavior and the scene information in the process of action, but TSN does not consider the relationships.

In view of the defects of the traditional fusion spatial-temporal network and the spatial-temporal asynchronous relationship of actions, this paper adopts the time domain segmentation strategy for TSN to sample segments randomly. The sampled RGB video frames and the two-stream fields before and after a period of time are, respectively, sent into the spatial and temporal flow networks to learn the spatial-temporal features of the video. Both of the spatial-temporal asynchronous information in the convolutional layer will be fused. The fused feature information contains the matching relationship between the motion sequence information and the single-frame image. It is now possible to effectively combine the asynchronous spatial-temporal information of the video by performing long-term motion modeling and extracting temporal features using Bi-LSTM. The final classification is achieved by the Softmax classifier to complete the behavior recognition, and the proposed structure of the spatial-temporal fusion network model is shown in Figure 2.

The leftmost nine images in Figure 2 represent the input images in groups of three, and each group of images is the same, and the three images of the same group represent the three channels of the image. The following three boxes show the early-stage spatialtemporal fusion, the mid-term spatial-temporal fusion, and the late spatial-temporal fusion. The upper side image of the box indicates the information about the three channels for the frame image, and the lower side of the gray image shows the information about the two directions of the overlapping optical flow field for the single frame image corresponding to the time.

**Figure 2.** The proposed structure of the spatial-temporal fusion network.

The same video is input into the early-stage fusion, mid-term fusion, and late fusion to extract temporal and spatial features separately. The two features of the different fusions are then fused, respectively, and the feature with the best effect is selected from the fused features as the input of Bi-LSTM to further mine the asynchronous information.

The spatial-temporal fusion network adopts the sparse sampling strategy of the TSN framework, models the long-range time structure based on the segmented sampling, and obtains the video-level prediction results, to effectively learn the action model using the entire action video. Considering the spatial-temporal asynchronous relationship between motion and space, the asynchronous spatial-temporal relationship is also obtained based on the extraction of the spatial-temporal synchronous relationship by inputting asynchronous motion sequence graphs to achieve effective action recognition. Extract the spatial-temporal information by integrating it into the convolutional layer of the spatial-temporal two-stream network, avoid going through the fully connected layer, and effectively reduce the feature loss. Inputting the fusion feature map into Bi-LSTM can model the temporal relationship of video segments, and give the correlation relationship before and after video frames. The network can extract and fuse spatial-temporal relations well to realize action recognition.

#### **4. A Human Action Recognition Method Based on Spatial-Temporal Fusion Network**

A sparse sampling method is used to sample the videos. For an input video *V*, divide it into *k* segments {*S*1, *S*2,..., *Sk*} of equal time, and then randomly sample a segment {*T*1, *T*2,..., *Tk*} from the corresponding segments. The data are preprocessed to extract the RGB image of each segment *Tk* and its corresponding *x* and *y* directional optical flow graphs before and after a period of time, and the optical flow covers the front and back of the image. In this way, not only the motion information on the current action but also its asynchronous motion information is considered.

To obtain each fragment, the paper adopts a sparse sampling method. Utilize the RGB image, and the *x* direction and *y* direction optical flow graphs containing time series information, i.e., including not only a small segment of optical flow graph corresponding to the image, but also the optical flow graphs before, during, and after its motion. Assuming that the duration of the motion information is *L* frames, the input optical flow information involves 2 × *L*, and fully contains the asynchronous information of time and space.

For the *t*th video frame, its spatial features are extracted by the corresponding one RGB image, the extraction of motion information starts from the *t*th frame optical flow graph, and then superimposes the horizontal and vertical optical flow fields of *N* subsequent video frames, that is, input the optical flow fields of [*t*, *t* + *N*] into the temporal convolutional neural network for analysis. The traditional two-stream methods merely extract an RGB image and a set of such optical flow features as input for each sampling segment, but this is only effective for actions with synchronous spatial and temporal information. Some actions are not always synchronous between the spatial and motion information, there are early or late cases, namely asynchrony.

For example, in the carrying behaviors of emergency rescue, when the motion extracted by the optical flow information block is exactly the carrying process of the paramedics, a simple two-stream network such as image and optical flow cannot accurately distinguish the carrying of emergency rescue from ordinary carrying behaviors. However, in the process of emergency rescue, there are frequent behaviors such as simple treatment and moving the injured who need to be carried on stretchers before the carrying action, as well as professional rescue after the carrying action. These actions either advance or lag behind the carrying actions on the image. Thus, in the process of emergency rescue, there is very often strong advance and lagging asynchronous information between the movement of the behavior and image.

Traditional two-stream networks do not consider the asynchronous characteristics between spatial and motion information, so the proposed method considers the asynchronous motion features with spatial information, to obtain more spatial-temporal relationships for effective action recognition. In addition to extracting the *N* frame optical flow field synchronized with the *t*th frame image of the traditional methods, the presented method also extracts more optical flow features asynchronously associated with the *t*th frame image. Take Δ as the time interval, extract a total of [*t* − 2Δ, *t* + 2Δ + *N*] frame light-flow fields, and the obtained time domain features are used as the input of the temporal flow network.

Spatial-temporal fusion network enables action classification by extracting and incorporating spatial and temporal features of videos. The traditional two-stream network integrates the spatial-temporal information after passing through the fully connected layer but destroys feature information. To make good use of the extracted space and motion features in the pixel-level correlation to achieve full fusion, this paper compares the different fusion methods, fuses the spatial-temporal features in advance right in the convolutional layer, and the fusion feature sequences contain the synchronous and asynchronous correlations between the motion sequence information and single frame image. The asynchronous relationship between the motion sequences and the space is different for different actions. Therefore, modeling the temporal features is also required for the mid-level spatial-temporal fusion feature graphs following the fusion of motion and spatial information. This paper introduces Bi-LSTM to extract temporal features from the fused feature sequences. It can solve the problem that TSN does not consider the correlation of spatial-temporal information without destroying its spatial-temporal characteristics. At the same time, it can further extract the synchronous and asynchronous relationships between motion and scene information. The flowchart of the proposed spatial-temporal fusion network method is shown in Figure 3.

#### *4.1. Spatial Flow Convolutional Neural Networks*

Through the sparse sampling of the videos, a set of single-frame RGB images are obtained as the input of the spatial flow CNN, and the actions are discriminated by the static spatial appearance information. Static RGB images use three channels to store pixel information and represent the shape. For the actions with the obvious correlation between objects and scene information, the actions can be classified through the scenes and the objects in the video frame, such as bandaging.

The actions and certain objects are inseparable in the videos, for example, bandaging the wound requires gauze. The spatial flow CNN extracts the spatial features of actions by identifying the background and shape information in the RGB frames. Therefore, the spatial flow networks can effectively recognize video behaviors by directly using image classification networks.

**Figure 3.** The flowchart of the proposed spatial-temporal fusion network method.

Both the spatial and temporal streams are detected by CNN in the two-stream network architecture. The spatial flow CNN focuses on extracting the features of the RGB image sequences, while the temporal CNN processes the optical flow information between the adjacent frames [25]. Spatial and temporal convolutional neural networks adopt the same network structure. The CNN used in this paper is the VGG16 network [23], and the network structure is shown in Figure 4.


**Figure 4.** The VGG16 Network Structure.

#### *4.2. Temporal Flow Convolutional Neural Networks*

The optical flow is utilized to measure the motion information about the video behaviors in the temporal network of the two-stream network [26]. A video is a combination of video frame images composed of continuous pixels. The optical flow method can detect the speed and direction of the target movement, which is judged by viewing the intensity changes of the pixels between continuous images. It expresses the changes in the images and contains the movement information about the targets, so the temporal features of the moving targets can be obtained [27].

In a temporal network, the input is 2*L* image frames consisting of stacked *x* directional horizontal optical graphs *dx* of *L* continuous video frames and *y* directional vertical optical graphs *dy*. Assuming that the video frame size is *w* × *h*, for any input video frame *τ*, the input optical flow field block *<sup>I</sup><sup>τ</sup>* <sup>∈</sup> *<sup>R</sup>w*×*h*×2*<sup>L</sup>* of the temporal stream can be calculated by Equation (1).

$$\begin{aligned} I\_{\tau}(\mu, v, 2k - 1) &= d\_{\tau}^{x}(\mu, v), \\ I\_{\tau}(\mu, v, 2k) &= d\_{\tau + k - 1}^{y}(\mu, v), \\ \mu = [1; w], v &= [1; h], k = [1; L] \end{aligned} \tag{1}$$

where *Iτ* represents the superposition of optical flow field blocks, namely the input of the temporal network. *d<sup>x</sup> <sup>τ</sup>* and *d y <sup>τ</sup>*, respectively, indicate the horizontal and vertical optical flow fields at time *τ*. *L* is the number of video frames, and (*u*, *v*) is the offset.

In addition to inputting the superimposed optical flow of the 2 × *L* (*L* = 10) continuous frames at the time of the motion, the proposed model also extracts the optical flow information of 2*k* before and after as the input of the temporal network. This is because the optical flow from the motion start frame (*t*0) to the motion end frame (*t*<sup>0</sup> + *L*) is the motion information synchronized with the spatial information. In addition to the information, there is also motion information that is asynchronous with spatial information, which is critical to action recognition. Therefore, the model also inputs *k* frames before the start of motion (*t*0− *k*) and *k* frames after the end of motion (*t*<sup>0</sup> + *L* + *k*) as the input of the temporal network.

Spatial-temporal information is asynchronous. For the two actions of infusion and injection, the motion of touching the human body with a needle tube is very similar. If only the optical flow information at this moment determines the action category, similar motion sequences will easily lead to misjudgment. Therefore, in addition to the frame of the motion moment, it is necessary to input the optical flow information before and after to assist the judgment. For the action at time *t*, in addition to inputting the RGB image and optical flow at time *t*, the optical flow information at time *t*, *t* − *k*, *t* + *L*/2, *t* + *L*, and *t* + *L* + *k* should be input. The presented two-stream network structure is shown in Figure 5.

**Figure 5.** The presented two-stream network structure.

#### *4.3. Spatial-Temporal Feature Learning and Fusion*

The spatial and temporal flow networks in two-stream neural networks obtain the corresponding classification results before the fusion to realize the final identification. The fusion of spatial-temporal networks mainly makes full use of the spatial and motion features of the videos and combines the correlation between the spatial and motion features, to judge the different behavior types. For example, in the bandaging behavior, the spatial flow network can identify the shape information of the hands and triangle towel, and the temporal flow network extracts the periodic action of the hands in a specific spatial position, so combining both of them can identify the bandaging action. However, the fusion of category scores after the fully connected layer cannot achieve the true sense of correlation fusion. To fully exploit the connection between spatial and temporal properties, the fusion of spatial and temporal streams needs to be thoroughly studied. To integrate the spatial and temporal network streams, this paper investigates three potential spatial and temporal fusion techniques.

(1) Early-stage fusion:

Early-stage fusion is performed before the input network by fusing the sparsely sampled single-frame RGB image and *L*frame superposition optical flow fields, i.e., threechannel information of the frame images and two directions of light flow field information are fused to form 3 + 2*L* channels, then input to the network to extract spatial-temporal features and achieve action classification. The early-stage fusion process is shown in Figure 6.

**Figure 6.** Early-stage fusion process.

#### (2) Mid-term fusion:

Mid-term fusion is the fusion in the network. The single-frame image and the overlapping optical flow field of its corresponding time are sent as input to the spatial flow and the temporal flow networks, respectively, and the spatial features and motion features of the video are extracted from the multi-layer convolutional layers. The extracted spatialtemporal features are fused in the convolutional layers to generate the feature graphs and spatial-temporal feature vectors, and then the classifier is used to classify the actions.

Figure 7 shows the mid-term fusion process. Mid-term fusion mainly includes summation fusion, maximum fusion, and mean fusion.

**Figure 7.** Mid-term fusion process.

(3) Late fusion:

Most of the fusion methods adopted by the traditional two-stream networks are late fusion. After the video information is input to the spatial and temporal flow networks, the corresponding category scores are obtained through feature extraction, and the two scores are directly fused to obtain the final recognition results, as shown in Figure 8.

**Figure 8.** Late fusion process.

Assuming that *f st* and *f tp* are eigenvectors, respectively, extracted from the spatial and temporal flow CNN. Calculating the score by the Softmax classifier is shown in Equations (2) and (3).

$$p(j|f\_{st}) = S\_{st}^{j} = \frac{\exp(\theta\_{st}^{\dot{j}} \cdot f\_{st})}{\sum\_{j'=1}^{n} \exp(\theta\_{st}^{\dot{j}} \cdot f\_{st})} \tag{2}$$

$$p(j|f\_{tp}) = S\_{tp}^{j} = \frac{\exp(\theta\_{tp}^{j} \cdot f\_{tp})}{\sum\_{j'=1}^{n} \exp(\theta\_{tp}^{j} \cdot f\_{tp})} \tag{3}$$

where *θ j st* and *θ j tp* represent the Softmax classifier parameters in the spatial and temporal flow CNN, respectively. *p*(*j*|*f st*) and *p*(*j*|*f tp*) denote the posterior probabilities that *f st* and *f tp* belong to the *j*th category [28].

The obtained spatial-temporal high-level features are fused in the convolutional layer to form a spatial-temporal fusion feature graph. In this way, the pixel-level fusion can be directly realized without passing through the fully connected layer, and the spatialtemporal correlation information can be extracted to achieve full fusion without affecting any features.

The proposed model in this paper changes the traditional fusion approaches by fusion in the convolutional layer and considers the synchronization and asynchronism of spatial-temporal information to fully fuse spatial-temporal features without destroying the spatial-temporal features. The input of the original two-stream network is a moving image and a motion sequence, and the spatial-temporal synchronization relationship is considered. The presented method adds another motion sequence to the input of the temporal flow to extract the asynchronous motion information with the moving image to model the long-term motion.

#### *4.4. Bi-LSTM Time-Series Feature Learning Network*

The Bi-LSTM network in Figure 3 extracts the fused temporal features and contains the matching relationship between motion information and multi-frame images. In the matching process, the asynchronous relationships between motions and images are different for different categories of behaviors, and thus further deep learning of the fused temporal features is required after the fusion of motion and image features. To further mine the synchronous and asynchronous information of the spatial-temporal networks, this paper introduces the Bi-LSTM network to construct the long-term motion model of the fused sequences. Bi-LSTM is a good method to model temporal data. In Bi-LSTM, the input at a certain moment will depend on the video frame information before and after it, which can well satisfy the asynchronous relationship of video actions and fully consider the temporal information. Compared with LSTM, Bi-LSTM can obtain stronger temporal information and realize the effective integration of video asynchrony information in the case of learning the front and rear video information (Algorithm 1). The pseudo-code of the Bi-LSTM temporal feature learning network algorithm is as follows. The meaning of each symbol in the pseudo-code is shown in Table 2.



**Table 2.** The meaning of each symbol in the pseudocode.


#### **5. Experiment Results**

#### *5.1. Datasets*

This paper studies the action recognition for the whole emergency rescue video sequences, which is the classification of emergency rescue actions for the whole video data, so the annotation file is in the form of {*video*, *action*\_*id*} and the action (*action*\_*id*) is specified action for each video (*video*). Experiments mainly adopt an improved self-built emergency rescue dataset and the publicly available dataset UCF101.

#### 5.1.1. An Improved Emergency Rescue Dataset

Based on referring to the abundant information, there are few action datasets in emergency rescue scenes. The existing datasets for spatial-temporal action recognition usually provide sparse annotations for composite actions in brief video clips. The emergency rescue dataset used in experiments is an improved research result of the authors [4].

The video dataset of spatiotemporally localized Atomic Visual Actions (AVA) densely annotates 80 atomic visual actions in 430 15-min video clips, where actions are localized in space and time, resulting in 1.58M action labels with multiple labels per person occurring frequently. The AVA dataset defines atomic visual actions using movies to gather a varied set of action representations. This departs from existing datasets for spatial-temporal action recognition, which typically provides sparse annotations for composite actions in short video clips. AVA, with its realistic scene and action complexity, exposes the intrinsic difficulty of action recognition. Since there are many people and multiple actions in the identification scenes of dynamic emergency rescue, the self-built dataset of literature [4] is built with reference to AVA.

The data divided the actions in the dynamic scenes of emergency rescue into daily actions and medical rescue actions including carrying, cardio-pulmonary resuscitation (CPR), bandage, infusion, injection, oxygen supply, standing, walking, running, lying, sitting, and crouching/kneeling.

We collected various videos about emergency rescue scenes from a variety of video websites such as YouTube, Tencent Video, and Bilibili, and intercepted the videos to obtain the segments related to emergency rescue operations using the video editing software FFmpeg, and include a total of 700 video segments. In addition to the videos collected in the literature [4], the daily actions also use some segments of the KTH public dataset [29]. To increase the recognition accuracy of small targets for large ranges, some small target data are also added to the dataset. Some examples of the dataset are shown in Figure 9.

**Figure 9.** Some examples of the dataset.

A bounding box is used to locate a person and his or her actions. For each piece of video data, keyframes are extracted, and human-centered annotations are performed. In each keyframe, each person is marked with the preset action vocabularies of the paper that may have multiple actions.

To improve the annotation efficiency, Faster R-CNN is used to detect the position of the person, and Via software is utilized to annotate actions and people. In the stage of action annotation, this paper deletes all incorrect bounding boxes and adds missing bounding boxes to ensure high accuracy. During the labeling stage, each video clip is annotated by three independent annotators to ensure the accuracy of the dataset as much as possible.

Marking all actions of all people in all keyframes, most person-bounding boxes have multiple labels, which naturally leads to a type imbalance between action categories. Compared to daily actions, there are fewer medical actions. This paper refers to the features of the AVA dataset and runs the identification model on actions without adopting the manually constructed and balanced datasets. For the actions annotated by the self-built dataset, the frequency distribution of the various action categories is counted in Figure 10.

**Figure 10.** Action category frequency distribution in the self-built dataset.

The number of manual annotation samples is smaller. To prevent the overfitting phenomenon in the network and increase sample diversity, this paper collects action sequence video data as the emergency rescue human action data for annotation and then expands the data through data augmentation methods.

Data augmentation can address the issue of sample class imbalance and prevent overfitting in neural networks. The preprocessing includes short-edge resizing as well as normal operations. During the model training process, image augmentation methods of cropping sampling, translation transformation, and random flipping are used for the images.

The dataset has a great influence on the experiment results. To ensure the accuracy of the dataset, the paper trains the model on both the self-built and UCF101 datasets, and chooses the model with the best result to further test and adjust the dataset.

#### 5.1.2. UCF101 Dataset

The mainstream action recognition dataset UCF101 has 13,320 videos from 101 action categories. The action categories include human–object interaction, human–human interaction, playing musical instruments, body-motion only, and sports. Since most of the available action recognition datasets are unrealistic and performed by participants in stages, UCF101 aims to encourage further research on action recognition by learning and exploring new realistic action categories. The database consists of realistic user-uploaded videos containing camera motion and cluttered backgrounds. UCF101 is currently the most challenging dataset of actions.

#### *5.2. Metrics*

This paper uses accuracy to evaluate the recognition results and visualizes the recognition results by adopting a confusion matrix. A confusion matrix is a performance measurement for machine learning classification and is mainly used to count the number of predicted values in the wrong and right categories, respectively [30].

In the confusion matrix, rows represent actual values, columns represent predicted values, and the number of columns is equal to the total number of rows. True Positive (TP) shows the predicted positive and it is true. True Negative (TN) represents the predicted negative and it is true. False Positive (FP) denotes the predicted positive and it is false. False Negative (FN) indicates the predicted negative and it is false. The four types of samples have no intersection, and the sum of TP, FP, TN, and FN is the total number of samples.

The larger the diagonal values of the confusion matrix (TP, TN), the higher the correct classification probability of the model and the better the model performance. For an ordinary binary task, the confusion matrix is shown in Table 3. It is a table with four different combinations of predicted and actual values.

**Table 3.** Binary confusion matrix.


#### *5.3. Analysis of the Experiment Results*

To reflect the effectiveness of the spatial-temporal asynchronous fusion network alone, the VGG\_16 network structure is adopted in the fusion network.

(1) Analysis of experiment results of spatial-temporal asynchronous information

The model inputs a total of 2(*L* + 2*k*)(*L* = 10) superimposed optical flows of successive frames, including synchronous and asynchronous information on spatial and motion sequences. The effect on the model is studied by taking different values of *k*, and the results are shown in Table 4.

**Table 4.** Effects of taking different values of *k* on the recognition results


The result is best when *k* is taken as 10, that is, a superimposed optical flow field from (*t*<sup>0</sup> − 20) to (*t*<sup>0</sup> + 20) for consecutive 40 frames. When *k* is, respectively, taken as 15 and 20, there is too much confusing information, which can easily decrease the recognition accuracy.

When training the spatial-temporal fusion network, the input of the image recognition network is a static image frame at time *t*0. The input of the optical flow network is superimposed optical flow fields from (*t*<sup>0</sup> − 20) to (*t*<sup>0</sup> + 20) centered on the time *t*0, forming a 2 × 40 = 80 channel optical flow block, which is cut into 20 channel superimposed optical flow blocks by a sliding window with a step size of 5. Due to the limitation of hardware memory, the batch size used for network training is 8, which is equivalent to randomly sampling 8 frames of static images during each training and the optical flow fields of 20 frames before and after each frame.

Due to the strong spatial-temporal asynchrony of actions, this paper fully utilizes spatial and temporal asynchronous actions to improve accuracy by integrating the asynchronous information of spatial and temporal features. The effects of the spatial-temporal information on the recognition accuracy are shown in Table 5.

**Table 5.** Effects of spatial-temporal information on accuracy.


(2) Analysis of the experiment results of the fusion methods

In this paper, spatial-temporal features are fused in the convolutional layer inside the two-stream fusion VGG16 model. During the experiment, the spatial-temporal feature fusion is performed on the convolutional layer Conv3 of the two-stream structure. The paper compares the separate spatial flow, temporal flow networks, and the two-stream networks, respectively. The action recognition accuracy of the different fusion methods is shown in Table 6.

**Table 6.** Comparison of action recognition accuracy for different fusion methods.


As can be seen from Table 6, the fusion of spatial-temporal asynchronous information in the convolutional layer Conv3 has the best effect. Different from the loss of information in the fully connected layer fusion, the fusion in the convolutional layer can not only retain better middle-level information of time and space but also obtain higher accuracy.

(3) Experiment comparison analysis of the action recognition for the spatial-temporal fusion CNN

To verify the effectiveness of the spatial-temporal fusion CNN in action recognition, the proposed method is compared with TSN, two-stream network, and two-stream network + Bi-LSTM methods, as shown in Table 7. It can be seen that the proposed method improves recognition accuracy.

**Table 7.** Comparison of recognition accuracy for different methods.


To give the model performance more intuitively, a confusion matrix is used to present the degree of confusion between the predicted and actual categories of the model, as shown in Figure 11.

**Figure 11.** Confusion matrix.

In Figure 11, C/K represents crouching/kneeling, and O-S indicates oxygen supply. As can be seen from the confusion matrix in Figure 11, among the confusing actions, carrying is a 7% probability of being identified as walking because carrying has a certain overlap with walking and running, and other actions are similar. CPR is a 6% probability of identifying as C/K since CPR in emergency rescue situations is mostly in a kneeling position. Injection and infusion can be confused with each other, i.e., 5% of the injections will be predicted to be infusions and vice versa. Other actions are misclassified with a small probability. In daily human actions, the common action of walking is a 3% probability of being recognized as running, 1% probability of being identified as standing, and 4% probability of being recognized as carrying, because it is easy to be misclassified as carrying when several people gather in one place. C/K is a 6% probability of being identified as sitting. The above actions are easily confusing daily actions and medical rescue actions. There are still some misclassifications of easily confusing actions, but the misclassifications have reduced to some extent, and the model has a better ability to distinguish confusing actions.

The recognition results are visualized as shown in Figures 12–14. In Figure 12, the optical flow captures the dynamic action sequence information about the action of dressing a wound, although the injured man is sitting and relatively still, the action sequences of the ambulanceman capture the bandage, so the recognition results are the action of the bandage. Figure 13 shows the recognition results when the persons in the video are carrying uniformly.

**Figure 12.** Bandage action recognition results in a simpler background.

**Figure 13.** Carrying recognition results in a simpler background.

**Figure 14.** CPR recognition results in a more complex background.

For the more complex background situations, when the persons in the video perform different actions, the recognition result is the action with the highest probability of all actions, i.e., performing the most important action. Figure 14 gives the recognition results when the people in the video perform different actions, respectively, standing, C/K, and CPR, the dynamic action of the main person is CPR, and the recognition results are CPR, which means that the main execution action of the video is CPR.

Moreover, to verify the effectiveness of the proposed spatial-temporal fusion model, experiments are also conducted on the mainstream dataset UCF101 to compare the proposed method with the classical and advanced methods.

This paper compares single-flow CNN and various improved methods based on twoflow CNN, including the algorithm based on C3D, the traditional recognition algorithm based on two-stream convolutional networks (Two-stream Convnet), the Long-Term Recurrent Convolutional Networks (LRCN)-based recognition algorithm, two-stream network and LSTM fusion recognition algorithm (Two-stream + LSTM), recognition algorithm fused two-stream network and LSTM in convolutional layer (Two-stream + LSTM + ConvFusion), the improved human action recognition algorithm of Spatial Transformer Networks (STN) and CNN fusion [31], and the two-stream 3D Convnet fusion action recognition algorithm [15]. The comparison results are shown in Table 8.

**Table 8.** The experiment comparison results on the UCF101.


The comparison results show the proposed spatial and temporal fusion method has the best recognition effect, the method can accurately recognize the human action in the videos and verify the effectiveness of the method.

In terms of speed, on the UCF101 public dataset, the time complexity of this method is determined by TSN and Bi-LSTM with a running speed of 197.2 fps. The time complexity of C3D is determined by the convolutional layer with a running speed of 313 fps. The time complexity of the two-stream network is also determined by the convolutional layer with a running speed of 33.3 fps. The LRCN method is simpler than the C3D network structure, has a small number of parameters and is easy to train, and runs faster than the C3D network. The time complexity of literature [32] is determined by Two-stream together with LSTM with a running speed of 29.7 fps. The time complexity of literature [33], literature [31], and literature [15] is determined by the convolutional layer and LSTM. Literature [33] runs at two speeds, 6 fps when the input is an optic flow image and 30 fps when the input is an RGB frame. The time complexity of literature [31] is determined jointly by STN and CNN, and its running speed is 37 fps. Literature [15] runs at two speeds, processed at 186.6 fps when the input is an RGB frame. When the input is an optical flow image, the processing speed is 185.9 fps.

As can be seen from Table 8, the speed of the proposed method is slower than both the C3D and LRCN methods, mainly due to the high time complexity caused by the introduction of Bi-LSTM. Compared with the other methods, the proposed method is fast. Overall, the proposed method has the highest recognition accuracy and relatively fast speed.

In recent years, human action recognition methods have focused on deep learning. At present, the latest methods mainly include the TS-PVAN action recognition model based on attention mechanism [34], skeleton-based ST-GCN for human action recognition with extended skeleton graph and partitioning strategy [35], human action recognition based on 2D CNN and Transformer [36], linear dynamical system approach for human action recognition with two-stream deep features [37], and hybrid handcrafted and learned feature framework for human action recognition [38]. Comparative analysis with the latest methods is as follows.

(1) The TS-PVAN action recognition model based on attention mechanism can adequately extract spatial features and possess certain generalization abilities, but the temporal network of the TS-PVAN cannot efficiently model long-range time structure and extract rich long-term temporal information. This paper introduces Bi-LSTM to model long-term motion and fully mine the long-term temporal information.

(2) The human action recognition method combined Skeleton-based ST-GCN with extended skeleton graph and partitioning strategy can extract the non-adjacent joint relationship information in the human skeleton images, and divide the input graph of Graph Convolutional Network (GCN) into five types of fixed length tensors by the partition strategy, to include the maximum motion dependency. However, this method does not consider the temporal features. The proposed method extracts temporal information using the temporal network of the two-stream network.

(3) 2D CNN is one of the mainstream methods for human action recognition at present. 2D CNN-based framework not only has the advantages of lightweight and fast reasoning ability but also operates on short segments of sparsely sampled whole videos. However, 2D CNN still suffers from the insufficient representation of some action features and a lack of temporal modeling capability. The human action recognition method based on 2D CNN and Transformer adopts 2D CNN architecture of channel-spatial attention mechanism to extract spatial features in frames, utilizes Transformer to extract complex temporal information between different frames, and improves the recognition accuracy. However, Transformer extracts spatial and temporal features in sequential order, and as the number of frames increases, the number of parameters also increases substantially, causing a burden for the calculation. The paper adopts a dual-stream network structure to extract the spatialtemporal information, so the spatial-temporal feature extraction is in parallel, and the TSN sparse sampling strategy is used to avoid the greater computational burden caused by the increase in the number of frames.

(4) The human action recognition method combined linear dynamical system approach with two-stream deep features captures the spatial-temporal features of human action using a dual-stream structure. The method operates directly on video sequences. The longer the video sequence is, the more time is consumed. The presented method adopts the time domain segmentation strategy for TSN to randomly sample fragments and speed up the operation.

(5) The human action recognition method based on hybrid handcrafted and learned feature framework uses a two-dimensional wavelet transform to decompose video frames into separable frequency and directional components to extract motion information. The dense trajectory method is used to extract feature points for tracking continuous frames. However, this method can only deal with videos with clear action boundaries, which is also a disadvantage of the proposed method.

#### **6. Conclusions**

To address the problem that the spatial-temporal fusion network does not fully fuse the spatial and temporal dimension information which leads to a decrease in human action recognition accuracy, this paper proposes a dynamic scene human action recognition method based on the spatial and temporal fusion network model. Considering the strong asynchrony and time sequence of video action recognition, a spatial-temporal feature asynchrony fusion framework is designed to extract spatial and asynchronous temporal features for fusion. Utilize Bi-LSTM to fully extract temporal information to capture videolevel motion information and fuse spatial-temporal information, and realize human motion recognition by Softmax. The presented method can model long-term motion behaviors by modeling the asynchronous relationship between the moving image and the motion sequence (optical flow). The proposed method is still far away from practical application, so the robustness and real-time performance of the method can be further studied in the future.

**Author Contributions:** Conceptualization, methodology, data curation, validation, writing—original draft preparation Q.G., Y.Z., Z.D. and A.W.; writing—review and editing, Z.D.; supervision, project administration, Y.Z.; funding acquisition, Y.Z. and A.W. All authors have read and agreed to the published version of the manuscript.

**Funding:** This paper was funded by National Key Research and Development Program Project (2020YFC0811004), National Natural Science Fund of China (61371143), and R&D Program of Beijing Municipal Education Commission (KM202110009002).

**Data Availability Statement:** The raw data can be provided on simple request.

**Conflicts of Interest:** The authors declare no conflict of interest.

#### **References**


**Disclaimer/Publisher's Note:** The statements, opinions and data contained in all publications are solely those of the individual author(s) and contributor(s) and not of MDPI and/or the editor(s). MDPI and/or the editor(s) disclaim responsibility for any injury to people or property resulting from any ideas, methods, instructions or products referred to in the content.

### *Article* **Vehicle Classification Using Deep Feature Fusion and Genetic Algorithms**

**Ahmed S. Alghamdi 1, Ammar Saeed 2, Muhammad Kamran 1,\*, Khalid T. Mursi <sup>1</sup> and Wafa Sulaiman Almukadi <sup>1</sup>**


**Abstract:** Vehicle classification is a challenging task in the area of image processing. It involves the classification of various vehicles based on their color, model, and make. A distinctive variety of vehicles belonging to various model categories have been developed in the automobile industry, which has made it necessary to establish a compact system that can classify vehicles within a complex model group. A well-established vehicle classification system has applications in security, vehicle monitoring in traffic cameras, route analysis in autonomous vehicles, and traffic control systems. In this paper, a hybrid model based on the integration of a pre-trained Convolutional Neural Network (CNN) and an evolutionary feature selection model is proposed for vehicle classification. The proposed model performs classification of eight different vehicle categories including sports cars, luxury cars and hybrid power-house SUVs. The used in this work is derived from Stanford car dataset that contains almost 196 cars and vehicle classes. After performing appropriate data preparation and preprocessing steps, feature learning and extraction is carried out using pre-trained VGG16 first that learns and extracts deep features from the set of input images. These features are then taken out of the last fully connected layer of VGG16, and feature optimization phase is carried out using evolution-based nature-inspired optimization model Genetic Algorithm (GA). The classification is performed using numerous SVM kernels where Cubic SVM achieves an accuracy of 99.7% and outperforms other kernels as well as excels in terns of performance as compared to the existing works.

**Keywords:** convolutional neural network; fused deep earning; vehicle classification

#### **1. Introduction**

The evolution of the modern era has had a significant impact on the automobile industry, which has progressed rapidly. Nowadays, vehicles of the same companies are being released with various colors, models, and physical attributes, making it difficult to differentiate them without having some prior knowledge about those models that makes developing a system that could perform vehicle classification an even bigger challenge. The emerging concept of smart cities relies on an intelligent traffic monitoring and classification system that could detect and surveil different vehicles for traffic rule obstruction, security, and emergency situations [1]. The ever-increasing demand, production and usage of vehicles of all kinds of makes, colors and models, it becomes very difficult for a human agent to perform vehicle monitoring, record keeping, surveillance and detection for any kind of obstruction [2]. Therefore, establishing an automated system that can discriminate between various vehicle types is necessary. A model like this could have applications in the area of security, smart traffic systems, self-driving vehicles for environmental understanding and collision avoidance, criminal activity reduction, and vehicle-type detection [3].

An intelligent traffic system could also assist in crime reduction and criminal activity tracking, given that most criminal activities involve the use of some kind of vehicle for movement. In such cases, vehicle data could be obtained from ITS (Intelligent Transport

**Citation:** Alghamdi, A.S.; Saeed, A.; Kamran, M.; Mursi, K.T.; Almukadi, W.S. Vehicle Classification Using Deep Feature Fusion and Genetic Algorithms. *Electronics* **2023**, *12*, 280. https://doi.org/10.3390/ electronics12020280

Academic Editor: Gwanggil Jeon

Received: 15 December 2022 Revised: 29 December 2022 Accepted: 2 January 2023 Published: 5 January 2023

**Copyright:** © 2023 by the authors. Licensee MDPI, Basel, Switzerland. This article is an open access article distributed under the terms and conditions of the Creative Commons Attribution (CC BY) license (https:// creativecommons.org/licenses/by/ 4.0/).

System) to help with criminal tracking [4]. The main focus and purpose of this work is to generate such an automated, self-contained and intelligent computerized visionbased system which could differentiate between various vehicle categories with up to the mark precision and accuracy. Such a system would have significant applications in traffic monitoring, smart cities traffic controlling, security and auto vehicle detection in drones and self-driving cars. Until now, many conventional and handcrafted means have been used for vehicle discrimination through color, vehicle structure, or model. This identification process produces decent results on the targeted data for which the approach is implemented, but its functionality becomes limited, and the accuracy of classification gets very low when the data perspective changes and varying data are used [5]. Therefore, the latest deep learning and machine learning models are being used for the development of such automatic traffic classification systems. The convolutional neural networks (CNNs) are widely used for this purpose; they can be trained on large datasets initially and then used for more narrowly defined tasks. These deep learning CNN models are much better for classification than the conventional handcrafted methods; they comprise a huge number of deep convolution layers that allow better and deeper learning about random datasets [6]. Considering all these prospects, the automated vehicle detection and classification to be in this work will be comprised on CNN model together with appropriate optimization and classification methods. In this paper, an amalgamation scheme, based on a genetic algorithm (GA) and pre-trained CNN VGG16, has been proposed for automated vehicle classification. The dataset used in the project comprises of five vehicle classes (i.e., bike, car, bus, truck, and helicopter) which are then resized to maintain uniformity in image dimension. Next, these images are given as input to the deep convolutional network namely, VGG16, to perform feature extraction and learning within its deep layers. The extracted features are optimized and reduced using GA that evolves in iterations and looks for the most concerned solution points based on priority and discards the others. The selected features are then passed on to the classification learner, where they are classified with multiple Support vector machine (SVM) classifier variations. The experimental results showed that the proposed model, using the linear SVM (LSVM) classifier, achieved an accuracy level of 97.8%, outperforming other kernels and previous works. In a nutshell, the main contributions of the proposed work include the following:


The rest of the paper has been organized as follows. The related work has been discussed in Section 2. Section 3 presents a detailed description of the proposed methodology. The details about the experiments and results have been given in Sections 4 and 5 concludes the paper.

#### **2. Related Work**

Previous studies have proposed different vehicle classification systems, depending on their datasets. Molina-Cabello et al. [7] used the dataset comprising cars, trucks, and bikes and sequences obtained from the next generation simulation (NGSIM) program provided by the highway authorities. Image visual quality was enhanced using single-image superresolution and median filter transformation. AlexNet was employed for feature learning in

this scheme. The classification phase was performed using various classifiers, including a multilayer perceptron, an SVM, a naïve Bayes (NB), decision trees, and random forests. The model achieved a maximum accuracy of 91.5%. Oh and Ritchie [8] proposed a vehicle classification method based on loop signatures, also known as "blades." The signatures of different vehicles, including cars, pickup trucks, SUVs, and vans, were obtained manually through blade sensors that had been installed along various parkways; the dataset was categorized into five divisions, each division containing 60 vehicles. A probabilistic neural network (PNN) was employed, based on a Bayesian classifier for vehicle data classification. The proposed model was evaluated using the correct classification rate, in this case, 75%.

He, Shao, and Tan [9] used a dataset consisting of 1196 car images with a frontal perspective covering 30 standalone car models in 12 of their makes. Images were enhanced using illumination normalization and a multiscale retinex. A part-based detection model was used to segment various regions of the vehicles and parts by their importance, namely, headlights, logos, and grills, separated by using the ROIs (Region of Interests), making it easy to classify different models of the same vehicle. Local Binary Patterns (LBP) and Histogram of Oriented Gradients (HOG) feature extractors were used to extract geometrical and textural features from the defined image regions. The classification stage of proposed model was performed by various classifiers where the maximum detection accuracy was achieved by the AdaBoost classifier, proving out at 94.8%. Psyllos et al. [10] also used a frontal view vehicle image dataset for vehicle manufacturer and model recognition. The vehicle logos, license plates, and headlight grills were segmented using masking and Phase Congruency Detection. Feature extraction, learning, and classification were performed using a PNN comprising input, radial, and output layers that categorized the input patterns into pre-allocated classes. The proposed model achieved an accuracy of 94%. Sheng et al [11] used the Stanford dataset and formulated six classes: Volkswagen, Audi, Chevrolet, BMW, Mercedes-Benz, and Ford. The work focused on the classification of vehicle type and area detection for a particular vehicle. The experiments were performed with six CNNs—AlexNet, VGG16, VGG19, GoogleNet, ResNet50, and ResNet101—for both an RCNN and a faster RCNN. The proposed model provided an average accuracy of 93.32% when discriminating six vehicle types.

Soon et al. [12] used a vehicle dataset BIT containing 9850 vehicle images in six vehicle categories: bus, minivan, microbus, sedan, SUV, and truck. Images were preprocessed to eliminate those showing more than one vehicle. The image count after preprocessing in each vehicle class was 558 buses, 883 microbuses, 476 minivans, 5922 sedans, 1392 SUVs, and 822 trucks. A novel principal component analysis (PCA) convolutional network was proposed in which the massive time consumption of the CNN was resolved by composing the convolutional layer filters of the CNN with the help of PCA. This process reduced the training burden and produced flexible features against various aspects. The proposed model yielded an average accuracy of above 88.35% in various conditions. Mundhenk et al. [13] compiled a Cars Overhead with Context (COWC) dataset containing 32,716 images from six image classes obtained from various geographic regions. A CNN named ResCeption was proposed, using AlexNet as a baseline and GoogleNet/Inception as a batch normalizer. The proposed model achieved an average accuracy of 97.294%.

Divyavarshini et al. [14] constituted a sutom CNN model comprising of 25 max pooling layers for vehicle type recignition. Feature extraction is performed by the proposed CNN model and also by the handcrafted HOG feature extractor. Resultant vectors from both the models are fused and classified using SVM classifier. Ahsan et al. [15] proposed a CNN-based model for vehicle number plate detection and processing. The model captures take the digital camera captured image and uses the super pixel resolution method to enhance the image quality. All the number plate embedded characters are segmented using the bounding boxes. The pre-trained Alexnet is then used to derive 4096 features from the segmented numbered images and a maximum detection accuracy of 98.2% is achieved.

Dai et al. [16] improved the pre-trained ResNet-50 model and formulated a faster R-CNN architecture for vehicle distance estimation and pedestrian estimation. Real time images are acquired using the infrared-based cameras which contained long distance roads containing pedestrians and tagged values. The model runs at the frame rate of 7 fps and provides an accuracy of 80% on the real time data.

Some of the recent works like [17–19] used for classification and detection of vehicles also motivated us to investigate deep learning techniques along with evolutionary techniques for intelligent transport system.

In contrast with the techniques mentioned above, we proposed an amalgamation scheme based on GA and pre-trained CNN VGG16 for automated vehicle classification. We used a deep convolutional network VGG16 for feature extraction and learning within its deep layers. The extracted features were optimized and reduced using a GA that evolved in iterations and looked for the most concerned solution points based on priority and discarded others. The selected features were then passed on to the classification learner, where they were classified with multiple SVM classifier variations.

#### **3. Proposed Methodology**

#### *3.1. Deep Feature Fusion and Genetic Algorithms*

Deep learning models have the tendency to extract the deep features from the input data given to them. These extracted features have complex dimensions, are in the form of vectors and contain most of the information derived from the input data. DL models can be of two types: pre-trained CNN models such as AlexNet, GoogleNet, ResNet150 etc. or custom CNN models. The pre-trained CNN models are already trained on massive amounts of data compilations and can therefore provide good results in in most cases. However, the customs CNN models need to be trained well before testing them on real datasets. There are certain cases in which even a single pre-trained model does not provide better results. This type of situation can happen when the data set is not well prepared there are multiple datasets involved. In such cases, it is better to merge the features learned by the two separate CNN models using the methods of transfer learning and feature fusion. The learned features are extracted from the last fully connected layers or pooling layers of respective CNN models are merged to formulate a compact vectorized feature vector combination. Later, when these merged features are provided to machine learning-based classifiers, a significant increase in the performance is observed in most of the cases. This future fusion can help increase the performance of the proposed model, but it leads to feature complexity and model entanglement. It is observed in most of the cases that model provides better results but at a cost of increased time after feature fusion. This causes the need for a feature selector or an optimization algorithm that may reduce the complexity of these fused features while maintaining the important details contained within them. Several nature-based evolutionary models as well as mathematically formulated feature optimization models exist out there that are used for such purpose. Genetic Algorithm (GA) is a metaheuristic algorithm inspired by the process of natural selection and it belongs to the larger class of evolutionary algorithms. It calculates high quality and global optimal solutions for against the problem space provided to it. It contains populations of individuals which are the possible solutions for given problem in a search space. These possible solutions, also termed as chromosomes spread in the problem space and find the nearest optimal solution while keeping the other chromosomes updated with their status. The chromosome nearest to the optimal solution is the best solution and is selected for problem solving, which is then reproduced by crossover process to generate its offspring. This process is followed by altering some of the genes in the mutation phase and finally the initial population is encoded. This process continues until all of the problem space is explored.

#### *3.2. Proposed Work*

In the proposed work, a deep CNN and a natural evolution-based GA algorithm are combined to formulate an automated classification system for eight different vehicle categories. The model is initially composed of deep CNN VGG16 that uses its deep layers to perform feature extraction and learning. These features are extracted from the last fully connected layer of CNN and since these features are massive in number and may contain ambiguous information as well that affects results. Therefore, an evolutionary feature selector GA is employed to keep the most related features and discard others. The classification phase is performed with several SVM kernels to see which performs best on the current data nature. The proposed model workflow is illustrated in Figure 1.

#### **Figure 1.** Proposed model.

#### *3.3. Data Acquisition and Preprocessing*

The dataset used in the proposed work is derived from the publicly available Stanford car dataset. The original dataset contains 196 different vehicles classes and over 8800 images. We only selected eight distinct vehicle classes each containing approximately 45 images. The selected vehicles classes are based on images from some of the famous brands including Acura, Audi, Bentley, BMW, Chevrolet, Dodge, Hyundai and Tesla. The dataset is passed through several augmentation steps including image flipping and rotation to increase the number of images in each class, balance the dataset classes and the postaugmentation dataset contains 1000 images per vehicle class. The final dataset contains a total of 8000 images divided among eight vehicle classes as shown in Figure 2.

The images were also resized as VGG16 accepts images in dimension of 224 × 224 and to also create a uniformity among images so that the results are not affected by varying sizes. All the dataset images are resized into the dimensions of 224 × 224 before passing them on to feature extraction stage. Since Stanford dataset contains images that are captured in real-time through RGB cameras as well as black and white CCTV camera so some of the images are not in the RGB channel. Therefore, while preprocessing it is checked whether images are in RGB or some other color channel and those not in RGB are converted into RGB using RGB color map. In order to enhance the image quality and contrast, the Guided Filter is also applied. The preprocessing steps are elaborated in Figure 3.

7

**Figure 2.** Sample images from the used dataset.

**Figure 3.** Preprocessing.

#### *3.4. Feature Extraction*

The resized images were passed on to the deep convolutional network model VGG16 for feature extraction and learning. VGG16 is a 16-layer-deep convolutional network trained on a massive ImageNet database and can discriminate among 1000 object categories. The input given to it was 224 × 224 ("VGG-16 Convolutional Neural Network-MATLAB VGG16" n.d.). It contains five combinations of convolutional layers in the form of batches, each containing 2 to 3 convolutional layers, followed by the pooling layers, as shown in Figure 4 [20].

**Figure 4.** VGG16 CNN Architecture.

A total of 8000 images were provided as input to the VGG16 model, that performed feature extraction using its deep layers. Images were provided as input to the VGG16 model, which performed the phases of feature extraction and learned on them in its deep layers. Table 1 shows the details of the extracted features. The features were taken out of the last fully connected layer of the VGG16 model fc8; the SoftMax and classification layer were not used in this case; instead, the features were optimized first using the GA optimizer and then classified using the SVM classifier.

**Table 1.** Overview of extracted features.


#### *3.5. Feature Selection*

The extracted features from the VGG16 model were obtained from its last fully connected layer, fc8, and were given as an input to the GA for optimization and reduction. GA calculated high-quality and globally optimal solutions for optimization problems. It contained populations of individuals based on chromosomes that were actually the possible solutions for a given situation in the search space. Each chromosome indicates a candidate solution and is further based on a list of variable values. A problem having *KN* number of possible solutions means that each chromosome will have a *KN* list as represented in Equation (1).

$$\mathbb{C}ltrososomes = [s\_1, s\_2, s\_3, \dots, s\_{\mathbb{K}\_N}]\_\prime \tag{1}$$

where, each p represents possible solution with regard to a particular chromosome and there can be *S*(*KN*) solutions. GA begins with selection of a random number of such chromosomes that actually serve as the agents in the initial iteration.

The process of finding optimal solution initiates and each population of chromosomes begin searching for solution in the declared search space. Each chromosome in the population maintains a certain fitness function as it searches for the problem in search space which is evaluated for all the chromosomes at the end of initial iteration. From a massive population containing chromosomes, some of the population is maintained based on the fitness scores of their chromosomes upon a user-defined probability while the rest is discarded. The fittest chromosomes are more likely to be chosen. The probability of a chromosome *Cxy* to be selected considering g as a positive function is represented in Equation (2).

$$P(\mathbb{C}\_{xy}) = |\frac{\mathcal{g}(\mathbb{C}\_{xy})}{\sum\_{a=1}^{N} \mathcal{g}(\mathbb{C}\_{k})}|\,\text{}\,\tag{2}$$

where, *P*(*Cxy*) represents the selection probability of a random chromosomes from the initial population, *N* represents total population, and *Ck* represents the continuity of this function for the *k*th number of chromosomes. The next phase comprises of crossover of the selected fittest chromosomes to increase the population of solution-finding agents. For this, a pair of chromosomes having the highest fitness score are chosen and offspring are generated from them. The crossover operation is demonstrated in Equation (3).

$$p\_{\mathcal{T}} = \begin{cases} p\_{\max - cr} - \left( \frac{p\_{\max - cr} - p\_{\min - cr}}{f\_{\mathcal{R}}} \right) & f\_{\mathbf{x}} \ge f\_{\text{avg}} \\\\ p\_{\max - cr} & f\_{\mathbf{x}} < f\_{\text{avg}} \end{cases} \tag{3}$$

where, *pcr* is the crossover probability, *pmax*−*cr* and *pmin*−*cr* are the maximum and minimum crossover prospects, *Tn* is the maximum possible iteration, *fx* is the chromosome with greatest fitness among the two chromosomes selected for crossover, and *favg* denotes the overall fitness value of the whole population.

After the process of crossover, the final step of mutation is performed in which the genes of newly formulated offspring are altered with already available information to make them more effective. If this set of newly formed offspring provides with the optimal solution then the process is terminated otherwise this process is repeated till so.

For problem-solving, the best individual was selected in the GA, which was then reproduced by the crossover process to generate its offspring. This process was followed by altering some of the genes in the mutation phase and encoding the initial population [21]. Table 2 shows the number features selected by GA. In this work, the number of chromosomes is kept at 10 , number of iterations is kept at 100, learning rate is 0.001. 80% of data is kept as training set and 20% data is kept as testing set for GA.

**Table 2.** Feature selection overview for the proposed model.


#### *3.6. Classification*

Finally, the selected features are transferred to the classification learner, where the classification phase is performed using the SVM classifiers together with its several kernel namely Linear SVM (L-SVM), Cubic SVM (CB-SVM), Quadratic SVM (Q-SVM), and Medium Gaussian SVM (MG-SVM). A total of 500 features which are selected by the GA from a set of 1000 deep model learned features are forwarded to these classifiers and each classifier is individually applied on them. The learning rate for the model is kept as 0.0001. The L-SVM classifier outperforms others in terms of accuracy. The proposed model performed better than the previous work in terms of both accuracy and time consumption, achieving an accuracy of 97.8%.

#### **4. Experiments and Results**

In the proposed work, a model was organized for the classification of vehicle images. The dataset used in the proposed work contains a total of 8000 images divided among eight vehicle classes. After data preparation and preprocessing, images were given to the pre-trained VGG16 for feature learning and extraction. The features were then extracted from the last fully connected layer of CNN and were then optimized using GA before passing them onto the SVM classifier.

The experiments were performed on Intel Core i7 with 8GB RAM and running on Windows 10 OS. The system houses a 256GB Solid State Drive (SSD) on which the MATLAB 2020a version is installed on which all the experiments are performed.

Table 3 shows the results of various SVM classifiers when directly applied on extracted features from CNN without optimization. The results are compared and evaluated with the help of evaluation measures precision, recall and f1-score. Cubic SVM stands out in terms of accuracy as compared to the rest of kernels, but a large amount of training time is compromised in this case. The goal here is to maintain the same performance rate but with significantly reduced time consumption since while conducting these experiments, a considerably better machine was used and even after that these kind of training times were encountered so time is only going to increase if the machine is not good enough and that is why optimization is so much needed.

**Table 3.** Performance evaluation of classifiers used in the proposed model.


The reason for CB-SVM being the best performing kernel is that CB-SVM performs best on non-linear features that are not easily differentiable via a hyperplane as it performs the same computation multiple times until the best results are achieved.

Table 4 shows the results of various SVM kernels when applied on GA-optimized features. The results are again compared using various performance measures as well as training time. When the results of Tables 3 and 4 are compared, we have successfully achieved almost the same performance standards but with a greatly reduced time consumption rate. This makes our model stand out from the previous works as the proposed model provides the best results without compromising time.

Figures 5 and 6 show the ROC curves for Cubic SVM both when it classifies nonoptimized and GA-optimized features. Also, Cubic SVM is the best performing model in both cases as evident from the tables above.

**Figure 5.** ROC curve of Cubic SVM on non-optimized features.

The ROC curve for the proposed model is illustrated in Figure 6, which shows that the area under the curve is exactly 1.00 while considering the true and false positive rates.

**Figure 6.** ROC curve of Cubic SVM on GA-optimized features.

Similarly, Figures 7 and 8 demonstrate confusion matrices for CB-SVM both for nonoptimized and optimized features.

The reason behind extracting features from the last fully connected layer of VGG16 and using SVM to classify them is that former layer is followed by SoftMax and classification layers. The classification layer of a pre-trained VGG16 is programmed to classify 1000 different object classes. In this case, we only needed to classify eight vehicle classes, so the classification layer was not applicable. To make a fair comparison though, new SoftMax and classification layers were put in place and the extracted features were passed on to them straight after the "fc8" layer but the results were way behind.

The reason behind this is that the newly implanted layers need to be trained for massive data corpus just as the original VGG16's layers have been trained on ImageNet database having 18 million images. But in this case, we are only going with 8000 images as achieving the best results with limited data is one of our targets of this research work and also that its not possible to train a CNN on such as massive dataset without having great computing resources. Therefore, the proposed model discards this approach and goes with the concerned approach. Table 5 Shows the accuracy comparison between classification performed by the proposed model and the CNN. Table 5 shows the accuracy comparison of proposed model with CNN-based classification.

Figures 9 and 10 demonstrated graphical visualization of the utilized SVM models in case of both non-optimized and optimized features.

The proposed model has accomplished almost the same performance standards even after the implementation of GA and reduction of many features, but it also helped reduce time training time to a large extent. Figure 11 visualizes the difference between training time in the case of both the use cases.

Finally, Table 6 provides a comparison of the proposed model with the previously proposed works. The proposed model outperforms other works in terms of accuracy as well as time consumption rate.

**Figure 7.** Confusion Matrix of Cubic SVM on non-optimized features.

**Figure 8.** Confusion Matrix of Cubic SVM on GA-optimized features.

**Figure 9.** Graphical comparison of various SVM kernels on non-GA-optimized features

**Figure 11.** Graphical comparison of various SVM kernels on non-GA-optimized features

Table 4 shows the overall statistics regarding the accuracy, training time, and prediction speed for the various SVM classifiers used in the classification phase. The LSVM classifier is chosen for the proposed model because it excels in both accuracy and prediction speed, with a slight trade-off for training time, which is negligible.


**Table 4.** Performance evaluation of classifiers used in the proposed model.

**Table 5.** Proposed model and CNN classification accuracy comparison.


Table 6 shows a comparison of the proposed work with the latest studies presented in [22–24]. The proposed model outperformed the others in terms of accuracy and time management, providing the best accuracy as well as time efficiency.


**Table 6.** Accuracy comparison with existing works.

#### **5. Conclusions**

An effective automated vehicle classification system based on the ideas of deep learning can assist in various real-world applications, including security, monitoring, and surveillance. A combination of deep CNN VGG16 and an evolution-based GA was proposed in this paper. In the proposed model, the feature learning was performed by VGG16 on a dataset containing eight vehicle classes. The feature selection was then performed by the GA. Finally, the classification was performed using the SVM classifier. The CB-LSVM classifier achieved an accuracy of 99.78%, which was better than the accuracy in previous studies. The proposed model excelled in both accuracy and time consumption, compared with those in previous studies. We believe that dataset can be increased largely and further work can be done to explore new pathways in the proposed work.

**Author Contributions:** Conceptualization, A.S.A., A.S. and M.K.; methodology, A.S.; software, A.S. and M.K.; validation, A.S. and M.K.; formal analysis, A.S. and M.K.; investigation, K.T.M.; resources, A.A. and W.S.A.; data curation, A.S., K.T.M., A.A. and W.A.; writing—original draft preparation, A.S.; writing—review and editing, A.S., M.K., A.A. and K.T.M.; visualization, A.S. and M.K.; supervision, A.A. and M.K.; project administration, A.S. and M.K.; funding acquisition, A.S. and M.K. All authors have read and agreed to the published version of the manuscript.

**Funding:** The authors extend their appreciation to the Deputyship for Research & Innovation, Ministry of Education in Saudi Arabia for funding this research work through the project number MoE-IF-20-07.

**Institutional Review Board Statement:** Not applicable.

**Data Availability Statement:** The data used in this work is available at Kaggle.

**Acknowledgments:** Authors would like to thank University of Jeddah for providing administrative support for this project.

**Conflicts of Interest:** The authors declare no conflict of interest.

#### **References**


**Disclaimer/Publisher's Note:** The statements, opinions and data contained in all publications are solely those of the individual author(s) and contributor(s) and not of MDPI and/or the editor(s). MDPI and/or the editor(s) disclaim responsibility for any injury to people or property resulting from any ideas, methods, instructions or products referred to in the content.

### *Article* **Fermatean Fuzzy Programming with New Score Function: A New Methodology to Multi-Objective Transportation Problems**

**M. K. Sharma 1, Kamini 1, Arvind Dhaka 2,\*, Amita Nandal 2, Hamurabi Gamboa Rosales 3, Francisco Eneldo López Monteagudo 3,\*, Alejandra García Hernández <sup>3</sup> and Vinh Truong Hoang <sup>4</sup>**


**Abstract:** The aim of this work is to establish a new methodology to tackle the multi-objective transportation problems [MOTP] in a Fermatean fuzzy environment that can deal with all the parameters that possess a conflicting nature. In our research work, we developed a new score function in the context of a fermatean nature for converting fuzzy data into crisp data with the help of the Fermatean fuzzy technique. Then, we introduced an algorithm-based methodology, i.e., the Fermatean Fuzzy Programming approach to tackle transportation problems with multi-objectives. The main purpose of this research work is to give an alternate fuzzy programming approach to handle the MOTP. To justify the potential and validity of our work, numerical computations have been carried out using our proposed methodology.

**Keywords:** multi-objective transportation problem [MOTP]; fuzzy programming [FP]; Fermatean fuzzy programming [FTP]; score function; Fermatean fuzzy transportation problem [FFTP]

### **1. Introduction**

In the present scenario of highly competitive market dynamics, there is pressure on transportation managers to conduct the smooth transport of goods and services, i.e., transportation problems are concerned with finding a way by which a decision maker can deliver the product from warehouses to a destination at a minimum cost. Transportation models have many applications in supply chain and logistics for reducing costs. In transportation problems, there are mainly three parameters that must be considered to solve the transportation problem, i.e.,


To distribute various goods and services from numerous origins to many termini, Hitchcock [1] proposed the transportation problem [TP]. Classical TP is an unusual type of linear programming problem which is more difficult to explain by the simplex method. So, in the literature, many approaches have been developed to find the initial "basic feasible solution" for classical TP which approaches are "Column Minima", "Method "North-West Corner Rule", "Row Minima Method", "Matrix Minima Method" and "Vogel's Method". In real life, there is a need to challenge optimization in the context of various objectives. So, the decision maker wants to handle various objectives, which may be to minimize cost, time, efficiency, less deterioration of a product and less energy consumption, etc. This type of TP is identified as a multi-objective transportation problem [MOTP]. In a MOTP all the objectives are conflicting in nature and with different scales and units for measurement.

**Citation:** Sharma, M.K.; Kamini; Dhaka, A.; Nandal, A.; Rosales, H.G.; Monteagudo, F.E.L.; Hernández, A.G.; Hoang, V.T. Fermatean Fuzzy Programming with New Score Function: A New Methodology to Multi-Objective Transportation Problems. *Electronics* **2023**, *12*, 277. https://doi.org/10.3390/ electronics12020277

Academic Editor: Cheng-Chi Lee

Received: 30 October 2022 Revised: 10 November 2022 Accepted: 14 November 2022 Published: 5 January 2023

**Copyright:** © 2023 by the authors. Licensee MDPI, Basel, Switzerland. This article is an open access article distributed under the terms and conditions of the Creative Commons Attribution (CC BY) license (https:// creativecommons.org/licenses/by/ 4.0/).

For solving such a type of multi-objective transportation problem, the following methods are used:


Additionally, Lee et al. [2] gave a goal programming method to evaluate the optimal solutions of TP with multi-objectives in multi-dimensional decision-making approaches. In actual life, there are many situations when the available information is not sufficient to judge or formulate the model of the problem. To deal with fuzziness in the real-world era, Zadeh [3] presented the idea of fuzzy set theory. This notion of fuzzy logic is used for the mathematical representation of less knowledgeable or imprecise data by a membership function. Bellman and Zadeh [4] introduced decision-making problems in such an environment in which goals or constraints are not defined precisely. Oheigeartaigh [5] described transportation problems in real-life situations and developed an algorithm for transporting goods from supply nodes to demand nodes in a fuzzy environment. Zimmermann [6] presented that the results attained by fuzzy linear programming are continuously ideal and effectual. When the "cost factors (C)", "supply(S)" and "demand (D)" measures are identified exactly, many procedures have been established for explaining the TP. However, in the current circumstances, there are numerous cases in which the C, S and D measures are fuzzy amounts. A fuzzy Transportation Problem is a Transportation Problem in which the parameters such as Cost, Supply and Demand measures are in the form of fuzzy numbers. Chanas et al. [7] analyzed the Transportation Problem in fuzzy uncertainty and used a parametric technique to find the optimal solution for Transportation Problems. Chanas and Kuchata [8] proposed a procedure for attaining an optimal solution for the TP with fuzzy parameters expressed in terms of fuzzy numbers. Atanassov [9] initiated the notion of Intuitionistic Fuzzy Set (IFS), a generalization of fuzzy sets that incorporates both truth and false grades with a hesitation margin, such that the sum of the degree of truth and false grades is less than or equal to 1. Then Yager [10] introduced the ortho-pair of fuzzy set [FS], where the square of the sum of truth and a false value is less than equal to 1, called the Pythagorean fuzzy set [PFS]. Senapati and Yager [11] presented the notion of Fermatean Fuzzy Sets [FFSs] to deal with the situation where the fuzzy logic fails due to membership grades only. Several researchers developed the model of the TP by either an individual objective TP or a MOTP, considering various fuzzy contexts. In the case of multi-objective transportation problems, Ringuest [12] proposed a collaborative approach for explaining MOTP and attained more than k-dominated and non-dominated solutions. Bit et al. [13] established a new method for MOTP with the help of fuzzy programming using a linear membership function. The weight and priorities of the objectives are involved in the method. Li and Lai [14] introduced a vague compromise programming method to the MOTP. Then, Ammar [15] characterized the optimal solution based on the alpha level of fuzzy numbers and checked the solidity of MOTP with a fuzzy environment. Lau et al. [16] developed a model by using a genetic algorithm to obtain the solution of MOTP. Kocken et al. [17] established a model by using a compensatory fuzzy technique for multi-objective linear transportation problems with triangular fuzzy numbers of parameters. In this technique, he used Zimmermann's "min" operator and fixed the cost-satisfaction intervals and breaking points. Nomani [18] established a new weight approach based on goal programming for MOTP. Ahmad and Adhami [19] considered nonlinearity in MOTP in the neutrosophic situation. A MOTP with a p-facility localization issue was given by Das and Roy [20]. To establish a real transport network, Das et al. [21] incorporated two-fold ambiguity. Ghosh et al. [22] developed a solid MOTP in intuitionistic circumstances. Then, in TP, Ghosh and Roy et al. [23] observed an extra cost which is considered a type-II fixed charge. Midya et al. [24] developed a solid MOTP with a fixed charge in an intuitionistic fuzzy environment. Sahoo [25] originally solved TP in the Fermatean fuzzy situation

and planned a procedure in which the Fermatean fuzzy transportation problem [FFTP] is transformed into a conventional TP to obtain the best optimal solution. Thereafter, Sahoo [26] anticipated various score functions for converting the Fermatean fuzzy data into a crisp form and applied the TOPOSIS technique to solve the MOTP in the Fermatean fuzzy context. Ghosh [27] introduced the latest technology (Preservation technology) for MOTP with preservation cost for the reduction in the rate of deterioration and also for building a more realistic problem; they considered criteria such that all the parameters of the problem are Pythagorean fuzzy sets; Rani and Ebrahimnejad [28] established a new algorithm to tackle the unbalanced fully rough and fixed charge MOTP.

The basic purposes of this work are as follows:


The work conducted in the research paper is categorized into eight segments including the present one dealing with the introduction and review of the work conducted by many authors. The Section 2 of the manuscript includes a basic definition related to the work in research. In the Section 3, we formulated a mathematical model for MOTP in the Fermatean fuzzy situation. Additionally, in the Section 4, we proposed a novel score function for the ranking of the Fermatean fuzzy numbers. The Section 5 includes the development of the model of the proposed Fermatean Fuzzy Programming for MOTP. In the Section 6, we describe the proposed methodology of this in the context of the proposed score function. The Section 7 includes a mathematical illustration to check the effectiveness of our planned methodology. In the last section, we discuss the conclusion of our planned matter.

#### **2. Basic Definitions**

*2.1. Fermatean Fuzzy Set [FFS]*

An FFS can be given as:

$$\widetilde{f} = \left\{ \langle w, \Theta\_{\overline{\mathfrak{f}}}(w), \delta\_{\overline{\mathfrak{f}}}(w) : w \in X \rangle \right\}.$$

where <sup>Σ</sup> is a universal set and <sup>θ</sup><sup>f</sup> (*w*) : X → [0, 1] is the degree of satisfaction and *δ*f (*w*) : X → [0, 1] is the degree of dissatisfaction *w* ∈ *X* and these two are related by the relation;

$$0 \le \left(\theta\_{\overline{f}}(w)\right)^3 + \left(\delta\_{\overline{f}}(w)\right)^3 \le 1 \,\forall \, w \in X.$$

And <sup>σ</sup><sup>f</sup> (*w*) denotes the grade of indeterminacy of *w* ∈ *X*, such that:

$$
\sigma\_{\overline{\mathbf{f}}}(w) = \sqrt[3]{1 - \left(\Theta\_{\overline{\mathbf{f}}}(w)\right)^3 - \left(\delta\_{\overline{\mathbf{f}}}(w)\right)^3}
$$

$$
\text{The set } \overline{\mathbf{f}} = \left\{ \left< w, \Theta\_{\overline{\mathbf{f}}}(w), \delta\_{\overline{\mathbf{f}}}(w) : w \in X \right> \right\} \text{ is denoted as } \overline{\mathbf{f}} = \left< \Theta\_{\overline{\mathbf{f}}'} \delta\_{\overline{\mathbf{f}}} \right>.
$$

#### *2.2. Arithmetic Operation on Fermatean Fuzzy Sets*

Let <sup>f</sup> <sup>=</sup> θ<sup>f</sup> , *δ*f , *f* <sup>1</sup> = *θ <sup>f</sup>* 1 , *δf* 1 and *f* <sup>2</sup> = *θ <sup>f</sup>* 2 , *δf* 2 be three Fermatean fuzzy sets on a universal set *X* and Λ > 0 be any scalar then arithmetic operations on the FFSs are defined such that:

$$\begin{array}{l} \text{(i)}. \quad \bar{f}\_{1} \oplus \bar{f}\_{2} = \langle \theta\_{\bar{f}\_{1}}, \delta\_{\bar{f}\_{1}} \rangle \oplus \langle \theta\_{\bar{f}\_{2}}, \delta\_{\bar{f}\_{2}} \rangle \\\text{(ii)}. \quad = & \sqrt[3]{\left(\theta\_{\bar{f}\_{1}}\right)^{3} + \left(\theta\_{\bar{f}\_{2}}\right)^{3} - \left(\theta\_{\bar{f}\_{1}}\right)^{3} \left(\theta\_{\bar{f}\_{2}}\right)^{3}}, \delta\_{\bar{f}\_{1}} \delta\_{\bar{f}\_{2}} > 0 \\\text{(iii)}. \quad \bar{f}\_{1} \otimes \bar{f}\_{2} = \langle \theta\_{\bar{f}\_{1}}, \delta\_{\bar{f}\_{1}} \rangle \otimes \langle \theta\_{\bar{f}\_{2}}, \delta\_{\bar{f}\_{2}} \rangle \end{array}$$

(iv). = < *θ <sup>f</sup>* 1 *θ f* 2 , 3 *δf* 1 3 + *δf* 2 3 − *δf* 1 3 *δf* 2 3 > (v). Λ *f* <sup>=</sup> <sup>3</sup> <sup>1</sup> <sup>−</sup> (<sup>1</sup> <sup>−</sup> (θ<sup>f</sup> ) 3 ) Λ , (*δ*<sup>f</sup> ) Λ (vi). *f* <sup>Λ</sup> <sup>=</sup> θf Λ, <sup>3</sup> <sup>1</sup> <sup>−</sup> (<sup>1</sup> <sup>−</sup> *δ*f )3 Λ (vii). *f* <sup>1</sup> ∪ *f* <sup>2</sup> = maσ *θ f* 1 , *θ f* 2 , min *δf* 1 , *δf* 2 (viii). *f* <sup>1</sup> ∩ *f* <sup>2</sup> <sup>=</sup> min *θ f* 1 , *θ f* 2 , maσ(*δ<sup>f</sup>* 1 , *δf* 2 ) (ix). *f <sup>c</sup>* <sup>=</sup> *<sup>δ</sup>*<sup>f</sup> , θf . Example Let *f* = 0.5, 0.6, *f* <sup>1</sup> = 0.2, 0.8, *f* <sup>2</sup> = 0.9, 0.4 be three FFSs and Λ = 3 be a scalar, then: (i). *f* 1 *<sup>f</sup>* <sup>2</sup> = 0.2, 0.8 0.9, 0.4 = 0.9009, 0.32 (ii). *f* 1 *f* <sup>2</sup> = 0.2, 0.8 0.9, 0.4 = 0.18, 0.8159 (iii). Λ *f* = 3 0.5, 0.6 = 0.6911, 0.216


#### *2.3. Score Function*

Let *f* be a Fermatean fuzzy set *f* <sup>=</sup> *φF*, *<sup>ψ</sup>F*, then the score function of *<sup>f</sup>* which is denoted by *Sf*(*f* ) and defined as follows:

$$\mathcal{S}\_f(\tilde{f}) = \left(\phi\_{\overline{F}}{}^3 - \psi\_{\overline{F}}{}^3\right).$$

Property: Consider an FFS *<sup>F</sup>* <sup>=</sup> *φF*, *<sup>ψ</sup>F*, then *<sup>S</sup>*<sup>∗</sup> *<sup>F</sup>*(*F*) ∈ [0, 1].

**Proof.** By the definition of an ortho pair, *<sup>φ</sup>F*, *<sup>ψ</sup>F* <sup>∈</sup> [0, 1]. Then, min *<sup>φ</sup>F*, *<sup>ψ</sup>F* ∈ [0, 1].

$$\begin{array}{lcl} \text{Also},\boldsymbol{\phi}\_{\overline{F}}^{3} \ge 0,\ \boldsymbol{\Psi}\_{\overline{F}}^{3} \ge 0,\ \boldsymbol{\Phi}\_{\overline{F}}^{3} \le 1 \text{ and } \boldsymbol{\Psi}\_{\overline{F}}^{3} \le 1\\ \implies 1 - \boldsymbol{\Psi}\_{\overline{F}}^{3} \ge 0\\ \implies 1 + \boldsymbol{\Phi}\_{\overline{F}}^{3} - \boldsymbol{\Psi}\_{\overline{F}}^{3} \ge 0\\ \therefore \ \frac{1}{2} \left(1 + \boldsymbol{\Phi}\_{\overline{F}}^{3} - \boldsymbol{\Psi}\_{\overline{F}}^{3}\right) \cdot \left(\min\left(\boldsymbol{\Phi}\_{\overline{F}}, \boldsymbol{\Psi}\_{\overline{F}}\right)\right) \ge 0\\ \implies \boldsymbol{\Phi}\_{\overline{F}}^{3} - \boldsymbol{\Psi}\_{\overline{F}}^{3} \le 1\\ \implies 1 + \boldsymbol{\Phi}\_{\overline{F}}^{3} - \boldsymbol{\Psi}\_{\overline{F}}^{3} \le 2\left(\ddots \,\boldsymbol{\Phi}\_{\overline{F}}^{3} \ge 0\right)\\ \implies \frac{1 + \boldsymbol{\Phi}\_{\overline{F}}^{3} - \boldsymbol{\Psi}\_{\overline{F}}^{3}}{2} \le 1\\ \implies \frac{1}{2} \left(1 + \boldsymbol{\Phi}\_{\overline{F}}^{3} - \boldsymbol{\Psi}\_{\overline{F}}^{3}\right) \cdot \left(\min\left(\boldsymbol{\Phi}\_{\overline{F}}, \boldsymbol{\Psi}\_{\overline{F}}\right)\right) \cdot 1 \left(\ddots \,\min\left(\boldsymbol{\Phi}\_{\overline{F}}, \boldsymbol{\Psi}\_{\overline{F}}\right) \le 1\right)\\ \text{Hence},\boldsymbol{\mathcal{S}}\_{\boldsymbol{F}}^{\*}(\overline{F}) \in [0, 1]. \end{array}$$

#### *2.4. Accuracy Function*

Let *f* be a Fermatean fuzzy set *f* <sup>=</sup> θ<sup>f</sup> , *δ*f , then the accuracy function of *f* which is denoted by Â*f*(*f* ) and defined as follows:

$$\mathring{\mathcal{A}}\_f(\vec{f}) = \left(\theta\_{\vec{f}}{}^3 + \delta\_{\vec{f}}{}^3\right).$$

#### **3. Mathematical Formulation for Multi-Objective Transportation Problem in a Fermatean Fuzzy Environment**

Let us consider a MOTP with *k* objectives containing *m* supply nodes and *n* demand nodes. Additionally, *a* f *<sup>i</sup>* <sup>=</sup> θ*ai* , *δai* units are available at the *i th* supply node and *b* f *<sup>j</sup>* <sup>=</sup> θ*bj* , *δbj*  units are demanded on *j th* demand node. Suppose *c* f *ij* <sup>=</sup> θ*cij* , *<sup>δ</sup>cij* is the unit fermatean fuzzy transportation cost and the *i th* source node to *j th* demand node and σ*ij* is the number of items that are carried from *i th* source node to *j th* demand node.

The mathematical formulation for the MOTP in the Fermatean Fuzzy Situation is as follows:

min*FK* = ∑*<sup>m</sup> <sup>i</sup>*=<sup>1</sup> <sup>∑</sup>*<sup>n</sup> <sup>j</sup>*=<sup>1</sup> *c* f *ijKwij* , *K* = 1, 2, 3, 4, . . . , *k* s.t. ∑*<sup>n</sup> <sup>j</sup>*=<sup>1</sup> *wij* ≤ *a* f *i* , *i* = 1, 2, 3, . . . , *m* ∑*<sup>m</sup> <sup>i</sup>*=<sup>1</sup> *wij* ≥ *b* f *j* , *j* = 1, 2, 3, . . . , *n* Such that *a* f *i* <sup>=</sup> θ*ai* , *δai* where 0 <sup>≤</sup> θ*ai* <sup>3</sup> + *δai* <sup>3</sup> <sup>≤</sup> 1, *b* f *j* <sup>=</sup> θ*bj* , *δbj* where 0 ≤ θ*bj* 3 + *δbj* 3 ≤ 1 *c* f *ij*<sup>=</sup> θ*cij* , *<sup>δ</sup>cij* where 0 <sup>≤</sup> θ*cij* 3 + *δcij* 3 ≤ 1*wij* ≥ 0, ∀ *i*, *j*

#### **4. Proposed Score Function for Fermatean Fuzzy Sets**

*4.1. Fermatean Fuzzy Score Functions*

Let <sup>f</sup> <sup>=</sup> θ<sup>f</sup> , *δ*f be any Fermatean fuzzy number and then the score function of f which is denoted by *S* f , defined as follows:

$$S\left(\widehat{\mathbf{f}}\right) = \frac{1}{2} \left(1 + \theta\_{\widehat{\mathbf{f}}} - \delta\_{\widehat{\mathbf{f}}}\right) \left(\min\left(\theta\_{\widehat{\mathbf{f}}'} \delta\_{\widehat{\mathbf{f}}}\right)\right)^2$$

#### *4.2. Property*

Let <sup>f</sup> <sup>=</sup> θ<sup>f</sup> , *δ*f be any Fermatean fuzzy set, then the score function of f, *S* f ∈ [0, 1].

**Proof.** By the description of membership as well as non-membership pairs,θ<sup>f</sup> , *<sup>δ</sup>*<sup>f</sup> <sup>∈</sup> [0, 1].

Then, min θf , *δ*f ∈ [0, 1]. Additionally, <sup>θ</sup><sup>f</sup> <sup>≥</sup> 0, *<sup>δ</sup>*<sup>f</sup> <sup>≥</sup> 0 , <sup>θ</sup><sup>f</sup> <sup>≤</sup> 1 and *<sup>δ</sup>*<sup>f</sup> <sup>≤</sup> <sup>1</sup> <sup>=</sup><sup>⇒</sup> <sup>1</sup> <sup>−</sup> *<sup>δ</sup>*<sup>f</sup> <sup>≥</sup> <sup>0</sup> <sup>=</sup><sup>⇒</sup> <sup>1</sup> <sup>+</sup> <sup>θ</sup><sup>f</sup> <sup>−</sup> *<sup>δ</sup>*<sup>f</sup> <sup>≥</sup> <sup>0</sup> ∴ <sup>1</sup> 2 <sup>1</sup> <sup>+</sup> <sup>θ</sup><sup>f</sup> <sup>−</sup> *<sup>δ</sup>*<sup>f</sup> min θf , *δ*f <sup>2</sup> <sup>≥</sup> 0, Hence, *S* f ≥ 0. Again, <sup>θ</sup><sup>f</sup> <sup>≤</sup> 1 and *<sup>δ</sup>*<sup>f</sup> <sup>≤</sup> <sup>1</sup> <sup>=</sup><sup>⇒</sup> <sup>θ</sup><sup>f</sup> <sup>−</sup> *<sup>δ</sup>*<sup>f</sup> <sup>≤</sup> <sup>1</sup> <sup>=</sup><sup>⇒</sup> <sup>1</sup> <sup>+</sup> <sup>θ</sup><sup>f</sup> <sup>−</sup> *<sup>δ</sup>*<sup>f</sup> <sup>≤</sup> <sup>1</sup> <sup>+</sup> <sup>1</sup> <sup>=</sup> <sup>2</sup> <sup>=</sup><sup>⇒</sup> <sup>1</sup>+θ<sup>f</sup> <sup>−</sup>*δ*<sup>f</sup> <sup>2</sup> ≤ 1 and min θf , *δ*f <sup>≤</sup> <sup>1</sup> <sup>=</sup><sup>⇒</sup> min θf , *δ*f <sup>2</sup> <sup>≤</sup> 1, <sup>=</sup><sup>⇒</sup> <sup>1</sup> 2 <sup>1</sup> <sup>+</sup> <sup>θ</sup><sup>f</sup> <sup>−</sup> *<sup>δ</sup>*<sup>f</sup> (min θf , *δ*f ) <sup>2</sup> <sup>≤</sup> 1, Hence, *S* f ≤ 1 =⇒ *S* f ∈ [0, 1].

#### **5. Proposed Model for Fermatean Fuzzy Programming**

Senapati and Yager [11] introduced the extension of an intuitionistic and Pythagorean fuzzy set when the sum of truth and false grades with the sum of the square of truth grade and false grade is greater than 1, but the sum of the cube of truth grade and false grade is less than equal to the 1. These fuzzy sets are FFS. FFS are more realistic and handle more uncertainty than Intuitionistic and Pythagorean fuzzy sets. Due to such an environment, Zimmermann [6] introduced a fuzzy programming approach for multi-objective decisionmaking problems based on a min-max operator. In this approach, we can use linear, exponential, or hyperbolic truth functions to attain compromised optimal solutions to the problems. Then, intuitionistic fuzzy programming is also developd for multi-objective problems in an intuitionistic fuzzy environment in which the truth and false grades may be linear, exponential, or hyperbolic functions. After that, Pythagorean fuzzy programming is also developed for such problems in a fuzzy environment. Now, we introduce non-linear programming, Fermatean Fuzzy Programming, to obtain a compromise optimal solution to

all objectives simultaneously of multi-objective decision-making problems in a Fermatean fuzzy environment and in any other environment, defined as:

Let U*<sup>k</sup>* and L*<sup>k</sup>* be upper and lower bounds, respectively, for objective *Fk*(*w*) of the problem and *μ*(*Fk*(*w*)) be membership function for objective *Fk*(*w*) and *ϑ*(*F*2(*w*)) be nonmembership function for the objective function *Fk*(*w*). Then, the proposed model for Fermatean Fuzzy Programming is as follows:

$$\text{Max.} \text{ } \sigma \text{ } \gamma\_1^3 - \gamma\_2^3$$

where

$$\mu(F\_k(w)) \stackrel{3}{\geq} \geq \gamma\_1^{\cdot 3} \; \vee \; k\theta(F\_k(w)) \; ^{\cdot 3} \leq \gamma\_2^{\cdot 3} \; \vee \; k$$

where

$$\mu(F\_k(w)) = \begin{cases} 1, & \text{if } F\_k(w) \le L\_k \\ \frac{\mathcal{U}\_k - F\_k(w)}{\mathcal{U}\_k - L\_k}, & \text{if } L\_k \le F\_k(\sigma) \le \mathcal{U}\_k \\ 0, & \text{if } F\_k(w) \ge \mathcal{U}\_k \end{cases}$$

and

$$\theta(F\_k(w)) = \begin{cases} 0, & \text{if } F\_k(w) \le L\_k \\ \frac{F\_k(w) - L\_k}{\mathcal{U}\_k - L\_k}, & \text{if } L\_k \le F\_k(\sigma) \le \mathcal{U}\_k \\ 1, & \text{if } F\_k(w) \ge \mathcal{U}\_k \end{cases}$$

i.e., *Uk* − *Fk*(*w*))<sup>3</sup> ≥ <sup>d</sup>*<sup>k</sup>* <sup>3</sup>γ<sup>1</sup> 3 *Fk*(*w*) − *Lk*)<sup>3</sup> ≤ <sup>d</sup>*<sup>k</sup>* <sup>3</sup>γ<sup>2</sup> 3, where *dk* = *Uk* − *Lk*. with respect to the constraints,

> *w*<sup>11</sup> + *w*<sup>12</sup> + ......... + *w*1*<sup>n</sup>* ≤ *a*<sup>1</sup> *w*<sup>21</sup> + *w*<sup>22</sup> + ......... + *w*2*<sup>n</sup>* ≤ *a*<sup>2</sup> . . *wm*<sup>1</sup> + *wm*<sup>2</sup> + ......... + *wmn* ≤ *am w*<sup>11</sup> + *w*<sup>21</sup> + ......... + *wm*<sup>1</sup> ≥ *b*1, *w*<sup>12</sup> + *w*<sup>22</sup> + ......... + *wm*<sup>2</sup> ≥ *b*2, . . *w*1*<sup>n</sup>* + *w*2*<sup>n</sup>* + ......... + *wmn* ≥ *bn*

$$\text{and } \sum\_{i=1}^{m} \mathbf{a}\_i = \sum\_{j=1}^{n} \mathbf{b}\_j \text{ , } w\_{ij} \ge 0 \text{, } 0 \le \boldsymbol{\gamma}\_1 \text{ , } \boldsymbol{\gamma}\_2 \overset{3}{\le} 1 \text{, } 0 \le \boldsymbol{\gamma}\_1 \overset{3}{\le} + \boldsymbol{\gamma}\_2 \overset{3}{\le} 1 \text{ and } \boldsymbol{\gamma}\_1 \overset{3}{\ge} \boldsymbol{\gamma}\_2 \overset{3}{\le} 1$$

#### **6. Proposed Methodology**

To handle the MOTP in the Fermatean fuzzy environment, we propose a methodology. The steps involved in the proposed methodology are depicted as follows:

Step 1: First, consider a MOTP in Fermatean fuzzy uncertainty such as:

\* min\$\! $F\_{\mathcal{K}} = \sum\_{i=1}^{m} \sum\_{j=1}^{n} c\_{ijK}^{\mathbf{f}} w\_{ij}$ .  $K = 1$ , 2, 3, 4, ..., kN.  $\text{ s.t.}$   $\sum\_{j=1}^{n} w\_{ij} \le a\_{\mathbf{i}'}^{\mathbf{f}}$ ,  $i = 1$ , 2, 3, ...,  $m$ 
\*  $\sum\_{i=1}^{m} w\_{ij} \ge b\_{\mathbf{j}'}^{\mathbf{f}}$ ,  $j = 1$ , 2, 3, ...,  $m$ 

$$\begin{array}{c} \text{Such that} \quad a\_{\bar{i}}^{\bar{t}} = \langle \theta\_{\bar{a}\_{i'}}, \delta\_{\bar{a}\_{i}} \rangle \text{ where } 0 \le \left(\theta\_{\bar{a}\_{\bar{i}}}\right)^{3} + \left(\delta\_{\bar{a}\_{\bar{i}}}\right)^{3} \le 1, \\\ b\_{\bar{j}}^{\bar{t}} = \langle \theta\_{\bar{b}\_{\bar{j}'}}, \delta\_{\bar{b}\_{\bar{j}}} \rangle \text{ where } 0 \le \left(\theta\_{\bar{b}\_{\bar{j}}}\right)^{3} + \left(\delta\_{\bar{b}\_{\bar{j}}}\right)^{3} \le 1, \\\ c\_{\bar{ij}}^{\bar{t}} = \langle \theta\_{\bar{c}\_{ij'}}, \delta\_{\bar{c}\_{ij}} \rangle \text{ where } 0 \le \left(\theta\_{\bar{c}\_{\bar{i}j}}\right)^{3} + \left(\delta\_{\bar{c}\_{\bar{i}j}}\right)^{3} \le 1, \\\ w\_{ij} \ge 0, \forall \ i, j \end{array}$$

Step 2: Then, convert the Fermatean fuzzy data into crisp data by using the proposed score function for Fermatean fuzzy sets as:

\* min\${}\_{K}\${}\_{K}\${}\_{K}\${}\_{i=1}\$\sum\_{j=1}^{n}\${S(cf\${}\_{ijK}\$)}\$w\_{ij}\$ {}\_{i}\${}^{K}\$ {}\_{i}\$ {}\_{i}\$ {}\_{i}\$ {}\_{i}\$ {}\_{i}\$ {}\_{i}\$ {}\_{i}\$ {}\_{i}\$ {}\_{i}\$ {}\_{i}\$ {}\_{i}\$ {}\_{i}\$ {}\_{i}\$ {}\_{i}\$ {}\_{i}\$ {}\_{i}\$ {}\_{i}\$ {}\_{i}\$ {}\_{i}\$ {}\_{i}\$ {}\_{i}\$ {}\_{i}\$ {}\_{i}\$ {}\_{i}\$ {}\_{i}\$ {}\_{i}\$ {}\_{i}\$ {}\_{i}\$ {}\_{i}\$ {}\_{i}\$ {}\_{i}\$ {}\_{i}\$ {}\_{i}\$ {}\_{i}\$ {}\_{i}\$ {}\_{i}\$ {}\_{i}\$ {}\_{i}\$ {}\_{i}\$ {}\_{i}\$ {}\_{i}\$ {}\_{i}\$ {}\_{i}\$ {}\_{i}\$ {}\_{i}\$

Step 3: Now, solve this problem for all objectives by taking one objective at a time. We obtain basic feasible solutions for all objectives.

Step 4: Now, we built a pay-off matrix for all objectives then we obtain upper bound U*<sup>k</sup>* and lower bound L*<sup>k</sup>* for the objective *Fk*(*w*) through pay-off matriσ.

$$\begin{array}{|c|c|}\hline\hline\hline\hline\hline\hline\hline\hline\hline\hline\hline\hline\hline\hline\hline\hline\end{array}$$

Step 5: Then, built a model of the problem by proposed Fermatean fuzzy programming approach and then solve this model by Lingo 19.0 software. The architecture is shown in Figure 1. The numerical computations of the proposed technique are give in Tables 1–3.

**Figure 1.** Structure outlines of proposed algorithm.

#### **7. Numerical Computations**

We consider a MOTP in Fermatean fuzzy uncertainty where all the variables of the problem are Fermatean fuzzy numbers. This is as follows:


**Table 1.** First Objective Cost.

**Table 2.** Second Objective Cost.



Step 2: Now, Convert Fermatean fuzzy parameters into crisp form by applying the proposed score function as:

Supply:

$$\begin{array}{rcl} S(\mathbf{a}\_{\overline{f\_1}}) &= S(\langle 0.6, 0.4 \rangle) &=& \frac{1}{2}(1 + 0.6 - 0.4)(\min(0.6, 0.4))^2 \\ &=& \frac{1}{2}(1.2)(0.4)^2 = 0.6 \times 0.16 = 0.096 \\ S(\mathbf{a}\_{\overline{f\_2}}) &= S(\langle 0.3, 0.5 \rangle) &=& \frac{1}{2}(1 + 0.3 - 0.5)(\min(0.3, 0.5))^2 \\ &=& \frac{1}{2}(0.8)(0.3)^2 = 0.4 \times 0.09 = 0.036 \\ S(\mathbf{a}\_{\overline{f\_3}}) &= S(\langle 0.4, 0.8 \rangle) &=& \frac{1}{2}(1 + 0.4 - 0.8)(\min(0.4, 0.8))^2 \\ &=& \frac{1}{2}(0.6)(0.4)^2 = 0.3 \times 0.16 = 0.048 \end{array}$$

Demand:

$$\begin{array}{rcl} S(\mathbf{b}\_{\overline{f}\_1}) = S(\langle 0.2, 0.5 \rangle) &=& \frac{1}{2}(1 + 0.2 - 0.5)(\min(0.2, 0.5))^2 \\ S(\mathbf{b}\_{\overline{f}\_2}) = S(\langle 0.4, 0.7 \rangle) &=& \frac{1}{2}(0.7)(0.2)^2 = 0.35 \times 0.04 = 0.014 \\ S(\mathbf{b}\_{\overline{f}\_2}) = S(\langle 0.4, 0.7 \rangle) &=& \frac{1}{2}(1 + 0.4 - 0.7)(\min(0.4, 0.7))^2 \\ S(\mathbf{b}\_{\overline{f}\_3}) = S(\langle 0.6, 0.4 \rangle) &=& \frac{1}{2}(1 + 0.6 - 0.4)(\min(0.6, 0.4))^2 \\ &=& \frac{1}{2}(1 + 0.6 - 0.4)(\min(0.6, 0.4))^2 \\ S(\mathbf{b}\_{\overline{f}\_4}) = S\left(\langle 0.2, 0.5 \rangle\right) &=& \frac{1}{2}(1 + 0.2 - 0.5)(\min(0.2, 0.5))^2 \\ &=& \frac{1}{2}(0.7)(0.2)^2 = 0.35 \times 0.04 = 0.014 \end{array}$$

Since *<sup>m</sup>* ∑ *i*=1 *S*(a*<sup>f</sup> i* ) <sup>=</sup> *<sup>n</sup>* ∑ *j*=1 *S*(b*<sup>f</sup> j* ) , then the problem is a balanced multi-objective transportation problem.

Costs:

First objective costs:

*S*(c*<sup>f</sup>* <sup>11</sup> )= *<sup>S</sup>*(0.2, 0.5) = <sup>1</sup> <sup>2</sup> (<sup>1</sup> <sup>+</sup> 0.2 <sup>−</sup> 0.5)(min(0.2, 0.5))<sup>2</sup> = <sup>1</sup> <sup>2</sup> (0.7)(0.2) 2 = 0.35 × 0.04 = 0.014 *S*(c*<sup>f</sup>* <sup>12</sup> ) = *<sup>S</sup>*(0.2, 0.4) = <sup>1</sup> <sup>2</sup> (<sup>1</sup> <sup>+</sup> 0.2 <sup>−</sup> 0.4)(min(0.4, 0.2))<sup>2</sup> = <sup>1</sup> <sup>2</sup> (0.8)(0.2) 2 = 0.4 × 0.04 = 0.016 *S*(c*<sup>f</sup>* <sup>13</sup> ) = *<sup>S</sup>*(0.3, 0.7) = <sup>1</sup> <sup>2</sup> (<sup>1</sup> <sup>+</sup> 0.3 <sup>−</sup> 0.7)(min(0.3, 0.7))<sup>2</sup> = <sup>1</sup> <sup>2</sup> (0.6)(0.3) 2 = 0.3 × 0.09 = 0.027 *S*(c*<sup>f</sup>* <sup>14</sup> ) = *<sup>S</sup>*(0.6, 0.5) = <sup>1</sup> <sup>2</sup> (<sup>1</sup> <sup>+</sup> 0.6 <sup>−</sup> 0.5)(min(0.6, 0.5))<sup>2</sup> = <sup>1</sup> <sup>2</sup> (1.1)(0.5) 2 = 0.55 × 0.25 = 0.1375 *S*(c*<sup>f</sup>* <sup>21</sup> ) = *<sup>S</sup>*(0.3, 0.5) = <sup>1</sup> <sup>2</sup> (<sup>1</sup> <sup>+</sup> 0.3 <sup>−</sup> 0.5)(min(0.3, 0.5))<sup>2</sup> = <sup>1</sup> <sup>2</sup> (0.8)(0.3) 2 = 0.4 × 0.09 = 0.036 *S*(c*<sup>f</sup>* <sup>22</sup> ) = *<sup>S</sup>*(0.2, 0.9) = <sup>1</sup> <sup>2</sup> (<sup>1</sup> <sup>+</sup> 0.2 <sup>−</sup> 0.9)(min(0.2, 0.9))<sup>2</sup> = <sup>1</sup> <sup>2</sup> (0.3)(0.2) 2 = 0.15 × 0.04 = 0.006 *S*(c*<sup>f</sup>* <sup>23</sup> ) = *<sup>S</sup>*(0.7, 0.1) = <sup>1</sup> <sup>2</sup> (<sup>1</sup> <sup>+</sup> 0.7 <sup>−</sup> 0.1)(min(0.7, 0.1))<sup>2</sup> = <sup>1</sup> <sup>2</sup> (1.6)(0.1) 2 = 0.8 × 0.01 = 0.008 *S*(c*<sup>f</sup>* <sup>24</sup> ) = *<sup>S</sup>*(0.4, 0.7) = <sup>1</sup> <sup>2</sup> (<sup>1</sup> <sup>+</sup> 0.4 <sup>−</sup> 0.7)(min(0.4, 0.7))<sup>2</sup> = <sup>1</sup> <sup>2</sup> (0.7)(0.4) 2 = 0.35 × 0.16 = 0.0056 *S*(c*<sup>f</sup>* <sup>31</sup> ) = *<sup>S</sup>*(0.7, 0.1) = <sup>1</sup> <sup>2</sup> (<sup>1</sup> <sup>+</sup> 0.7 <sup>−</sup> 0.1)(min(0.7, 0.1))<sup>2</sup> = <sup>1</sup> <sup>2</sup> (1.6)(0.1) 2 = 0.8 × 0.01 = 0.008 *S*(c*<sup>f</sup>* <sup>32</sup> ) = *<sup>S</sup>*(0.2, 0.3) = <sup>1</sup> <sup>2</sup> (<sup>1</sup> <sup>+</sup> 0.2 <sup>−</sup> 0.3)(min(0.2, 0.3))<sup>2</sup> = <sup>1</sup> <sup>2</sup> (0.9)(0.2) 2 = 0.45 × 0.04 = 0.018 *S*(c*<sup>f</sup>* <sup>33</sup> ) = *<sup>S</sup>*(0.5, 0.1) = <sup>1</sup> <sup>2</sup> (<sup>1</sup> <sup>+</sup> 0.5 <sup>−</sup> 0.1)(min(0.5, 0.1))<sup>2</sup> = <sup>1</sup> <sup>2</sup> (1.4)(0.1) 2 = 0.70 × 0.01 = 0.007 *S*(c*<sup>f</sup>* <sup>34</sup> ) = *<sup>S</sup>*(0.6, 0.4) = <sup>1</sup> <sup>2</sup> (<sup>1</sup> <sup>+</sup> 0.6 <sup>−</sup> 0.4)(min(0.6, 0.4))<sup>2</sup> = <sup>1</sup> <sup>2</sup> (1.2)(0.4) 2 = 0.6 × 0.16 = 0.096

Second objective costs:

$$\begin{array}{ll} S(\mathbf{c}\_{\overline{f}\_{11}}) = 0.056, & S(\mathbf{c}\_{\overline{f}\_{12}}) = 0.027\\ S(\mathbf{c}\_{\overline{f}\_{13}}) = 0.1375, & S(\mathbf{c}\_{\overline{f}\_{14}}) = 0.008\\ S(\mathbf{c}\_{\overline{f}\_{21}}) = 0.036, & S(\mathbf{c}\_{\overline{f}\_{22}}) = 0.048\\ S(\mathbf{c}\_{\overline{f}\_{23}}) = 0.027, & S(\mathbf{c}\_{\overline{f}\_{24}}) = 0.028\\ S(\mathbf{c}\_{\overline{f}\_{31}}) = 0.056, & S(\mathbf{c}\_{\overline{f}\_{32}}) = 0.064\\ S(\mathbf{c}\_{\overline{f}\_{33}}) = 0.006, & S(\mathbf{c}\_{\overline{f}\_{34}}) = 0.007\\ \end{array}$$

Third objective costs:

*S*(c*<sup>f</sup>* <sup>11</sup> ) = 0.018, *<sup>S</sup>*(c*<sup>f</sup>* <sup>12</sup> ) = 0.162 *S*(c*<sup>f</sup>* <sup>13</sup> ) = 0.027, *<sup>S</sup>*(c*<sup>f</sup>* <sup>14</sup> ) = 0.1375 *S*(c*<sup>f</sup>* <sup>21</sup> ) = 0.008, *<sup>S</sup>*(c*<sup>f</sup>* <sup>22</sup> ) = 0.006 *S*(c*<sup>f</sup>* <sup>23</sup> ) = 0.162, *<sup>S</sup>*(c*<sup>f</sup>* <sup>24</sup> ) = 0.056 *S*(c*<sup>f</sup>* <sup>31</sup> ) = 0.112, *<sup>S</sup>*(c*<sup>f</sup>* <sup>32</sup> ) = 0.027 *S*(c*<sup>f</sup>* <sup>33</sup> ) = 0.014, *<sup>S</sup>*(c*<sup>f</sup>* <sup>34</sup> ) = 0.008

Step 3: We obtain three individual transportation problems. Then, we solve these three transportation problems and obtain the basic feasible or optimal solutions for all problems. For the first objective transportation problem:

$$F\_1(w) = 0.014w\_{11} + 0.016w\_{12} + 0.027w\_{13} + 0.1375w\_{14} + 0.036w\_{21} + 0.006w\_{22} + 0.008w\_{23} + 0.008w\_{31} + 0.008w\_{32} + 0.008w\_{33} + 0.008w\_{34}$$

subject to the constraints:

$$\begin{array}{l}w\_{11} + w\_{12} + w\_{13} + w\_{14} \leq 0.096, \\ w\_{21} + w\_{22} + w\_{23} + w\_{24} \leq 0.036, \\ w\_{31} + w\_{32} + w\_{33} + w\_{34} \leq 0.048, \\ w\_{11} + w\_{21} + w\_{31} \geq 0.014, \\ w\_{12} + w\_{22} + w\_{32} \geq 0.056, \\ w\_{13} + w\_{23} + w\_{33} \geq 0.096, \\ w\_{14} + w\_{24} + w\_{34} \geq 0.014, \\ w\_{ij} \geq 0, \sum\_{i=1}^{m} \mathbf{a}\_{i} = \sum\_{j=1}^{n} \mathbf{b}\_{j} \end{array}$$

After solving this problem, we obtain the optimal solution as follows:

 $F\_1 = 0.003090$ ,  $w\_{11} = 0.014$ ,  $w\_{12} = 0.056$ ,  $w\_{13} = 0.026$ ,  $w\_{14} = 0$ ,  $w\_{21} = 0$ ,  $w\_{22} = 0$ ,  $w\_{23} = 0.022$ ,  $w\_{24} = 0.014$ ,  $w\_{31} = 0$ ,  $w\_{32} = 0$ ,  $w\_{33} = 0.048$ ,  $w\_{34} = 0$ .

For the second objective transportation problem:

$$F\_2(w) = 0.0056w\_{11} + 0.027w\_{12} + 0.1375w\_{13} + 0.008w\_{14} + 0.036w\_{21} + 0.048w\_{22} + 0.027w\_{23} + 0.028w\_{24} + 0.056w\_{31} + 0.006w\_{32} + 0.007w\_{34}$$

Subject to the constraints:

$$\begin{array}{l}w\_{11} + w\_{12} + w\_{13} + w\_{14} \leq 0.096, \\ w\_{21} + w\_{22} + w\_{23} + w\_{24} \leq 0.036, \\ w\_{31} + w\_{32} + w\_{33} + w\_{34} \leq 0.048, \\ w\_{11} + w\_{21} + w\_{31} \geq 0.014, \\ w\_{12} + w\_{22} + w\_{32} \geq 0.056, \\ w\_{13} + w\_{23} + w\_{33} \geq 0.096, \\ w\_{14} + w\_{24} + w\_{34} \geq 0.014, \\ w\_{ij} \geq 0 \\ \sum\_{i=1}^{n} \mathbf{a}\_{i} = \sum\_{j=1}^{n} \mathbf{b}\_{j}. \end{array}$$

After solving this problem, we obtain the optimal solution as follows:

*F*<sup>2</sup> = 0.005318, *w*<sup>11</sup> = 0.014, *w*<sup>12</sup> = 0.056, *w*<sup>13</sup> = 0.012, *w*<sup>14</sup> = 0.014, *w*<sup>21</sup> = 0, *w*<sup>22</sup> = 0, *w*<sup>23</sup> = 0.036, *w*<sup>24</sup> = 0, *w*<sup>31</sup> = 0, *w*<sup>32</sup> = 0, *w*<sup>33</sup> = 0.048, *w*<sup>34</sup> = 0.

For the third objective transportation problem:

*F*3(*w*) = 0.018*w*<sup>11</sup> + 0.162*w*<sup>12</sup> + 0.027*w*<sup>13</sup> + 0.1375*w*<sup>14</sup> + 0.008*w*<sup>21</sup> + 0.006*w*<sup>22</sup> + 0.162*w*<sup>23</sup> + 0.056*w*<sup>24</sup> + 0.112*w*<sup>31</sup> + 0.027*w*<sup>32</sup> + 0.014*w*<sup>33</sup> + 0.008*w*<sup>34</sup>

Subject to the constraints:

*w*<sup>11</sup> + *w*<sup>12</sup> + *w*<sup>13</sup> + *w*<sup>14</sup> ≤ 0.096, *w*<sup>21</sup> + *w*<sup>22</sup> + *w*<sup>23</sup> + *w*<sup>24</sup> ≤ 0.036, *w*<sup>31</sup> + *w*<sup>32</sup> + *w*<sup>33</sup> + *w*<sup>34</sup> ≤ 0.048, *w*<sup>11</sup> + *w*<sup>21</sup> + *w*<sup>31</sup> ≥ 0.014, *w*<sup>12</sup> + *w*<sup>22</sup> + *w*<sup>32</sup> ≥ 0.056, *w*<sup>13</sup> + *w*<sup>23</sup> + *w*<sup>33</sup> ≥ 0.096, *w*<sup>14</sup> + *w*<sup>24</sup> + *w*<sup>34</sup> ≥ 0.014, *wij* <sup>≥</sup> <sup>0</sup> *<sup>m</sup>* ∑ *i*=1 <sup>a</sup>*<sup>i</sup>* <sup>=</sup> *<sup>n</sup>* ∑ *j*=1 b*<sup>j</sup>* .

After solving this problem, we obtain the optimal solution as follows:

 $F\_3 = 0.00353$ ,  $w\_{11} = 0.014$ ,  $w\_{12} = 0$ ,  $w\_{13} = 0.082$ ,  $w\_{14} = 0$ ,  $w\_{21} = 0$ ,  $w\_{22} = 0.036$ ,  $w\_{23} = 0$ ,  $w\_{24} = 0$ ,  $w\_{31} = 0$ ,  $w\_{32} = 0.02$ ,  $w\_{33} = 0.014$ ,  $w\_{34} = 0$ .

Step 4: After obtaining the solutions for all objectives individually, we obtain the pay off matrix such that:


So, we can find the upper and lower bounds for all three objectives which are as follows:

L1 = 0.003090, U1 = 0.004428, d1 = 0.001338. L2 = 0.005318, U2 = 0.015249, d2 L3 = 0.003530, U3 = 0.0180277, d3 = 0.014547

Step 5: Now, model for the problem by using the proposed Fermatean fuzzy programming:

$$\text{Max } \gamma\_1^{\;3} - \gamma\_2^{\;3}$$

where

.

$$\begin{array}{l} \mu(F\_k(w)) \stackrel{3}{\rightarrow} \mathcal{Y}\_1^{\;3} \; \forall \; k\\ \theta(F\_k(w)) \stackrel{3}{\rightarrow} \mathcal{Y}\_2^{\;3} \; \forall \; k \end{array}$$

i.e., *Uk* − *Fk*(*w*))<sup>3</sup> ≥ <sup>d</sup>*<sup>k</sup>* <sup>3</sup>γ<sup>1</sup> 3, ∀ *<sup>k</sup>*

$$\begin{array}{l} \implies \left(0.004428 - F\_1\right)^3 \ge 0.00000000239534646\gamma\_1^3\text{.}\\ \implies \left(0.015249 - F\_2\right)^3 \ge 0.000000979442501\gamma\_1^3\text{.}\\ \implies \left(0.018077 - F\_3\right)^3 \ge 0.00000307836645\gamma\_1^3\text{.} \end{array}$$

Again *Fk*(*w*) − *Lk*)<sup>3</sup> ≤ <sup>d</sup>*<sup>k</sup>* <sup>3</sup>γ<sup>2</sup> 3, ∀ *<sup>k</sup>*

$$\begin{array}{l} \implies \left(F\_1 - 0.003090\right)^3 \le 0.00000000239534646\gamma\_2^3,\\ \implies \left(F\_2 - 0.005318\right)^3 \le 0.000000979442501\gamma\_2^3,\\ \implies \left(F\_3 - 0.003530\right)^3 \le 0.00000307836645\gamma\_2^3. \end{array}$$

with respect to the constraints,

```
w11 + w12 + ......... + w1n ≤ a1,
w21 + w22 + ......... + w2n ≤ a2,
.
.
.
wm1 + wm2 + ......... + wmn ≤ am,
w11 + w21 + ......... + wm1 ≥ b1,
w12 + w22 + ......... + wm2 ≥ b2,
.
.
.
w1n + w2n + ......... + wmn ≥ bn
```
and *<sup>m</sup>* ∑ *i*=1 <sup>a</sup>*<sup>i</sup>* <sup>=</sup> *<sup>n</sup>* ∑ *j*=1 b*<sup>j</sup>* , *wij* ≥ 0, 0≤ γ<sup>1</sup> 3, γ<sup>2</sup> <sup>3</sup> ≤ 1, 0≤ <sup>γ</sup><sup>1</sup> <sup>3</sup> + γ<sup>2</sup> <sup>3</sup> ≤ 1 and <sup>γ</sup><sup>1</sup> <sup>3</sup> ≥ <sup>γ</sup><sup>2</sup> 3.

Then, by solving this model with the help of Lingo 19.0, we obtain the optimal solution such that:

γ<sup>1</sup> = 1, γ<sup>2</sup> = 0.004699874, *F*<sup>1</sup> = 0.004400256, *F*<sup>2</sup> = 0.01528566, *F*<sup>3</sup> = 0.003598369, *w*<sup>11</sup> = 0.014, *w*<sup>12</sup> = 0, *w*<sup>13</sup> = 0.082, *w*<sup>14</sup> = 0, *w*<sup>21</sup> = 0, *w*<sup>22</sup> = 0.0350091, *w*<sup>23</sup> = 0, *w*<sup>24</sup> = 0.000990856, *w*<sup>31</sup> = 0, *w*<sup>32</sup> = 0, *w*<sup>33</sup> = 0.014, *w*<sup>34</sup> = 0.0130091.

#### **8. Conclusions**

There are many approaches to convert fuzzy data into crisp data and many methods are introduced for the extension of fuzzy, i.e., for intuitionistic data, Pythagorean data, Fermatean data and other uncertain data. In this paper, we establish a new score function for the ranking of Fermatean Fuzzy numbers which helps to handle the Fermatean fuzzy uncertainty in a crisp environment. Then, we introduce a Fermatean Fuzzy Programming approach for Multi-objective Decision-Making Problems under uncertainty. Fermatean Fuzzy programming is non-linear programming for multi-objective problems which is an extension of Pythagorean Fuzzy Programming. With the proposed Fermatean Fuzzy programming approach, we built a model of a MOTP and solved a numerical illustration of a MOTP in the Fermatean Fuzzy environment. We found that our proposed approach is fruitful in finding a compromise optimal solution for Multi-objective Decision-Making Problems. Therefore, we can say that our proposed methodology is an alternate way for solving multi-objective decision-making problems in the Fermatean fuzzy environment and we can also use the proposed Fermatean fuzzy programming approach to solve multiobjective decision-making problems in any other fuzzy environment. In the end, for future perspectives, we will enhance the technique of our type-2 fuzzy logic and develop a model to handle many engineering and medical areas.

**Author Contributions:** Conceptualization, M.K.S. and K.; methodology, M.K.S. and K.; software, A.D., F.E.L.M. and A.G.H.; validation, A.N.; formal analysis, A.N. and M.K.S.; investigation, M.K.S., K. and A.D.; resources, H.G.R., V.T.H., A.D., F.E.L.M. and A.G.H.; data curation, M.K.S. and K.; writing—original draft preparation, M.K.S.; writing—review and editing, A.N., H.G.R., V.T.H., A.D. and A.G.H.; visualization, H.G.R. and A.G.H.; project administration, A.D.; funding acquisition, H.G.R., F.E.L.M. and A.G.H. All authors have read and agreed to the published version of the manuscript.

**Funding:** This research received no external funding.

**Acknowledgments:** This work is supported by Universidad Autonoma de Zacatecas, Mexico and CONACyT, Mexico.

**Conflicts of Interest:** The authors declare no conflict of interest.

#### **References**


**Disclaimer/Publisher's Note:** The statements, opinions and data contained in all publications are solely those of the individual author(s) and contributor(s) and not of MDPI and/or the editor(s). MDPI and/or the editor(s) disclaim responsibility for any injury to people or property resulting from any ideas, methods, instructions or products referred to in the content.

**Weishi Sun, Yaning Zhang, Jie Li, Chenxing Sun and Shuzhuang Zhang \***

School of Computer Science, Beijing University of Posts and Telecommunications, Beijing 100083, China **\*** Correspondence: zhangshuzhuang@bupt.edu.cn

**Abstract:** Network traffic classification has great significance for network security, network management and other fields. However, in recent years, the use of VPN and TLS encryption had presented network traffic classification with new challenges. Due to the great performances of deep learning in image recognition, many solutions have focused on the deep learning-based method and achieved positive results. A traffic classification method based on deep learning is provided in this paper, where the concept of Packet Block is proposed, which is the aggregation of continuous packets in the same direction. The features of Packet Block are extracted from network traffic, and then transformed into images. Finally, convolutional neural networks are used to identify the application type of network traffic. The experiment is conducted using captured OpenVPN dataset and public ISCX-Tor dataset. The results shows that the accuracy is 97.20% in OpenVPN dataset and 93.31% in ISCX-Tor dataset, which is higher than the state-of-the-art methods. This suggests that our approach has the ability to meet the challenges of VPN and TLS encryption.

**Keywords:** deep learning; VPN traffic classification; image recognize

#### **1. Introduction**

In recent years, the study of network traffic classification has become a popular research topic [1–3]. It plays an important role in achieving better quality of service (QoS), network security, and network monitoring [4,5]. On the one hand, traffic classification can optimize network resource allocation for better QoS. For example, the network administrators of enterprises or campuses can observe the distribution of network traffic through network traffic classification, and then formulate some new strategies to improve the network efficiency of resource utilization [6]. On the other hand, in terms of network security, traffic classification is usually the first step in some network monitoring activities such as malicious detection [7]. Therefore, the network traffic classification has always been an indispensable study in network management and supervision.

Virtual private network (VPN) is a technology that establishes a private network on a public network to access Intranet resources remotely. It is simple to deploy and use. According to the statistics of the 45 largest VPN apps in Google Play and IOS App Store, VPN downloads have exceeded 134 million times all over the world [8], which means that using the VPN can be seen everywhere in the network. In addition, with increasing awareness surrounding users' privacy protection [9], encryption technology has seen rapid growth and is used more and more widely. Many VPNs use TLS encryption to enhance communication security and protect users' privacy. The use of VPN, especially VPN in a TLS tunnel, poses a great challenge to traffic classification. VPN encapsulates the original traffic and hides some information of the original message. Furthermore, if VPN uses TLS tunnels for encryption, TLS tunnels will further group and randomize the payload of traffic. This can result in the traffic payload hardly contributing to VPN traffic classification. Therefore, VPN and TLS encryption will affect some traditional traffic classification technologies, such as deep packet inspection (DPI) [10] and port detection [11].

**Citation:** Sun, W.; Zhang, Y.; Li, J.; Sun, C.; Zhang, S. A Deep Learning-Based Encrypted VPN Traffic Classification Method Using Packet Block Image. *Electronics* **2023**, *12*, 115. https://doi.org/10.3390/ electronics12010115

Academic Editor: Yanhui Guo

Received: 15 October 2022 Revised: 23 November 2022 Accepted: 21 December 2022 Published: 27 December 2022

**Copyright:** © 2022 by the authors. Licensee MDPI, Basel, Switzerland. This article is an open access article distributed under the terms and conditions of the Creative Commons Attribution (CC BY) license (https:// creativecommons.org/licenses/by/ 4.0/).

At present, VPN traffic classification methods mainly include the fingerprint-based method, payload-based method and statistical feature-based method. Fingerprint-based method usually matches the traffic fingerprint to be detected with the fingerprint database to identify the traffic. Furthermore, it has high accuracy and is usually applied to some fine-grained traffic identification problems, such as Tor traffic identification [12]. The payload-based method directly imports all or part of the packet payload into a classification model. One representative of them is an end-to-end encrypted traffic classification method proposed by Wang et al. [13]. A payload-based method which usually adopts deep learning algorithm has the advantage of a simple feature extraction process without manual design. However, the payload-based method is affected by encryption and has a better effect on unencrypted VPN traffic or traffic encrypted in a specific way. The statistical featurebased method extracts statistical features from VPN traffic and then constructs a machine learning or deep learning model to classify these traffic. Statistical features which are not affected by encryption can effectively solve the problem of VPN traffic classification. Shapira et al. [14] used packet size distribution at a different time to create an image, which they called FlowPic. After that, these FlowPics were sent into a traditional CNN model (Lenet-5) for classification, and the final experiment achieved good results. However, the FlowPic method also has limitations, that is, feature collisions caused by single packet size distribution will affect the accuracy of VPN traffic classification.

Recently, many deep learning-based methods were proposed and achieved some good performance. However, most of them only focused on the encapsulation of VPN. When these VPN traffics used TLS encryption, the methods will be influenced. For example, Shapira et al. [14] achieved better classification effectiveness in an unencrypted ISCX-VPN dataset than in the encrypted ISCX-Tor dataset. This is because encryption will hide some traffic plaintexts and randomize the payload. In this regard, we proposed a Packet Blockbased method in order to improve the effectiveness of encrypted VPN traffic classification. In this method, continuous packets in the same direction were aggregated as the basic units, which are called Packet Block. Thus, traffic can be represented as a Packet Block sequence. The Packet Block features use the data size distribution of the whole flow and some packet interaction relations between the communication parties to reduce not only the effect of TLS and VPN encapsulation but also the feature collision of different traffic types. Then, an image was created using Packet Block features, which were subsequently imported into a CNN model to train the classifier. Finally, our method achieved satisfactory classification accuracy. In this paper, the ISCX-Tor dataset posted online and the OpenVPN dataset captured by ourselves were used. The ACC of traffic classification problems under the ISCX-Tor dataset was 93.31%, and that under the OpenVPN dataset was 97.20%. It can be seen that our method achieved a good classification effect in both authoritative public datasets and self-captured datasets.

The rest of this paper is structured as follows: Section 2 describes related work. Section 3 mainly introduces the encrypted VPN network traffic classification method based on Packet Block image, Section 4 mainly introduces the experiment designs and results. Section 5 is the conclusion of this paper.

#### **2. Related Work**

This paper mainly studies the classification of VPN traffic in TLS tunnels. Current studies rarely discuss the dual effect of VPN encapsulation and TLS encryption on traffic classification. However, VPN traffic in the TLS tunnel has the dual characteristics of VPN traffic and TLS encryption. Therefore, the classification of VPN traffic can also bring enlightenment to the classification of VPN traffic in TLS tunnels explored in this paper. Recently, VPN traffic classification mainly uses fingerprint-based methods, payload-based methods and statistics-based methods.

Fingerprint-based method: in the handshake process of TLS encryption, the messages before the two programs exchange cipher suites are unencrypted, including the information in the Client Hello phase or Server Hello phase, and the certificate-related information provided by the server. They can usually be used to construct the fingerprint characteristics of traffic. In addition, the fingerprint characteristics may also include DNS context information and the HTTP context information of traffic. Fingerprint-based methods usually extract fingerprint information from a large number of traffic to build a fingerprint database. Then, whenever a new traffic to be classified is encountered, its fingerprint is matched with the fingerprint in the database to obtain the final classification result. Johanna et al. [12] found that the issuer and subject fields in normal certificates often do not contain information such as location or company name, and use random generic names, which are formatted as www.+ base-32 code of 8–20 letters + Com or Net, so Johanna et al. [10] successfully identified tor traffic from the dataset containing tor traffic using the issuer and subject fields in the certificate as fingerprints. Fingerprint-based methods usually perform well in some fine-grained classification problems, such as website recognition.

Payload-based method: a payload-based method takes all or part of the packet payload as the input of the classifier. Wang et al. [13] proposed an end-to-end encrypted traffic classification method for the first time. This method uses the first 784 bytes of TCP payload as input data to build a 1D-CNN model for traffic classification. After that, many researchers proposed improved methods on the basis of Wang et al. The research of He et al. [15] aimed to convert the first few non-zero payloads of the session into gray images, and used a convolutional neural network (CNN) to classify the converted gray images. Finally, satisfactory results were achieved in both non-VPN traffic and VPN traffic. All of these methods used the packet payload as the feature. Furthermore, a deep learning model was introduced to build the classifier, which achieved good results. However, the method based on the payload characteristics will be affected by encryption. Wang et al. [13] succeeded in classifying the VPN traffic because their training dataset and test dataset used the same encryption method and encryption key. Bu et al. [16] also proposed that, for VPN traffic, only using the header of the packet will have a better classification effect than using the entire packet. Therefore, when TLS encryption is applied, the packet payload is no longer meaningful due to symmetric encryption, and the effect of the payload-based method will be greatly reduced.

Statistical feature-based method: statistical feature-based method extracts statistical features from packets and imports them into classifiers. Statistical features are often not limited by TLS encryption or VPN. Mohammad et al. [17] proposed a deep learning-based approach called "Deep Packet". Furthermore, their approach achieved good effectiveness in both the application identification task and traffic categorization task. Gil et al. [18] used time-related statistical features to implement the tor traffic classification. They applied C4.5, KNN and random forest as classifiers. Finally, the accuracy in the ISCX-Tor dataset is more than 80%. Iliyasu et al. [19] proposed a semi-supervised learning method based on a deep convolutional generative adversarial network (DCGAN), the basic idea of which is to use the samples generated by the DCGAN generator and unlabeled data to improve the performance of the classifier trained with a small number of labeled samples. Their approach achieved good accuracy on both QUIC and ISCX-VPN datasets. Qin et al. [20] calculated the payload size distribution probability (they called it PSD) of packets in a twoway flow. Then, they used Renyi cross entropy to identify the similarity between the PSD of the traffic to be detected and the specific application. Finally, they solved the classification problem of eight kinds of VPN applications. Shapira et al. [14] converted traffic into images for identification. This method used the packet size and packet arrival time to create images called FlowPic. Then, these FlowPics were sent to a CNN for classification. The classifier has excellent classification accuracy in traffic classification and application recognition. Although FlowPic uses the packet size and arrival time as features, the time-related features are greatly affected by the network environment and make little contribution to VPN traffic classification. Therefore, FlowPic actually takes the packet size as the dominant feature. However, the packet size is highly correlated with the protocol and traffic size. When including many traffic categories, they may have similar packet size distribution. Shapira et al. [14] also mentioned that, in their experiment, 70% of the VoIP traffic was

finally identified as file transmission, which was caused by the feature collision of packet size features.

For the classification of VPN traffic in the TLS tunnel in this paper, VPN encapsulation and TLS encryption will have a certain impact on the above methods. Fingerprint-based methods usually extract fingerprint information from the plaintext part of traffic, but there is no meaningful plaintext part of VPN traffic in a TLS tunnel. Thus, fingerprint-based methods will be difficult to apply to the classification of VPN traffic in TLS tunnel. Under the effect of TLS encryption, the payload-based method can hardly contribute to the classification of VPN traffic in the TLS tunnel. As for the statistical-feature-based method, because some statistical features are not affected by VPN encapsulation and TLS encryption, it can be a feasible solution to the classification of VPN traffic in the TLS tunnel. This paper proposes a method based on the Packet Block image, which uses the size and length characteristics of the Packet Block to reduce the feature collision of different kinds of traffic, so as to realize the classification of encrypted VPN traffic.

#### **3. Method**

This paper explores the classification of VPN traffic in TLS tunnels. VPN encapsulation and TLS encryption will bring dual challenges to our traffic classification problem. Most VPN applications add the headers of the VPN client and server the original packets, in order to encapsulate them into new TCP/IP packets. As a result, the header information we obtain is actually the header information of VPN client and server, and the headers of original traffic are hidden. TLS further encrypts the VPN payload and invalidates some methods using payload and plaintext information. The encapsulation of TLS and VPN will have a certain effect on the transmission behavior, but the relative size of packets is not affected by VPN and encryption. Therefore, in recent years, there are many studies on using packet size distribution features to classify VPN traffic.

Qin et al. [20] first proposed the classification of eight applications using PSD features. Shapira et al. [14] used the packet size and packet arrival time to create an image to classify VPN traffic. Both of their studies used packet size distribution features and obtained good classification results. However, the possible feature collision of packet size distribution affects the accuracy of classification. For example, some applications of video and file transfer use the HTTPS protocol, resulting in similar packet size distributions. In this regard, we put forward the concept of Packet Block. Packet Block is the aggregation of a series of consecutive packets in the same direction. Our method takes Packet Block instead of a packet as the basic unit of flow, and uses the length and size of Packet Block to generate images to realize the classification of VPN traffic in a TLS tunnel. Packet Block is not affected by VPN and TLS encryption, and can reflect the data size distribution of the flow and some packet interaction between communication parties. It can be used to classify VPN traffic in TLS tunnels.

#### *3.1. Packet Block*

Nowadays, most studies on VPN traffic classification regard the packet as the basic unit of flow. Thus, as the flow can be represented as an ordered set of packets, then this model of the flow can be expressed in Figure 1. In the figure, a vertical line represents a packet, the height represents the packet size and the distance between the line represents the time interval between two packets. In order to express the deeper information of the traffic, we put forward the concept of Packet Block, that is, a group of continuous packets in the same direction. We define the number of packets in Packet Block as the length of the Packet Block, and the average size of packets in the packet block as the size of the Packet Block (the upstream traffic is positive and the downstream traffic is negative). [Packet Block's length, Packet Block's size] can be used to indicate a Packet Block. We consider that a flow consists of an ordered Packet Block group, and this model of the flow can be regarded as Figure 2.

**Figure 1.** Flow model-based packet.

**Figure 2.** Flow model-based Packet Block.

Packet Block feature extraction is simple and is not affected by TLS encryption and VPN encapsulation. Thus, it can be used to classify VPN traffic in TLS tunnels. In the traffic model with the Packet Block, we ignore the time-based feature. This is because time which is greatly affected by the network environment has low robustness, and thus it is of little help to our traffic classification problem. Compared with packet feature, the Packet Block feature can still represent the rule of byte distribution of a flow. In addition, the Packet Block feature can reflect some other deep features, such as the relationship between the upstream and downstream traffic of the communication parties, the data grouping method of the application type and some other deep features. For example, some VPN tunnels may combine multiple packets into one, so it can change the packet size rule of a flow. In a Packet Block, the packet size increases but the packet number reduces. This causes the total size of the Packet Block to adopt a stable value. Therefore, Packet Block feature can deal with VPN and TLS encrypted better than using a single packet.

#### *3.2. Packet Block Image*

After describing the traffic model in Packet Block, we will create a two-dimensional image of the flow, which is called the Packet Block image. The X axis is the length of the Packet Block and the Y axis is the size of the Packet Block. The value of each pixel represents the number of Packet Blocks with a corresponding length and size. For example, if the value of (4, 1200) is 7, this means that there are seven Packet Blocks with a length and size of 4 and 1200B, respectively. Thus, this image can be regarded as a distribution matrix of the Packet Block length and size. Then, we will use these images as the input of a CNN model. We depict the Packet Block images of five different traffics in OpenVPN dataset, and the images are shown in Figure 3.

It can be seen that several types of traffic show some rules on the Packet Block image. The Packet Block length of VoIP and chat traffic is relatively short, generally within 10. The Packet Block size of chat traffic is smaller than that of VoIP traffic. The Video, FT (file transfer), and browsing traffic have some similarities in the Packet Block size distribution. The downstream Packet Block size is larger than that of the upstream, and the downstream Packet Block size of FT traffic is almost in a straight line. The downstream Packet Block size distribution of video traffic is relatively dispersed, while the downstream Packet Block size distribution of a browsing traffic is the most dispersed. This is because, when a web page is opened, the web page will load a variety of resources, which have different sizes. However, there are some smaller Packet Blocks in video traffic because video traffic is accompanied by a small amount of audio, text and other traffic. When downloading files, the downstream traffic is relatively pure, which results in the downstream Packet Block size of FT traffic being basically distributed along a straight line near 1400 bytes. In terms of the Packet Block length, that of FT traffic will also be slightly smaller than that of video and browsing traffic.

**Figure 3.** Example of Packet Block images.

The X axis is the length of the packet block, with a value ranging from 0 to N. If the value of N is excessively large, the effective part will be compressed on the left side of the image. If the value of N is too small, many packet blocks will be outside the image. Therefore, it is essential to choose an appropriate N value. Observing the packet blocks of different traffics, we initially select the value of N between 10 and 150, and the specific value will be decided through subsequent experiments. The Y axis is the size of the packet

block. The size is less than 1500B, which is the MTU of Ethernet. For the convenience of the calculation, the Y axis value range is (−1500, 1500], with a total of 3000 dimensions. In a proposed classification method, the distribution accurate to 1 byte is unnecessary. We only need the rough distribution of the packet block size. Therefore, we aggregated the Y axis in k-byte units. At this time, the Y axis dimensions M can be seen as 3000/K. Therefore, the values of the X axis are {0, 1, . . . , N} and those of the Y axis are {(−0.5 MK, −(0.5 M − 1) K], . . . , (−K, 0], (0, K], (K, 2 K], . . . , ((0.5 M − 1)K, 0.5 Mk)]}.

#### *3.3. An Encrypted VPN Traffic Classification Framework Based on CNN Model*

This paper proposes an encrypted VPN traffic classification framework based on the CNN model [21] to realize traffic classification. The specific framework is shown in Figure 4. The whole framework is divided into two parts: model generation, verification or classification. The model generation can be divided into three parts: data preprocessing, image generation and model training.

**Figure 4.** Traffic classification framework based on CNN model.

• Data preprocessing: the task of data preprocessing is to extract the flow from dataset files and convert them into data that are easy to process. We use the four-tuple of source IP address, source port, destination IP address and destination port to distinguish different flow. In order to increase the number of datasets and reduce overfitting, we divide the flow into session by time T. A four-tuple plus a start time can determine a session. In this paper, each session is represented by a group of Packet Blocks, and the length and size of each Packet Block are calculated, because our subsequent experiments will only use these two types of information. The labels of each session need to include whether it is VPN traffic, application type and application, such as VPN traffic, video and Tencent. If the number of Packet Blocks in a session is less than 10, we think that the session does not contain valid information and round it off. Finally, the format of each session is shown in Figure 5.

**Figure 5.** Example of sessions.

• Image generation: In this step, we generate the Packet Block image by using the session obtained by data preprocessing. In this process, we need to determine two parameters: the aggregation degree K of the Packet Block size and the upper limit N of Packet Block length. If the length of the Packet Block exceeds the upper limit N, the length of the packet block will be regarded as N. The final Packet Block image is an M\*N matrix (M = 3000/K).


#### **4. Experiments and Results**

#### *4.1. Dataset*

The ISCX-tor [22] dataset and ISCX-VPN [18] dataset are widely used by researchers in traffic classification research. However, the ISCX-VPN dataset is not completely applicable to this study. On the one hand, the volume of some types of traffic in ISCX-VPN dataset is small, such as chat and email. If the dataset is directly applied to the experiment, the imbalance may have some influence on the final result. On the other hand, all types of traffic in the ISCX-VPN dataset are not TLS tunnel traffic, which is not completely consistent with our research scenario. Therefore, the ISCX-Tor dataset and the OpenVPN dataset captured by ourselves are used in this paper. The ISCX-Tor dataset consists of the tor traffic and non-tor traffic of seven application types. Five of them are selected in our study: VoIP, video, file transfer, chat and browsing. The OpenVPN dataset we captured also includes these five application types. The application types and specific applications of ISCX-Tor dataset and VPN dataset captured by us are shown in Table 1.


**Table 1.** Protocols and applications for each traffic category.

In order to create a complete and representative dataset, this study set up a VPN environment in the laboratory by referring to the capture method of the ISCX-VPN [18] dataset to capture the VPN proxy tunnel traffic encrypted by the TLS of different types and applications. Referring to studies on traffic classification, most of them use traffic from five application types, namely video, VoIP, file transfer, chat and browsing, to classify traffic. Therefore, the dataset of this study includes the traffic of these five application types. For each type of traffic, we captured a regular session and a session over a VPN tunnel, so there are 10 traffic categories. All application types and specific applications contained in the OpenVPN dataset captured by us are shown in Table 1.

The lab environment is shown in Figure 6, where a VPN gateway is set up between the PC and the campus network gateway, where the OpenVPN client and Stunnel are deployed. The ENP3s0 port of the VPN gateway is connected to the PC, and the ENP1s0 port is connected to the external network through the campus network. With this configuration, we run the TCP-dump on the VPN gateway and capture a pair of .pcap files on enp3s0 and ENp1s0 ports, marking non-VPN traffic and VPN traffic, respectively.

**Figure 6.** Network environment of the captured dataset.

The source IP, destination IP, source port and destination port of the VPN traffic we capture are exactly the same, which makes it difficult for us to distinguish them. Therefore, we control the PC to run only one program in a time period, and all traffic tags in this time period are the program. In addition, we also need to observe the non-VPN traffic captured by the ENP3s0 port to ensure that the noise traffic is within an acceptable range. Otherwise, we will discard the .pcap file. Below, we give a detailed description of different types of traffic generation:


For each application of each traffic, we captured the .pcap file of 3 h. Finally, 29.47 g VPN traffic and 25.39 g non-VPN traffic are captured. VPN traffic and non-VPN traffic had five application types: video, VoIP, file transfer, chat and browsing. Each type has three applications or protocols. See Table 1 for details.

In the data preprocessing of this experiment, in order to increase the number of samples in the dataset and reduce overfitting, the flow is divided into 30 s' sessions. Table 2 shows the number of sessions of various application types in the ISCX-tor dataset and the captured OpenVPN dataset. It can be seen that the number of sessions in the OpenVPN dataset is greater than that in the ISCX-tor dataset, and the number of chat type sessions is greatly increased compared with that in the ISCX-tor dataset, thus ensuring the balance of the dataset.

**Table 2.** The number of sessions in ISCX-Tor and OpneVPN.


#### *4.2. Measurement*

In this study, we use the accuracy (*ACC*) as the measurement. Accuracy is not only one of the most common measurements in deep learning, but also the main measurement adopted by many flow classification methods such as FlowPic. The definition of accuracy in this experiment is as follows.

$$A\mathbb{C}\mathbb{C} = \frac{\sum\_{i \in A} TP\_i}{\sum\_{i \in A} (TP\_i + FP\_i)}.\tag{1}$$

where *A* = {VoIP, Video, FT, Chat, Browsing}; the true positive (*TP*) and false positive (*FP*), respectively, represent the number of positive samples correctly classified and the number of negative samples incorrectly classified as positive samples; and *TPi* and *FPi* represent the number of true positives and false positives of category *i*, respectively. The advantage of *ACC* is intuitive. *ACC* directly represents the proportion of correctly predicted samples. However, there is also a problem with the *ACC* index. If there is a large number of samples in one class in the dataset, even if the classifier has a poor classification effect on other classes of samples, it can still maintain a high *ACC* in the end.

In order to avoid this situation, in addition to *ACC*, we also used the confusion matrix to better observe the multi classification problem. In the confusion matrix, each row represents the real label and each column represents the predicted label. Diagonals indicate the probability of the correct prediction for each category. Precision, recall and F1-score are also used. In this experiment, they are defined as follows:

$$Precision = \frac{TP}{TP + FP}.\tag{2}$$

$$Recall = \frac{TP}{TP + FN}.\tag{3}$$

$$F1\text{-score} = \frac{2 \ast Precision \ast Recall}{Precision + Recall} \cdot \tag{4}$$

#### *4.3. CNN Model*

Traffic classification in this paper adopts the classification framework based on the aforementioned CNN model, and the parameter settings of the CNN model are shown in Table 3. Our CNN model is divided into seven layers, and the input is a matrix of (60, 60). Parameter Settings for seven layers are as follows:


**Table 3.** The main parameters of CNN model.


#### *4.4. Packet Block Image Size Selection Experiment*

In our experiment, the aggregation degree K of the Packet Block size and the upper limit N of Packet Block length affect the size of the images. The size of the images can be calculated as M\*N (M = 3000/K). To choose the appropriate value of M and N, we designed two groups of experiments on OpenVPN dataset. The first group set M = 30 to observe the impact of different N values on the classification accuracy. The second group set N = 60 to observe the impact of different M values on the classification accuracy. We conducted five VPN traffic classification experiments on two groups to choose the value of M and N.

Table 4 shows the experimental results of traffic classification under different N values when M = 30. It can be seen that the classification accuracy increases with the increase in N at the beginning. When N increases to 60, the classification accuracy remains almost stable. This is because the packet block length of most traffic is less than 60. The continuous growth of N does not provide more available information for the packet block image. In addition, when N = 60, the time of running each epoch is only 1 s longer than when N = 10. Therefore, in subsequent experiments, we will choose the value of N as 60.

Table 5 shows the experimental results of traffic classification under different M values when N = 60. It can be seen that M ranges from 30 to 600, and the value of classification accuracy changes very little. This is because we only need to understand the approximate byte distribution of traffic, rather than accurate to each byte, for traffic classification problems of different application types. Therefore, the 30-dimensional packet block size feature that aggregates 100 bytes can achieve a good accuracy. In terms of running time, the time taken by the model to run an epoch will decrease significantly after M becomes smaller. Therefore, we choose the value of M as 60 in the subsequent experiments.

#### *4.5. Traffic Classification Experiment Based on Packet Block Image*

We will select the parameters obtained above, i.e., m = 60, n = 60, and conduct traffic classification experiments for five application types of traffic on our captured OpenVPN dataset and public ISCX-tor dataset. The classifier adopts the CNN model described above. The ACC, precision, recall and F1-score of the final traffic classification problem are shown in Table 6. The classification ACC of the packet block image method on the OpenVPN dataset is 97.20%, and the classification ACC on the ISCX-tor dataset is 93.31%. Furthermore, the F1-scores in the two datasets are 96.70% and 89.28%.


**Table 4.** Results under different N when M = 30.

**Table 5.** Results of different M when N = 60.


**Table 6.** Results in OpenVPN and ISCX-Tor.


OpenVPN dataset is 97.20%, and the classification ACC on the ISCX-tor dataset is 93.31%. It can be seen that the ACC of our Packet Block image method for five traffic classification problems under the ISCX-Tor dataset is significantly higher than that of FlowPic method (67.8%). Even though FlowPic balances ISCX-Tor dataset, our ACC is still higher than their 86.9%. Furthermore, under the OpenVPN dataset captured by ourselves, the ACC of classification can reach 97.20%. As for the speed, it takes approximately 5–10 min for FlowPic to run an epoch, while the Packet Block image method only needs 3 s to run an epoch because the dimension of features is greatly reduced. Therefore, the processing speed of the Packet Block image method is much faster than the FlowPic method.

In order to observe the multi-classification results more clearly, we give the confusion matrix of the Packet Block image method under the OpenVPN dataset and ISCX-tor dataset, as shown in Figures 7 and 8. In the OpenVPN dataset, the number of correctly identified samples exceeds 97% in the four types of traffic: VoIP, video, FT and chat. However, the classifier has a poor recognition effect on browsing, and only 86.26% of the browsing traffic is correctly recognized, and approximately 14% of browsing traffic is recognized as video, VoIP and FT traffic. Under the ISCX-tor dataset, we can first see that the identification effect of the chat traffic is not good, which is within our expectations. This is because various types of traffic in the ISCX-tor dataset are unbalanced, and the number of chat traffic is much smaller than that of other types of traffic. In addition, the classifier has the worst recognition ability for video, and 22.81% of video traffic is recognized as browsing traffic. However, it is worth noting that all misclassified traffic is related to browsing traffic. Some browsing traffic is identified as other types, and some other types are identified as browsing. If the browsing traffic is ignored, the classification accuracy of the other four traffic types will be greatly improved. It can be seen that regardless of whether the OpenVPN dataset or ISCX-tor dataset is used, the browsing traffic will have a great influence on the traffic classification. This is because, when we label the dataset, the browsing traffic overlaps with the other four types of traffic. For example, when we visit the main page of some websites, a video will be played automatically. At this time, the browsing traffic is actually also a video traffic. This leads to the limitation of browsing traffic in traffic classification.

**Figure 7.** Confusion matrix in an OpenVPN dataset.

#### *4.6. Comparison with Related Methods*

We compare the proposed method with some traditional machine learning methods and recent deep learning methods in the ISCX-Tor dataset. The results are shown in Table 7. We can see that the methods based on deep learning have an apparent improvement rather than traditional ML methods. Furthermore, our proposed method achieves the best accuracy and F1-score in recent deep learning works. In these methods, the results of PSD and end-to-end methods are reproduced by us in the ISCX-Tor dataset, because they use different datasets. It can be seen that the accuracy of the end-to-end method is only above 20%. This is because payload-based methods have no ability to deal with encrypted traffic. Overall, our proposed approach can deal with VPN traffic classification using TLS encryption, and have a good improvement over the recent methods.

**Figure 8.** Confusion matrix in ISCX-tor dataset.



#### **5. Conclusions**

This paper aims to solve the classification problem of VPN traffic in the TLS tunnel. The encryption of the TLS tunnel and the encapsulation of VPN are two of the main difficulties in the traffic classification problem. In this paper, we propose an encrypted VPN traffic classification method based on Packet Block image. In this method, we represent the flow with a sequence of Packet Blocks, and then generate a Packet Block image of it. Actually, the Packet Block image represents the distribution of packet blocks in VPN traffic. These images are then imported into CNN for learning. Packet Block images can extract deep features besides the packet size distribution, which can greatly avoid feature collision, so as to ensure the classification result. The public authoritative ISCX-tor dataset and our OpenVPN dataset are selected to verify our method, and the traffic classification experiments of five application types are carried out on the two datasets, respectively. Finally, the Packet Block image method has a classification ACC of 93.31% under the ISCX-tor dataset and 97.20% under the OpenVPN dataset we captured, which a has higher accuracy than the FlowPic method. Because the size of the method image of the Packet Block image is smaller than that of FlowPic, the processing speed of the model is much

higher than that of FlowPic. Therefore, it was verified that our Packet Block image method can be used in the classification problems of VPN traffic in the TLS tunnel.

However, if many different flows are mixed, the effectiveness of our approach will be diminished. Almost all methods run their experiments using a single-flow dataset. However, most real-world traffics are mixed, such as watching a video when downloading some files. Thus, the identification and division of mixed VPN traffic is an important research in the future.

**Author Contributions:** Conceptualization, S.Z. and W.S.; Methodology, S.Z. and W.S.; Software, W.S.; Validation, Y.Z., J.L. and C.S.; Resources, Z.S.; Data Curation, W.S. and Y.Z.; Writing—Original Draft Preparation, W.S.; Writing—Review and Editing, S.Z.; Project Administration, S.Z. All authors have read and agreed to the published version of the manuscript.

**Funding:** This research received no external funding.

**Data Availability Statement:** Publicly available datasets were analyzed in this study. The Tor-nonTor datasets can be found here: https://www.unb.ca/cic/datasets/tor.html. And the data we created can be found here: https://github.com/spirit19970507/OpenVPN (accessed on 14 October 2022).

**Conflicts of Interest:** The authors declare no conflict of interest.

#### **References**


**Disclaimer/Publisher's Note:** The statements, opinions and data contained in all publications are solely those of the individual author(s) and contributor(s) and not of MDPI and/or the editor(s). MDPI and/or the editor(s) disclaim responsibility for any injury to people or property resulting from any ideas, methods, instructions or products referred to in the content.

### *Article* **A Novel COVID-19 Image Classification Method Based on the Improved Residual Network**

**Hui Chen \*, Tian Zhang, Runbin Chen, Zihang Zhu and Xu Wang**

School of Computer Science and Engineering, Anhui University of Science & Technology, Huainan 232001, China **\*** Correspondence: huichen@aust.edu.cn

**Abstract:** In recent years, chest X-ray (CXR) imaging has become one of the significant tools to assist in the diagnosis and treatment of novel coronavirus pneumonia. However, CXR images have complex-shaped and changing lesion areas, which makes it difficult to identify novel coronavirus pneumonia from the images. To address this problem, a new deep learning network model (BoT-ViTNet) for automatic classification is designed in this study, which is constructed on the basis of ResNet50. First, we introduce multi-headed self-attention (MSA) to the last Bottleneck block of the first three stages in the ResNet50 to enhance the ability to model global information. Then, to further enhance the feature expression performance and the correlation between features, the TRT-ViT blocks, consisting of Transformer and Bottleneck, are used in the final stage of ResNet50, which improves the recognition of complex lesion regions in CXR images. Finally, the extracted features are delivered to the global average pooling layer for global spatial information integration in a concatenated way and used for classification. Experiments conducted on the COVID-19 Radiography database show that the classification accuracy, precision, sensitivity, specificity, and F1-score of the BoT-ViTNet model is 98.91%, 97.80%, 98.76%, 99.13%, and 98.27%, respectively, which outperforms other classification models. The experimental results show that our model can classify CXR images better.

**Keywords:** novel coronary pneumonia; image classification; ResNet50; multi-headed self-attention (MSA); TRT-ViT

**Citation:** Chen, H.; Zhang, T.; Chen, R.; Zhu, Z.; Wang, X. A Novel COVID-19 Image Classification Method Based on the Improved Residual Network. *Electronics* **2023**, *12*, 80. https://doi.org/10.3390/ electronics12010080

Academic Editor: Chiman Kwan

Received: 25 November 2022 Revised: 20 December 2022 Accepted: 21 December 2022 Published: 25 December 2022

**Copyright:** © 2022 by the authors. Licensee MDPI, Basel, Switzerland. This article is an open access article distributed under the terms and conditions of the Creative Commons Attribution (CC BY) license (https:// creativecommons.org/licenses/by/ 4.0/).

#### **1. Introduction**

At the end of December 2019, some cases of novel coronavirus pneumonia of unknown origin were reported in Wuhan, Hubei province, which was officially named COVID-19 by the World Health Organization in February 2020 [1]. COVID-19 has pandemic characteristics and spread rapidly worldwide, seriously threatening the lives and health of people [2]. According to the Global New Coronavirus Pneumonia Epidemic Real-Time Big Data Report, as of October 2022, more than 200 countries and regions worldwide have been infected with the new coronavirus pneumonia, and the global number of confirmed cases cumulatively exceeds 600 million and the number of deaths exceeds 6 million.

COVID-19 is a novel infectious disease caused by severe acute respiratory syndrome coronavirus-2 infection, whose early clinical features are mainly fever, dry cough, and malaise, with a few accompanying symptoms such as runny nose and diarrhea. Severe cases can cause dyspnea and organ failure, which even can lead to death [3,4]. For more than two years, due to the instability of the COVID-19 gene sequence, several variants of COVID-19 have been generated. These variants are characterized by greater concealment, which makes it extremely difficult to diagnose them accurately.

Nucleic acid testing is the most common method to diagnose COVID-19. This method detects viral fragments by using a reverse transcription polymerase chain reaction (RT-PCR) [4] technique. However, the nucleic acid test has disadvantages such as being time consuming, low sensitivity, high false negative rate, and the need for special test kits [5,6], which limits its development. In recent years, medical imaging techniques have been

widely used for the diagnosis of various diseases. Chest X-ray (CXR) and computed tomography (CT) [7] are used to diagnose COVID-19. Compared with the nucleic acid test, medical-imaging-based diagnostics are faster and more effective. CT generates a lot of radiation and is not suitable for pregnant women and infants, while CXR contains little radiation and can reduce the risk of cross-infection to some extent. Furthermore, CXR is less costly and more widely used than CT. However, the manual analysis and diagnostic process based on CXR images depends heavily on the expertise of healthcare professionals and the analysis of image characteristics is time consuming, which makes it difficult to observe occult lesions at an early stage and distinguish between other viral and bacterial pneumonias [8]. Due to the urgent need, experts recommend the use of computer-aided diagnosis to replace manual diagnosis to improve the efficiency of detection and help doctors diagnose more accurately.

With the development of artificial intelligence, deep learning methods [9,10] have achieved good success in the field of computer vision. Several studies [11–13] have shown that convolutional neural networks (CNNs) have excellent feature extraction capabilities and can accurately extract image features of different scale sizes. Medical image classification using CNN requires the fusion of feature maps from different scale sizes, while taking into account both local and global information. The representative models used for COVID-19 classification include VGG networks, ResNet networks, and high-resolution networks [14]. These experimental results suggest that local feature extraction of medical images using CNN is feasible. CNN has a fixed sampling location and its limited perceptual field leads to poor global modeling capabilities, which cannot learn image features effectively according to the complex changes of lesion regions. COVID-19 CXR images showed consolidation in the lung and ground-glass clouding with commonly irregular shapes, such as hazy, patchy, diffuse, and reticular nodular patterns, which greatly increased the difficulty of COVID-19 detection [15,16]. Consequently, improving feature extraction from infected regions with complex shapes and establishing long-distance dependencies between features is the key to recognize COVID-19 accurately. Transformer is the most advanced sequence encoder, whose core idea is self-attention [17], which can establish long-range dependencies between feature vectors and improve feature extraction and representation. Vision Transformer (ViT) [18,19] is the representative model. The experimental results indicate that the extraction of global information from medical images using pure Transformer is practicable, but it tends to lead to excessive memory and computational costs. As a result, some studies [20,21] have shown that combining convolution and Transformer as a hybrid network model can help improve classification performance while reducing computational cost.

In this paper, we design a new deep learning network model (BoT-ViTNet) based on ResNet50 for automatic image classification to help doctors more accurately identify COVID-19 and other viral pneumonias. The network model first combines the advantages of CNN for extracting local feature representations and multi-headed self-attention for global information modeling. Then, the TRT-ViT blocks are used in the final stage to fuse global and local feature information. This well solves the problem of learning feature representations from different complexly infected regions of CXR images, thus significantly improving the classification performance of the model. The main work includes:


The remainder of this paper is described as follows. The related work of this paper, which includes convolutional neural network, vision Transformer, and hybrid network models, is briefly mentioned in Section 2. Section 3 presents the proposed BoT-ViTNet in detail and describes it in part. In Section 4, extensive experiments are conducted to prove the effectiveness of BoT-ViTNet and the experimental results are discussed and analyzed. Finally, the conclusion of the whole paper is given in Section 5.

#### **2. Related Work**

#### *2.1. Convolutional Neural Network*

In recent years, CNN models have been widely used in the field of computer vision, such as image classification, target detection, and semantic segmentation. Anand et al. [23] fine-tuned a VGGNet [24] with an input size of 200 × 200 and tested it using three different pooling layers to obtain high classification accuracy. Rajpal et al. [25] designed a new classification method that consists of three modules. In the first module, they used ResNet50 [26] for feature extraction and solved the network degradation problem through skip connections in the residual blocks. In the second module, a pool of frequencies and textures of selected features were constructed and the features were further simplified using PCA before being passed to the feed-forward neural network. The third module connected the features obtained from the first and second modules, passing them to the dense layer and classifying them by the softmax function. Sarker et al. [27] proposed a COVID-DenseNet, which used the densely connected Densenet-121 [28] as a feature extractor and a fully connected layer with softmax activation as a classifier for COVID-19 patient detection. This method reduces the number of parameters and computational complexity through feature reuse, and can effectively mitigate the gradient disappearance problem. However, there are still some challenges when using CNN for COVID-19 classification, such as irregular shapes and complex positional information in CXR images.

#### *2.2. Vision Transformer*

Transformer was originally applied in the field of natural language processing with significant results. ViT [18,19] showed that Transformer can achieve better performance in computer vision tasks. ViT performs self-attention by mapping a series of image patches to semantic tags, which helps capture long-term relationships between sequence elements. The Data Efficient Image Converter (DeiT) network [19] uses a strategy of knowledge distillation based on the ViT architecture. The self-attention mechanism in the encoder is used for different regions of the image and integrated with the information in the whole image to complete the classification prediction by adding two connected classifiers. Jalalifar et al. [29] discovered experimentally that DeiT without a single convolutional layer successfully achieved the same performance as DenseNet169, which showed that ViT can be applied to medical image analysis tasks. To reduce the computational effort, Swin Transformer [30] was proposed to compute self-attention in non-overlapping local windows which reduced computational complexity, but the sparse attention employed limited the ability to model remote relationships. Currently, researchers are focusing more on efficiency, including the effectiveness of self-attention and various training strategies.

#### *2.3. Hybrid Network Models*

Recent research [31,32] has indicated that combining convolution and Transformer as a hybrid network model can help to fully incorporate the benefits of both. Rao et al. [33] enabled the model to focus on deep semantic information more by introducing a self-attention mechanism in CNN. Lin et al. [34] proposed an adaptive attention network (AANet). The network first used deformable ResNet to learn feature representations to adapt to the diversity of COVID-19 features. Then, the network utilized a self-attention mechanism to model non-local interactions and learn rich contextual information to detect complex-shaped lesion regions, which improved recognition efficiency. Aboutalebi et al. [35] proposed a multi-scale encoder-decoder self-attention (MEDUSA) model to solve the problem of

overlapping image appearance. The model improved the ability to model global remote spatial context by introducing self-attention modules, achieving good classification performance on several datasets. Li et al. [36] proposed a new UniFormer that effectively unifies convolution and self-attention in a concise transformer format for overcoming local redundancy and global dependencies. The method achieves better performance on image classification tasks.

#### **3. Method**

To identify COVID-19 CXR images accurately, we propose a novel model called BoT-ViTNet in this paper, whose architecture is shown in Figure 1. The BoT-ViTNet model contains three parts. For the first part, we use the bottleneck block for local feature extraction of the lesion region of CXR images. For the second part, multi-headed self-attention block (MSA) is introduced to learn the contextual information of the extracted features, which can enhance the global modeling capability of the feature information. For the last part, the TRT-ViT block is used to extract both local and global feature information to further enhance the feature representation and correlation between feature locations.

**Figure 1.** The whole structure of the BoT-ViTNet model. (a) Bottleneck Block(s = 2); (b) Bottleneck Block(s = 1); (c) MSA Block; (d) TRT-ViT Block.

The general structure of BoT-ViTNet is similar to ResNet50, which also goes through four stages and each stage consists of 3, 4, 6, and 3 blocks, respectively, as shown in Figure 1. The CXR image is first passed through a 7 × 7 convolution layer with a step size of 2 and a

%R79L71HW

3 × 3 pooling layer with a step size of 2 to obtain a feature map with a resolution of 56 × 56 × 64. Then, the feature map is input into the Bottleneck block, which consists of two types of residual convolution structure, as shown in Figure 1a,b. The Bottleneck block conducts a channel expansion of the feature map when the step size is 1 and performs a downsampling operation to increase the perceptual field on the feature map when the step size is 2. After passing 2 Bottleneck blocks sequentially, the feature map is input to the MSA block, whose structure is shown in Figure 1c. The MSA block learns the global information of image features and establishes long-range dependencies of features to enhance the expression of features. To further fuse the global and local information, the feature maps are input to the TRT-ViT block after being processed by multiple Bottleneck blocks and MSA blocks. The structure of the TRT-ViT block is shown in Figure 1d. The TRT-ViT block extracts global features and local information by Transformer and Bottleneck, respectively, and then fuses the global features and local information, which improves the expression of features and the correlation of positions between features. The global average pooling layer is used to integrate the global spatial information for the fused features. The output features are mapped to the softmax layer for probability prediction. BoT-ViTNet can not only capture deep global semantic information of CXR images but also extract shallow local texture information of CXR images. It inherits the advantages of Transformer and CNN, improving the recognition performance. Table 1 shows the structural details of the BoT-ViTNet model.


**Table 1.** The structure details and specific parameters of the BoT-ViTNet model.

#### *3.1. Bottleneck Block*

Unlike the extraction of features using standard convolution, the Bottleneck block can reduce the computational complexity of the model while extracting features. Therefore, we used the Bottleneck block for local feature extraction of complex lesion regions in CXR images. The Bottleneck block consists of two 1 × 1 convolutions and a 3 × 3 depth-wise convolution. The first 1 × 1 convolution is used to reduce the number of channels of the feature map so that feature extraction can be performed more efficiently and intuitively. The 3 × 3 depth-wise convolution is used to extract the local feature information of the image. The second 1 × 1 convolution is used to expand the number of channels of the feature map so that the number of channels of the output feature map is equal to the number of channels of the input feature map, and to perform summation. The use of the Bottleneck structure greatly reduces the number of parameters and computation, thus improving computational efficiency. In addition, a residual structure is added to each output to avoid causing network degradation and over-fitting problems. The residual block is computed as follows:

$$y = F(x) + x \tag{1}$$

where *x* denotes the input feature map, *y* indicates the output feature map, and *F*(*x*) represents the convolution operation. The residual network can span the previous layers of the network and act on the later layers, which can improve the gradient disappearance problem when the network is trained for back propagation.

#### *3.2. MSA Block*

To enhance the long-term dependencies of the features, we introduce the MSA block in ResNet50, as shown in Figure 1c. MSA [37] is an essential component of Transformer, which can unite feature information from different locations representing different subspaces. It is an extension of Self-attention (SA), which runs k SA operations in parallel at the same time and projects their concatenated outputs. We first review the basic SA modules that are widely used in neural network architectures. SA is the core idea of Transformer, which has the feature of a weak inductive bias. By performing similarity calculation, it can establish long-distance dependency between feature vectors and improve feature extraction and expression ability. The input of each SA consists of query *Q*, key *K*, and value *V*, which are linear transformations of the input sequence. The new vectors *Q*, *K*, and *V* are obtained by multiplying the original *Q*, *K*, and *V* with the weight matrices *WQ*, *WK*, and *WV*, which are learned during the training process, respectively. In this section, we use Scaled Dot-Product Attention for the similarity calculation among vectors with the following equation:

$$SA(X) = softmax \left(\frac{QK^T}{\sqrt{d}}\right) V\tag{2}$$

where *X* denotes the input sequence, *SA*(·) represents the SA operation, and *d* means the dimension of the head.

MSA concatenates k single-head self-attentions and performs a linear projection operation on them with the following equation:

$$X\_{\mathfrak{m}} = MSA(X) = \mathbb{C}uncat[SA\_1(X), \dots, SA\_k(X)]\mathbb{W}\_{\mathfrak{m}} \tag{3}$$

In Equation (3), *Xm* is the output of MSA, *MSA*(·) means the MSA operation, *Concat*[·] denotes the connection of feature maps with the same dimension, and *Wm* is the learnable linear transformation.

CNN has strong inductive bias and can effectively extract local texture information of feature maps in shallow networks, while MSA has weak inductive bias and can establish long-range dependencies of features in deep networks. Consequently, combining CNN with MSA can obtain powerful feature representation capability and high accuracy.

#### *3.3. TRT-ViT Block*

To further improve the feature representation, the TRT-ViT block [22], consisting of Transformer and Bottleneck is introduced in the last stage of ResNet50, which takes a global-then-local hybrid block pattern for feature extraction. As described in the paper [38], usually the Transformer with a larger receptive field can extract global information from the feature map and enable information exchange. In contrast, a convolution with a small receptive field can only extract local information from the feature map. TRT-ViT block fully combines the advantages of the Transformer and Bottleneck, which enhance the expression of features and the correlation of positions between features, thus helping to identify complex lesion regions in CXR images.

The network structure of the TRT-ViT block is shown in Figure 1d, which first uses Transformer to model the global information and then uses Bottleneck to extract the local information. Transformer is calculated as follows:

$$\begin{array}{l} X = X\_{in} + MSA(Norm(X\_{in}))\\ X\_{out} = X + MLP(Norm(X)) \end{array} \tag{4}$$

where *Xin* ∈ *<sup>R</sup>*H×W×<sup>C</sup> is the input feature map and *Xout* ∈ *<sup>R</sup>*H×W×<sup>C</sup> is the output feature map. We firstly perform an operation of channel reduction using a 1 × 1 convolution with a step size of 1, reducing the number of channels of the feature map to half of the original feature map. Then, to capture the long-range dependencies of features in complex lesion regions of the image, we use MSA in Transformer to extract global information of the feature map and implement information exchange within each channel. Finally, the global features after information exchange are delivered to the multilayer perceptron (MLP) layer to improve the ability of the network to acquire image background information, which helps to identify complex lesions in CXR images. After Transformer operation, the feature map containing global information is input into Bottleneck blocks to learn the local space information. We connect the extracted global features with local features to enhance the expression of the features and the correlation between the positions of the features, improving the recognition accuracy.

Transformer aims to establish global connections between features, whereas convolution captures only local information. The computational effort of the Transformer and Bottleneck is almost equal when the resolution of the input image is low, indicating that the placement of Transformer at a later stage of the network helps to balance performance and efficiency [39]. It was further demonstrated that using a hybrid block pattern of globalthen-local can be helpful to identify complex lesion regions in CXR images. Consequently, in this section, the Bottleneck block is replaced by the TRT-ViT block in the last stage of the ResNet50 network and cross-stacked, which can effectively extract local texture information and global semantic information from infected regions with complex shapes and performs feature fusion, while achieving high performance and high accuracy.

#### **4. Results and Discussion**

#### *4.1. Datasets*

We used the COVID-19 Radiography database [39] as the experimental data, which was obtained by researchers from Qatar University and Dhaka University in collaboration with physicians from Pakistan and Malaysia. The dataset contains 15,169 CXR images from 15,153 patients, including 3616 patients positive with COVID-19 (COVID), 1345 patients with viral pneumonia (Viral), and 10,192 patients with uninfected pneumonia (normal). The partial CXR images of the COVID-19 Radiography database are shown in Figure 2. In the experiments, the dataset is divided into a training set and a testing set in the ratio of 6:4 for training the model with parameters and validating the classification accuracy, respectively. There are 9094 images in the training set and 6059 images in the testing set. The number of images in each category in the dataset is shown in Table 2. To further demonstrate the robustness of our model, another CXR dataset called Coronahack [40] was used. This dataset has 5922 CXR images, includng1576 normal images and 4346 images of pneumonia. The detailed information about the dataset is shown in Table 3. We divided this dataset in the ratio of 8:2 to obtain 4737 images in the training set and 1185 images in the test set, as shown in Figure 3.

**Table 2.** Detailed information about the COVID-19 Radiography database.


**Table 3.** Detailed information about the Coronahack dataset.


COVID Viral Normal

**Figure 2.** The partial CXR images of the COVID-19 Radiography database.

Pneumonia Normal

**Figure 3.** The partial CXR images of the Coronahack dataset.

#### *4.2. Experimental Details*

The specific configuration of the environment and parameters for the experiment are shown in Table 4. In this experiment, our programming environment used the deep learning framework PyTorch 1.9.0 and the programming language Python 3.8. The operating system is ubuntu18.04. In the training process, we reset the size of all CXR images to 224 × 224 and use the Adam optimizer for model optimization, with the learning rate set to 0.001, the number of iterations set to 100, and the batch size set to 64. All experiments are performed on an RTX A4000 GPU and 16 GB memory.


**Table 4.** The configuration of the environment and parameters for the experiment.

#### *4.3. Evaluation Metrics*

In order to verify the validity and robustness of the BoT-ViTNet model, the confusion matrix and commonly used evaluation metrics were selected for effect evaluation in this experiment, including accuracy, precision, sensitivity, specificity, and F1-score. The equations for each indicator are as follows:

$$Accuracy(Acc.) = \frac{N\_c}{N\_t} \tag{5}$$

$$\text{Precision}(\text{Pre.}) = \frac{TP}{TP + FP} \tag{6}$$

$$Sensitivity(Sen.) = \frac{TP}{TP + FN} \tag{7}$$

$$Specificity(Spe.) = \frac{TN}{TN + FP} \tag{8}$$

$$F1 - score(F\_1.) = 2 \times \frac{\text{Pre.} \times Sem.}{\text{Pre.} + Sen.} \tag{9}$$

where *Nc* is the number of correctly predicted cases and *Nt* is the total number of predicted cases. TP (True Positive) denotes the number of correctly predicted COVID-19 positive cases. TN (True Negative) represents the number of correctly predicted normal and viral pneumonia cases. FP (False Positive) is the number of normal or viral pneumonia cases misdiagnosed as COVID-19 positive. FN (False Negative) indicates the number of COVID-19 positive cases misdiagnosed as normal or viral pneumonia.

#### *4.4. Experimental Results and Analysis*

#### 4.4.1. Comparison of Classification Effects of Different Models

To validate the effectiveness of the BoT-ViTNet model, we use some common deep learning models to perform experimental comparisons. The classification results of the COVID-19 Radiography database are shown in Table 5. From the results in Table 5, it can be known that the BoT-ViTNet model achieves the highest values of 98.91%, 97.80%, 98.76%, 99.13%, and 98.27% in terms of classification accuracy, precision, sensitivity, specificity, and F1-score, respectively. Compared with other common deep learning models, the classification accuracy of our model is improved by 1.98%, 1.80%, 1.72%, 1.53%, 2.11%, 4.19%, 5.99%, 4.70%, 1.29%, and 0.79%, respectively. Table 6 provides detailed results of the different models with BoT-ViTNet on the Coronahack dataset. From the results in Table 6, it can be known that the BoT-ViTNet model achieves the highest values of 98.40%, 97.99%, 97.89%, 97.89%, and 97.94% in terms of classification accuracy, precision, sensitivity, specificity, and F1-score, respectively. Compared with other common deep learning models, the classification accuracy of our model is improved by 2.03%, 1.10%, 1.27%, 0.93%, 1.71%, 3.29%, 6.42%, 4.48%, 0.68%, and 0.43%, respectively. These results

show that the BoT-ViTNet model makes full use of the advantages of convolution and MSA for CXR images classification, which can not only extract the local texture information of CXR images, but also capture the global semantic information of the images. Meanwhile, using the global-then-local hybrid block pattern (TRT-ViT) to acquire image information at a later stage is much more efficient, which can be more helpful to identify complex lesions in images to achieve higher classification performance.


**Table 5.** The classification results of different models on the COVID-19 Radiography database.

**Table 6.** The classification results of different models on the Coronahack dataset.


Figure 4 shows the confusion matrix for different models on the test set of the COVID-19 Radiography database. The identification results for COVID-19, normal, and Viral can be visualized from the confusion matrix. The confusion matrix shows that the CXR images in the test set are substantially concentrated on the diagonal, indicating that these images are correctly classified into the categories to which they belong. Meanwhile, it can also be seen that the BoT-ViTNet model has 21 misclassified COVID-19 cases and correctly predicted 1426 COVID-19 cases with an error rate of only 1.47%. Consequently, the BoT-ViTNet model can effectively and robustly identify COVID-19 cases.

Figure 5 shows the confusion matrix for different models on the test set of the Coronahack dataset, which indicates that the BoT-ViTNet model has 9 misclassified pneumonia cases and correctly predicted 860 pneumonia cases with an error rate of only 1.46%. Consequently, BoT-ViTNet can effectively and robustly identify pneumonia. The data in Figure 5 illustrates the good classification performance of BoT-ViTNet.

**Figure 4.** The confusion matrix on the test set of the COVID-19 Radiography database. (**a**) VGG16; (**b**) ResNet50; (**c**) DenseNet121; (**d**) [41]; (**e**) [42]; (**f**) [43]; (**g**) ViT; (**h**) Swin Transformer; (**i**) Uniformer; (**g**) AlterNet; (**k**) BoT-ViTNet.

**Figure 5.** The confusion matrix on the test set of the Coronahack dataset. (**a**) VGG16; (**b**) ResNet50; (**c**) DenseNet121; (**d**) [41]; (**e**) [42]; (**f**) [43]; (**g**) ViT; (**h**) Swin Transformer; (**i**) Uniformer; (**j**) AlterNet; (**k**) BoT-ViTNet.

#### 4.4.2. Comparison of Loss Curves of Different Models

Figures 6 and 7 show the loss curves of ResNet50, ViT, AlterNet, and BoT-ViTNet on the COVID-19 Radiography database and Coronahack dataset during the training process. The curves in Figures 6 and 7 show that BoT-ViTNet has a relatively faster convergence rate, whereas the ViT model has the slowest convergence rate. The AlterNet and ResNet50 also converge at a similar rate. These results suggest that the BoT-ViTNet model has a shorter training time, lower train loss curve, and faster rate of convergence on the same dataset, thus achieving a local optimum and improving the training efficiency of the model.

**Figure 6.** Loss variation curves of different models on the COVID-19 Radiography database.

**Figure 7.** Loss variation curves of different models on the Coronahack dataset.

4.4.3. Analysis of Classification Results by Each Category

Tables 7 and 8 represent the specific performance of BoT-ViTNet for each category on the COVID-19 Radiography database and Coronahack dataset.

**Table 7.** Recognition results of the BoT-ViTNet model on the COVID-19 Radiography database.



**Table 8.** Recognition results of the BoT-ViTNet model on the Coronahack dataset.

As can be seen from the experimental results in Table 7, BoT-ViTNet has achieved high classification results for COVID-19 case recognition on the COVID-19 Radiography database, with accuracy, sensitivity, specificity, and F1-score of 98.55%, 99.00%, 99.67%, and 98.77% respectively. Table 8 indicates that the BoT-ViTNet model has achieved high classification results for pneumonia case recognition on the Coronahack dataset, with accuracy, sensitivity, specificity, and F1-score of 98.85%, 98.96%, 96.83%, and 98.90% respectively. These results illustrate the good recognition performance of BoT-ViTNet.

#### 4.4.4. Ablation Experiments

In this section, we perform ablation experiments to verify the performance impact of introducing the TRT-ViT block and MSA block to replace the Bottleneck block in ResNet50. The results of the ablation experiment on the COVID-19 Radiography database and Coronahack dataset are shown in Tables 9 and 10.

**Table 9.** Results of the ablation experiment on the COVID-19 Radiography database.


**Table 10.** Results of the ablation experiment on the Coronahack dataset.


(1) The effect of the MSA block: To clearly show the positive impact of replacing the partial Bottleneck block with the MSA block on the classification results, we used the Bottleneck block for feature extraction in all the first three stages, as shown in Table 9. For the results of No. 3, the removal of the MSA block in BoT-ViTNet results in a significant degradation of performance on the dataset. In comparison with the results of No. 4, the accuracy, precision, sensitivity, and F1-score of No. 3 decreased by 1.09%, 0.08%, 2.18%, and 1.12%, respectively. In contrast, after introducing the MSA block in the last Bottleneck block of the first three stages, No. 2 improved 0.53%, 1.63%, 0.70%, and 0.77% in accuracy, sensitivity, specificity, and F1-score over the original ResNet50 (No. 1), respectively. This shows that the MSA block is important to improve the performance of the BoT-ViTNet model.

(2) The effect of the TRT-ViT block: We also explored the contribution to the classification results by introducing the TRT-ViT block in the last stage of ResNet50, as shown in Table 9. In comparison with the original ResNet50 (No. 1), the classification results of No. 3 using the TRT-ViT block show a significant improvement in accuracy, precision, sensitivity, specificity, and F1-score by 0.71%, 1.05%, 1.51%, 1.51%, and 1.32%, respectively. These results demonstrate that the TRT-ViT block plays a crucial role in BoT-ViTNet. The data in

Table 10 also shows that the BoT-ViTNet also has fairly good classification performance on the Coronahack dataset.

As mentioned above, the MSA block and the TRT-ViT block can effectively improve the performance of COVID-19 classification in BoT-ViTNet. As shown in Tables 9 and 10, the classification effect of No. 4 is superior to the other settings in most metrics. These results show that the MSA block and the TRT-ViT block are important components of BoT-ViTNet for achieving good classification results.

#### 4.4.5. Robustness of BoT-ViTNet

In order to further verify that the BoT-ViTNet model has better robustness, we added experiments of CXR image classification with different batch sizes. The batch sizes selected in this experiment were 4, 8, 16, 32, and 64 and the dataset was the same as those in the classification accuracy test above. The classification effects of different batch sizes on the COVID-19 Radiography database and Coronahack dataset are shown in Tables 11 and 12.


**Table 11.** Classification results of different batch sizes on the COVID-19 Radiography database.



From the data in Tables 11 and 12, it can be observed that when the batch size is 4, the BoT-ViTNet model has the worst classification effect and the classification accuracy is 96.39% and 96.03%. When the batch size is 64, the BoT-ViTNet model has the best classification effect with a classification accuracy of 98.91% and 98.40%, which greatly improves the classification performance. During the training process, the batch size being too small will cause a long training time and gradient oscillation, which is not conducive to the convergence of the model parameters.

#### **5. Conclusions**

In this paper, we designed a BoT-ViTNet model for COVID-19 image classification based on the ResNet50. Firstly, the MSA block is introduced in the last Bottleneck block of the first three stages of ResNet50 to enhance the ability to model global information. Then, to further enhance the correlation between features and the representation of features, the TRT-ViT block, which consists of Transformer and Bottleneck, is used in the final stage of ResNet50 to fuse global and local information for improving the recognition of complex lesion regions in CXR images. Finally, the extracted features are delivered to the global average pooling layer for global spatial information integration in a concatenated way and used for classification. The experimental results of image classification on the publicly accessible COVID-19 Radiography database and Coronahack dataset show that BoT-ViTNet model achieves the better results. The overall accuracy, precision, sensitivity, specificity, and F1-score of the BoT-ViTNet model on the COVID-19 Radiography database are 98.91%, 97.80%, 98.76%, 99.13%, and 98.27%, respectively. The BoT-ViTNet model has better

recognition effect on COVID-19 with 98.55%, 99.00%, 99.67%, and 98.77% in precision, sensitivity, specificity, and F1-score, respectively. Compared with other classification models, the BoT-ViTNet model has better performance in recognizing and classifying COVID-19 images. Although the BoT-ViTNet model can achieve good results for the classification of COVID-19 images, further clinical studies and tests are still required.

**Author Contributions:** Methodology, H.C., T.Z. and R.C.; conceptualization, T.Z.; software, T.Z. and R.C.; validation, H.C., Z.Z. and X.W.; writing—original draft preparation, H.C., T.Z., R.C. and Z.Z.; writing—review and editing, H.C., T.Z., R.C. and X.W. All authors have read and agreed to the published version of the manuscript.

**Funding:** This research was supported by The National Natural Science Foundation of China grant number 61170060, the Key teaching research project of Anhui province grant number 2020jyxm0458.

**Data Availability Statement:** Not applicable.

**Conflicts of Interest:** The authors declare no conflict of interest.

#### **References**


**Disclaimer/Publisher's Note:** The statements, opinions and data contained in all publications are solely those of the individual author(s) and contributor(s) and not of MDPI and/or the editor(s). MDPI and/or the editor(s) disclaim responsibility for any injury to people or property resulting from any ideas, methods, instructions or products referred to in the content.

### *Article* **Low-Resource Malware Family Detection by Cross-Family Knowledge Transfer**

**Yan Lin 1,\*, Guoai Xu 2,\*, Chunlai Du 3, Guosheng Xu <sup>1</sup> and Shucen Liu <sup>1</sup>**


**Abstract:** Low-resource malware families are highly susceptible to being overlooked when using machine learning models or deep learning models for automated detection because of the small amount of data samples. When we target to train a classifier for a low-resource malware family, the training data using the family itself is not sufficient to train a good classifier. In this work, we study the relationship between different malware families and improve the performance of the malware detection model based on machine learning method in low-resource malware family detection. First, we propose an empirical supportive score to measure the transfer quality and find that transferring performance varies a lot between different malware families. Second, we propose a Sequential Family Selection (SFS) algorithm to select multiple families as the training data. With SFS, we only transfer knowledge from several supportive families to target low-resource families. We conduct experiments on 16 families and 4 malware detection models, the results show that our model could outperform best baselines by 2.29% on average and our algorithm achieves 14.16% improvement in accuracy at the highest. Third, we study the transferred knowledge and find that our algorithm could capture the common characteristics between different malware families by proposing a supportive score and achieve good detection performance in the low-resource malware family. Our algorithm could also be applicable to image detection and signal detection.

**Keywords:** machine learning; knowledge transfer; malware detection

#### **1. Introduction**

Android is an open-source system framework. It has become one of the most popular mobile ecosystems. With the popularity of mobile devices in daily life increasing every year, more researchers are paying attention to the security of the Android ecosystem [1]. Various malware detection approaches have been raised in our community.

Malware attempts to control the user's use system without authorization, steal personal information, encrypt important electronic files or cause other damages. A malware family is composed of malware samples with common characteristics. Common characteristics usually include the same code segment, pattern, application characteristics and similar behavior. The number of each malware family varies from just a few to tens of thousands. Low resource malware family refers to a family with less data, and its data are not enough to train the malware detection model alone. Although low-resource malware families have less data volume, these malware families can still bring great software security risks. When we target to train a classifier for a low-resource malware family, the training data using the family itself are not sufficient to train a good classifier. Androzoo [2] currently contains 17,927,200 different APKs with hundreds of malware families. However, some malware families have less than 100 or 500 apks, which is not enough to train a good malware classifier. For the unpopular malware, we may never find enough samples even if we label all the existing apks. At the same time, it is hard for us to have sufficient labels in time with the continuous evolution of malware.

**Citation:** Lin, Y.; Xu, G.; Du, C.; Xu, G.; Liu, S. Low-Resource Malware Family Detection by Cross-Family Knowledge Transfer. *Electronics* **2022**, *11*, 4148. https://doi.org/10.3390/ electronics11244148

Academic Editor: Aryya Gangopadhyay

Received: 21 October 2022 Accepted: 29 November 2022 Published: 12 December 2022

**Publisher's Note:** MDPI stays neutral with regard to jurisdictional claims in published maps and institutional affiliations.

**Copyright:** © 2022 by the authors. Licensee MDPI, Basel, Switzerland. This article is an open access article distributed under the terms and conditions of the Creative Commons Attribution (CC BY) license (https:// creativecommons.org/licenses/by/ 4.0/).

There are some widely used malware data sets, in which a large number of malware families have only a few data samples. For example, MalGenome [3] dataset covers 49 families, and each family contains 1 to 309 malware samples. The top three families occupy roughly 70% of the overall dataset, while over 30 families have less than 10 samples. The distribution suggested that, as long as the detection approach can successfully detect the top families, the overall result will be good enough. The malware families with few samples are ignored. If we directly use all the training datasets from other malware families, most of the malware detection models may not be robust enough to transfer between different malware families. To detect low-resource malware families, Tran et al. first used prototype learning to create prototype representations for the target malware family, and used the twin network to classify malware [4]. Subsequent work further improved the generation of prototype representations [5], or improved the training of twin networks using meta-learning [6–8] and contrastive loss functions [9]. Kamaci et al. [10] established novel distance concepts to measure the relative difference between two objects. Alsboui [11] proposed a graph-based dynamic multi-mobile agent itinerary planning method to cover all nodes in the network. However, these methods use only a small number of samples related to the target low-resource malware family for model training, ignoring the relationship between the target malware family and a large number of existing malware families. In this paper, we seek to study the relatedness of malware families and leverage the relatedness to improve the malware detection performance of low-resource families. Our work focuses on three research questions:

First, does use different malware families as training datasets could help the detection of target low-resource malware? Intuitively, training with similar malware tends to achieve better transferring to the low-resource malware family and dissimilar malware may even harm the performance. We propose to measure the similarity with empirical experiments. Specifically, we train a malware detection model with one family *mtrain* and test on the target family *mtest*, and define the test performance as the supportive score from *mtrain* to *mtest*. Our work shows that the transferring performance varies a lot between different malware families, and we could achieve good performance by selecting the family with the biggest supportive score.

Second, we further study whether it is more helpful to use multiple malware families as the training set? We found that if we neglect the differences between distinct malware families and train the model with all families in the possible training data, the malware detection performance may even be worse than only selecting one most supportive malware family. We propose a Sequential Family Selection (SFS) algorithm to carefully select multiple families as the training set. Our algorithm could be easily adapted to any detection model. We conduct experiments to validate its performance and test on 16 malware families and four representative detection models csbd, drebin, mamadroid, and droidsieve. Our results show that SFS improves the performance of all the malware detection models. We also evaluate the performance on datasets from future time and SFS still achieves better performances.

Third, we try to understand why the supportive score between some malware families is higher which means having better transferring performance. We hypothesize that this is because of the similar characteristics between different malware families. We study two popular characteristics about whether malware steals user data and whether it displays advertisements. We found that malware with the same characteristics tends to have high supportive scores. Most supportive relations are the same for different malware detection algorithms, while some are varied for different detection models. Our work makes three contributions:

1. We make the first systematic study of the relatedness between malware families. We propose to measure the malware family similarity with an empirical supportive score and find it is the key to good transfer performance.


#### **2. Related Work**

In this section, we first overview the common Android malware detection methods. Then, we discuss low-resource malware detection.

#### *2.1. Android Malware Detection Based on Machine Learning*

Many researchers have studied various malware detection methods in our daily lives. Malware detection methods are emerging rapidly, such as methods based on static features [12–16], methods based on dynamic features, etc. [17–21].Static features of applications include API calls, permissions, opcode, etc. These are extracted by analyzing the structure of applications. Dynamic features such as system calls, behavior characteristics, network traffic, etc. [22].These features are extracted during the period when the application is running. Mudflow [23] uses the flows between APIs as the malware features to detect malware. Deep4maldroid [24] leverages the constructed graph to train malware detection models. DroidAPIMiner [25] provides a lightweight malware classifier by conducting a thorough analysis of apks at the API level. However, these Android malware detection methods are more concerned with the performance of the algorithm on the overall dataset and ignore the malware detection performance in the low-resource malware family.

#### *2.2. Low-Resource Malware Detection Based on Machine Learning*

To enhance the detection of low-resource malware families, several researchers have improved models to enhance the detection of low-resource malware families. In 2019, Tran et al. used prototype learning to create prototype representations for target malware families and used twin networks for malware classification [4]. The improvement of the subsequent work is divided into two main directions; first, to further improve the generation of prototype representations, Chai et al. used dynamic prototype networks to generate prototype representations [5] and Tang et al. used multilayer convolutional neural networks to generate prototype representations [6]. Second, to better train the network, Bai et al. used a contrast loss function to better train the twin network [9]. In addition, Tran et al. used meta-learning to train memory neural networks for malware family classification [26]. However, all these methods use only samples of the target malware family for model training and prediction, but ignore data samples of malware families related to the target malware family.

To address the problem of low-resource malware detection, some other researchers have increased the data samples of low-resource malware families by generating new data [27–29]. Zahra et al. used generative adversarial networks to increase the data samples of low-resource malware families by generating new sample signatures of malware [30]. Chen et al. proposed a malware detection model called Adv4Mal, which generates new data based on specific signatures of malware to supplement the training data of the lowresource malware family [31]. These methods use artificially constructed data, while this paper uses real data related to the target malware family.

Table 1 shows the differences between these related work of low-resource malware detection based on machine learning method. This paper proposes sequential family selection algorithm that does not require the generation of any forged new data, but rather supplements the existing data with low-resource malware families. Meanwhile, the sequential family selection algorithm uses knowledge transferring among malware families to select

relevant real data samples to improve the performance of low-resource malware family detection. Theoretically, the sequential family selection algorithm can be combined with the above malware detection methods and further improve its detection performance on low-resource malware families.


**Table 1.** Low-resource malware detection based on machine learning.

#### **3. Methodology**

In this section, we first introduce how to measure the similarity between different malware families. We could achieve good transferring performance on low-resource families by selecting a most similar malware family. Then we introduce a Sequential Family Selection (SFS) algorithm to select multiple families as training data and achieve better performance.

#### *3.1. Malware Family Similarity*

It is noticed that the malware detection performance differs significantly with different malware families in the training set when using the same malware detection algorithm. Moreover, the impact of different families differs while malware detection methods vary. It is interesting to explore the similarity between malware families. There are two general ways to measure the similarity between malware families, malware characteristics and empirical metrics.

Researchers obtain the characteristics of mobile applications through dynamic analysis and static analysis, including operation code, API calls, behavior characteristics et al. Malware families with similar characteristics are more likely to transfer knowledge to each other. The method of determining the similarity between malware families based on

features is interpretable. However, it is hard to define all the characteristics of android applications, especially for the rarely studied low-resource malware families. Futhermore it is usually time-consuming to analyze the malware applications. We propose an empirical supportive score to measure the transfer quality. Specifically, we train a malware detection model with one family *mtrain* and test on the target family *mtest*, and define the test performance as the supportive score from *mtrain* to *mtest*. To achieve good performance on a given test set, we could select the family with the biggest supportive score as training data.

This quantitative relationship between malware families directly corresponds to the performance of malware detection performance and could help to improve the detection performance. We further explore the relationship between the characteristics of malware families and the supportive score. We find that the supportive score is highly correlated to the human summarized characteristics.

#### *3.2. Sequential Family Selection (SFS) Algorithm*

In this section, we first formalize the problem. Then we propose two baselines, the most supportive family only and the training with all the malware families. We further propose a new malware family selection algorithm to carefully select multiple families as the training set. Finally, we compare the performance of these four malware detection methods and validate our research questions. Formally, we target to test the performance of a target malware family *mt*, which have a validation and test dataset. The training data include a set of malware families *Strain* = {*m*1, *m*2,..., *mn*}, and each malware family corresponds to a training dataset *Di*, which contains malware from the corresponding family and randomly sampled benign. The benign samples of different groups of training data are not overlapping. We have two baseline settings for this problem. The first one is to train all the malware families in the training dataset and test on the target malware family. This is equivalent to neglecting the differences between different malware families. The second one is to find a training family which is the most supportive one to target the malware family. We empirically calculate the supportive malware family from the malware family *ma* to malware family *mb* by training on *ma* and test on *mb*.

To achieve better performance, we propose a Sequential Family Selection algorithm (SFS). SFS's target is to select from a subset of malware families in training data. SFS starts from an empty set and selects families one by one. In each step, we will try to combine each candidate with the selected families and evaluation on the target test set. We select the best family with the best performance, i.e., with the biggest supportive score, and add it to the selected set. Formally, we initialize the selected family dataset *Sselected* as empty set before selection, and all the families in the training data are in the candidate set. In the first step, we train each malware family in the candidate set separately and evaluate the target malware family *mt*. We add the malware family with the best performance into *Sselected*. We also will remove the selected family from the candidate set. Secondly, we combine *Sselected* with the other malware families in the candidate set separately and evaluate on *mt*. We select the malware family with the best combinations and add the malware into *Sselected*. In the following steps, we repeat iteratively try to combine the malware families in the candidate set with *Sselected* and add the best family into *Sselected*. The algorithm terminates when we add all the families into *Sselected*. Finally, we choose the combination with the best performance in all *Sselected* history and return it as our final selection. We further describe the whole algorithm in Algorithm 1. Our algorithm is independent of the malware detection model and could improve the performance of any model.

#### **Algorithm 1** Sequential Family Selection algorithm (SFS).

**Input:** A malware classifier *C*(). A target malware family *mt*, includes a validation set *Dvalid* and a test set *Dtest*. A set of training malware families *Strain* = {*m*1, *m*2,..., *mn*}, each *mi* corresponds to a set of training samples *Di*.

**Output:** A subset of *Strain*.


#### **4. Experimental Setup**

#### *4.1. Malware Detection Approaches*

Our method is designed to benefit malware detection by studying the relationship between the malware families. As for the malware detection approaches, we apply four popular malware detection methods csbd, drebin, mamadroid, and droidsieve.

Csbd [19] extracts the control flow graph as the features of malware detection. It was first proposed by Kevin from the University of Luxembourg. This method performs the static analysis on the application bytecode to extract the control flow graph, takes the basic blocks of the control flow graph as features of the application, and uses classification algorithms in machine learning to assign it. Drebin [20] uses multiple static analysis approaches to extract multiple features of the application from disassembled code and the information file as the features for Android malware detection. It is a lightweight detection method for Android malware. SVM algorithm is used to automate the classification of applications. Mamadroid [21] is a malware detection model based on application behavior. This method extracts and abstracts the sequence of calls between APIs in an application, constructs feature vectors based on Markov chain, and uses different classification algorithms in machine learning to assign applications. Droidsieve [22] utilizes the confusion invariant features and artifacts introduced by the confusion mechanism used in malware attacks to classify malware.

#### *4.2. Malware Corpus*

To prepare for our study on malware detection, we collect a set of Android applications from Androzoo, an open Android datasets collection project [2]. The Android apks from Androzoo were obtained from various app markets. As VirusTotal is broadly utilized for Android malware labeling, we use the scanned results of VirusTotal to resolve the labels of the collected Android apks. VirusTotal uses more than 70 anti-virus scanners and URL/domain block list services to check items [32]. The Android apk is labeled as benign software when no engines in VirusTotal marks it as positive. In order to ensure the reliability of the collected dataset, we label the apk as malware when at least five engines in VriusTotal label this apk as malicious. We use Euphony [33] to acquire more information about the malware family of each data in the collected apks. We collect a dataset containing over 20,000 malware with family-type information and 20,000 benign software, and the time span of the dataset is 2015 to 2016. Note that, we will construct different groups of training and test sets targeting diverse cases we have considered in this paper.

*4.3. Testing Dataset*

Many different malware families exist in our community. We pick 16 random families from the malware corpus for our study. To ensure the consistency of our experiments, the number of each malware family is 500. We then combine 500 benign software with each malware family to construct the test set. The collected malware families belong to different categories with different malicious characteristics. Table 2 shows the name and the description of the selected malware family. Note that the malware family in the test set does not appear in the training dataset.


**Table 2.** The detailed description of the selected malware family.


#### **Table 2.** *Cont.*

#### **5. Results and Discussions**

In this section, we illustrate the impact of different malware families in training set on the malware detection and obtain the supportive score of the 16 malware families.

#### *5.1. Dataset Construction*

We construct 16 malware benchmarks to be used as the 16 test sets. Each test set consists of 1000 apks with 500 malware and 500 benign. The dataset used in this section is from 2015. The test set of each experiment contains merely one malware family. To minimize the variability of the experimental results, we perform each set of experiments five times. We then use the average of the five results as the final outcome of each experiment.

#### *5.2. Supportive Score of Malware Families*

We further analyze the impact of malware detection performance when using distinct malware families for training. Tables 3, 4, 5, and 6 shows the accuracy of our experiment results for the different malware families in the training set. We bold the maximum values in Tables 3, 4, 5, and 6 for the experimental groups that test the same malware family, which is the supportive score of the corresponding malware family.

It is interesting to find that accurate malware detection is possible when the malware families in the training and test sets are completely different. However, the malware detection performance differs between the four algorithms. The accuracy of malware detection between 37.7% pairs in csbd is higher than 55%, 79.1% pairs in drebin are higher than 55% and 51.6% pairs in mamadroid are higher than 55%. Moreover, 60.8% pairs in droidsieve are higher than 55%. It can be seen that only 25, 109, 36, and 72 (10.4%, 45.4%, 15%, and 30% in proportion) groups of experiments yield an accuracy of over 80%. The accuracy of malware detection is highly related to the malware family used in the training dataset. For example, gingermaster works the best and revmob performs the worst when testing the malware family ginmaster in csbd, with a difference of 22.2% in the terms of accuracy. The training set with leadbolt could achieve an accuracy of 92.3% when detecting plankton in drebin, but only 48.7% when using admogo as the training set.

**Finding 1:** It is possible to transfer knowledge between different families. The malware detection performance to a target malware family is dependent on the relation to the training family. Meanwhile, the malware detection performance is also related to the malware detection algorithm.


**Table 3.** Malware detection performance with different malware families in the training set using csbd. The number in bolds show the best results.

**Table 4.** Malware detection performance with different malware families in the training set using drebin. The number in bolds show the best results.


**Table 5.** Malware detection performance with different malware families in the training set using mamadroid. The number in bolds show the best results.


**Table 6.** Malware detection performance with different malware families in the training set using droidsieve. The number in bolds show the best results.


#### *5.3. Low-Resource Malware Family Detection*

In this section, we investigate the performance of our algorithm in the case of lowresource malware detection. We apply SFS algorithm to four malware detection method, csbd, drebin, mamadroid and droidsieve. Each malware detection method is performed in 16 sets of experiments.

Low-resource malware families refer to the malware families that have a small amount of data but do not have enough data to train a malware detection model. The target malware family with only 10 samples in the training set, can be considered as a low-resource malware family. Other malware families of each group of experiments have 500 samples. We first conduct the experiment using data from 2015.

For a detailed discussion, we illustrate the SFS process of the detection of plankton using drebin. The leftmost part of Figure 1 shows the first step of the SFS algorithm. We train the other 15 families separately. It can be seen that training on leadbolt could perform better than any other malware families. We take leadbolt as the base set of our next iteration of SFS. The rest of the malware families are combined with the base set leadbolt, separately. The second part of Figure 1 shows that adding gingermaster to leadbolt could perform the best. The combination of leadbolt and gingermaster is taken as a new training dataset for the next round of the experiment. This is repeated until all the malware families are added to the training set. The best solution for this method is "leadbolt + gingermaster + ginmaster + utchi + wapsx + mulad + droidkungfu".

**Figure 1.** Study of SFS for plankton using Drebin.The number in bolds and the red box show the best results.

To ensure the validity of the experiments, when a certain number of malware samples are added in each round of experiments, we then combine them with the same number of benign samples. We also perform each experiment five times to minimize the error of the experimental results. The results of each experiment are obtained by taking the average of the five experiments. We compare the best malware detection performance of training with only one family, the malware detection performance trained with all families, and the malware detection performance using our algorithm. Table 7 shows the comparison of the malware detection performance between the three cases. It can be seen that SFS algorithm outperforms when applied to all the four malware detection methods. In particular, some low-resource malware family detection can be improved by over 10% in the terms of accuracy using our SFS algorithm, such as the detection of malware droidkunfu with Csbd and the detection of malware umeng with Drebin. The malware detection performance of different malware detection algorithms varies when malware families in the dataset are different. For example, the malware detection accuracy of Droidsieve is greater than 90% in 13 out of the 16 groups of experiments. With our SFS algorithm, the detection results using Droidsieve for artemis and ginmaster are improved by 6.4% and 8.12%, respectively. We also compare the average malware detection accuracy of the 16 malware families. The performance of the malware detection method using SFS algorithm has been improved. Mamadroid with SFS algorithm has the highest improvement in malware detection, which can reach 6.73%.

The Android mobile applications are constantly evolving with the rapid development of technology [52]. New malware is often constantly updated to evade malware detection. This leads to the fact that malware detection algorithms that can achieve very good results in one year may not be able to classify new malware produced in the next year.

**Table 7.** The comparison of low-resource malware family detection. The number in bolds shows the best results.


To explore the sustainability of SFS algorithm, we use the model trained by the data collected in 2015 to test the data in 2016, that is, the malware detection classifier obtained by using outdated data is trained to detect future malware data samples by observing the performance of malware detection algorithm to further evaluate the sustainability of SFS algorithm.

It can be seen from Table 8 that SFS algorithm does not work well when some malware families, such as mulad and umeng, use the outdated data for training, but SFS algorithm still has a better detection performance for most low-resource malware family detection. Table 8 shows that SFS algorithm performs the best on average. It can be assumed that SFS algorithm can support the sustainability of malware detection algorithms. From the average value, the combination of SFS algorithm with csbd, drebin, mamadroid and droidsieve can improve their malware detection accuracy by 4.11%, 0.76%, 3.36% and 4.26%, respectively. For a specific malware family and detection algorithm, SFS algorithm can greatly improve its performance. For example, the malware detection algorithm csbd trained the malware detection classifier using data from 2015, and SFS algorithm still improve the detection accuracy of the malware family droidkungfu by 13.32% when detecting the malware family droidkungfu in 2016.


**Table 8.** The malwaredetection ability when using the outdated training data. The number in bolds shows the best results.

#### *5.4. Zero-Resource Malware Family Detection*

In this section, we further investigate the most extreme case of the low-resource malware family detection, which is the zero-resource malware family detection. Zeroresource is the most extreme case of low-resource malware family detection. Zero-resource detection means that the malware detection model has never seen this malware family, that is, the zero-resource malware family is not included in the training dataset. Table 9 shows the comparison of malware detection performance when the malware family to be detected does not exist in the training data set. It shows that SFS algorithm performs the best in all of the experiments when the malware family to be detected does not exist in the training dataset.

In 60.4% of the experiments, the malware detection performance training on all malware families is worse than the one training based on SFS. This shows that we cannot simply improve the malware detection performance on the target family by adding more but unrelated malware data. Using all malware families for training is equivalent to ignoring the differences between malware families in the hope of using one model to detect all malware. Our results also show this is a bad solution for malware detection.

The performance of SFS is higher than the two baselines in all of the experiments. This shows the effectiveness of SFS. This shows that carefully selecting multiple malware families is better than only selecting one most supportive family.

Figure 2 shows the performance of plankton in terms of accuracy. The size of the training set for the four malware detection methods in each experiment is the same. The horizontal coordinate indicates the name of the malware family for which the highest accuracy can be achieved by adding the target malware family. We could see that the malware detection accuracy of plankton reaches the highest in the middle of the process of SFS.


**Table 9.** The comparison of zero-resource malware family detection. The number in bolds shows the best results.

**Figure 2.** The performance of malware detection methods using SFS to detect plankton. (**a**) csbd. (**b**) drebin. (**c**) mamadroid. (**d**) droidsieve.

**Finding 2:** Using multiple malware families without selection may harm the detection performance. However, carefully selecting the family combination with SFS could improve the performance. Selected a subset of training data with our algorithm is better than with all training datasets.

#### **6. Relationship between Malware Families**

Our results show that it could be supported by training with different malware families for the low-resource malware detection. However, the performance varies between different malware families. Therefore, we further dwell on understanding under what circumstances could support the low resource malware detection? We hypothesize that the knowledge transfer between different malware families is because they have similar characteristics. In this section, we study two popular characteristics: whether malware steals user data and whether malware displays advertisements to mobile users.

#### *6.1. Malware Categories*

**Steal Data.** Some of the malware families could steal user information from the device. The report [53] points out that 60.7% of the applications collected Android ID and other unique device identification information, 55.4% of the applications collected application list information, 13.7% of the applications collected clipboard information, such information can be used for character portraits, personalized push, and other business. The sensitivity of this kind of information is relatively high.

**Display Advertisements.** Some of these malware families are sorted as adware. It is a malicious application that puts unneeded ads on users' screens, especially when accessing web services. Adware lures users to view ads that offer lucrative products and entices them to click on that ad. Once the user clicks on the ad, the developer of the unwanted application generates revenue. Some common examples of adware include weight loss programs that make money in a shorter period of time and on-screen warnings about fake viruses.

#### *6.2. Analysis*

Table 10 shows the special characteristics of the 16 malware families. We label the malware with the characteristics of stealing user data or displaying advertisements as '◦' and those without these characteristics as '×'. To better show the relationship between the malware detection performance and the malware characteristics, we further leverage TSNE to map the malware detection performance in Table 8 into two-dimension vectors. The results are shown in Figure 3 for all the four malware detection models.

In Figure 3, we could find that (1) most of the malware families with the same characteristics are close to each other for all the four malware classification models. For example, the malware families that steal user data such as ginmaster, gingermaster, and nandrobox are close to each other, and plankton and leadbolt are also close to each other. The malware families that display ads such as waps, wapsx, and umeng are always close to each other. This shows that most of the similarities could be captured by all the malware detection models. (2) The malware relatedness has a slight difference between different malware detection algorithms. For example, for both csbd and mamadroid, mulad is far from the ads cluster. However, in drebin and droidsieve, mulad is close to the ads cluster, while admogo is the opposite. This may be because that drebin and droidsieve are not good at capturing the corresponding similar characteristics between admogo and other ads malware.

**Figure 3.** The relationship between the malware detection performance and the malware characteristics. Red color represents the characteristics of displaying ads and blue color represents the characteristics of stealing user data. (**a**) csbd. (**b**) drebin. (**c**) mamadroid. (**d**) droidsieve.

The results show that most of the supportive scores match our human knowledge. A higher supportive score means the characteristics of different malware are more similar. If the target malware family uses similar technology or has similar targets to the training family, our model could leverage this knowledge and achieve good results.


**Table 10.** The characteristic of the malware families. The malware with the characteristics of stealing user data or displaying advertisements are labeled as '◦' and those without these characteristics are labeled as '×'.

Although summarizing common characteristics could improve the human understanding of the knowledge transfer, using an empirical supportive score is a better way if we only target to improve the low-resource malware family performance. There are two reasons: (1) Specific malware detection models may not be good at capturing the common characteristics even if they exist. For example, csbd and mamadroid show that mulad is similar to ads cluster while drebin or droidsieve cannot. In this case, it is better to use a different family for drebin or droidsieve algorithm. (2) Summarizing all characteristics need a huge effort from experts, and experts may have less interest to study the low-resource malware families because they often have less impact. In this case, an empirical supportive score is much easy to get and only costs computer resources.

**Finding 3:** We validate that the knowledge transfer between different malware families is because of their shared characteristics. We study two characteristics, stealing user data and showing ads, and show that the knowledge transfer between malware families with the same characteristics is better than others. Our study could partially explain current experiment results. If we hope to fully explain all results, we need to study more characteristics with our methodology.

#### **7. Conclusions and Future Work**

Our work studies the cross-family knowledge transfer for low resource malware family detection. We quantify the knowledge transfer ability between malware families by supportive scores. We propose Sequential Family Selection algorithm to select multiple malware familes related to the target malware family to support low resource malware family detection based on the supportive scores of different malware families. The experiment shows that the Sequential Family Selection algorithm can better improve the performance of the malware detection model based on machine learning method in low-resource malware family detection. The research in this paper demonstrates that cross-family knowledge transfer can effectively improve the detection performance of

low-resource malware. Furthermore, by analyzing the two behavioral characteristics of stealing user data and displaying advertisements, it could be found that the knowledge transfer between different malware families is due to their common characteristics.

In future work, we plan to set different weights on each malware family in the training dataset. The weights are based on the contribution to the target malware family detection. Each target malware family can select some specific families to improve its detection performance. New knowledge transfer methods can also be further explored to achieve better detection results for low-resource malware detection.

**Author Contributions:** Validation, methodology, Writing—original draft, Y.L.; Writing—review & editing, G.X. (Guoai Xu) and G.X. (Guosheng Xu); Validation, S.L. and C.D. All authors have read and agreed to the published version of the manuscript.

**Funding:** This work was supported by National Natural Science Foundation of China under grant of No. 62172006 and the National Key Research and Development Program of China (grants No.: 2021YFB3101500).

**Data Availability Statement:** Not applicable.

**Conflicts of Interest:** The authors declare no conflict of interest.

#### **References**


### *Article* **Energy-Efficient Edge Caching and Task Deployment Algorithm Enabled by Deep Q-Learning for MEC**

**Li Ma, Peng Wang, Chunlai Du and Yang Li \***

School of Information Science and Technology, North China University of Technology, Beijing 100144, China **\*** Correspondence: li\_yang@ncut.edu.cn

**Abstract:** Container technology enables rapid deployment of computing services, while edge computing reduces the latency of task computing and improves performance. However, there are limits to the types, number and performance of containers that can be supported by different edge servers, and a sensible task deployment strategy and rapid response to the policy is a must. Therefore, by jointly optimizing the strategies of task deployment, offloading decisions, edge cache and resource allocation, this paper aims to minimize the overall energy consumption of a mobile edge computing (MEC) system composed of multiple mobile devices (MD) and multiple edge servers integrated with different containers. The problem is formalized as a combinatorial optimization problem containing multiple discrete variables when constraints of container type, transmission power, latency, task offloading and deployment strategies are satisfied. To solve the NP-hard problem and achieve fast response for sub-optimal policy, this paper proposes an energy-efficient edge caching and task deployment policy based on Deep Q-Learning (DQCD). Firstly, the pruning and optimization of the exponential action space consisting of offloading decisions, task deployment and caching policy is completed to accelerate the training of the model. Then, the iterative optimization of the training model is completed using a deep neural network. Finally, the sub-optimal task deployment, offloading and caching policies are obtained based on the training model. Simulation results demonstrate that the proposed algorithm is able to converge the model in very few iterations and results in a great improvement in terms of reducing system energy consumption and policy response delay compared to other algorithms.

**Keywords:** deep Q-learning; edge caching; task deployment; computing offload; edge computing

#### **1. Introduction**

With the advent of the 5G era, the internet and mobile devices have undergone great development, generating massive amounts of computing tasks and data. However, due to the limited computing ability of mobile devices, the delay in offloading computing tasks to remote cloud servers is too large to meet the needs of delay-sensitive tasks such as cognitive assistance, mobile gaming and virtual augmented reality (VR/AR) [1]. Mobile edge computing pushes computing tasks and data from centralized cloud computing to the edge of the network [2] so that task processing and generation are closer to the device, which improves the utilization of underlying resources and the QoS of users, which can well solve the above problems.

Compared with mobile devices, edge servers have larger computing ability and storage. Mobile devices can offload delay-sensitive tasks to edge servers for computing, which not only meets user needs but also reduces the energy consumption of mobile devices. In [3], an algorithm based on global search is used to find the optimal offloading strategy to achieve the goal of maximizing the weighted sum of task completion delay and system energy consumption. At the same time, in the research of computing offloading, containers have the characteristics of a light weight, fast deployment speed, small footprint, high

**Citation:** Ma, L.; Wang, P.; Du, C.; Li, Y. Energy-Efficient Edge Caching and Task Deployment Algorithm Enabled by Deep Q-Learning for MEC. *Electronics* **2022**, *11*, 4121. https:// doi.org/10.3390/electronics11244121

Academic Editor: Fernando De la Prieta

Received: 30 October 2022 Accepted: 9 December 2022 Published: 10 December 2022

**Publisher's Note:** MDPI stays neutral with regard to jurisdictional claims in published maps and institutional affiliations.

**Copyright:** © 2022 by the authors. Licensee MDPI, Basel, Switzerland. This article is an open access article distributed under the terms and conditions of the Creative Commons Attribution (CC BY) license (https:// creativecommons.org/licenses/by/ 4.0/).

portability and high security and can provide quality computing services for computingintensive tasks such as virtual augmented reality and mobile games. In [4] multiple containers are proposed in each system for a single application, thereby improving the overall efficiency of the system.

The above works do not study the multi-user multi-edge server environment, especially when the number of tasks is large and when the edge server cannot process a large number of tasks in time and still does not meet the demand for latency. Therefore, user request data can be cached on the edge server, thereby greatly reducing the delay and energy consumption during task processing. In [5], the authors study the management of an MEC cache to improve the quality of service of mobile augmented reality (MAR) services. In [6], a data-driven approach is adopted using model-free reinforcement learning to optimize edge cache allocation.

In the deployment strategy of mobile edge computing, researchers have proposed optimization methods such as dynamic programming [7], genetic algorithm [3], etc. with some improvements, but the execution time of the above algorithms will increase exponentially when the number of users increases. The exponential growth does not meet the needs of rapid response for deployment strategy in reality. With the development of artificial intelligence, optimization methods based on deep reinforcement learning (DRL) have received more and more attention [8]. The biggest advantage of DRL is that it can obtain a sub-optimal deployment strategy and can greatly shorten the algorithm execution time in the multi-MD multi-server environment. Moreover, the above studies have not conducted in-depth research on the joint optimization of edge caching, computing offloading and task deployment strategies.

In general, in the context of massive computing tasks and data, due to the limited computing power of mobile devices, the computing power and storage capacity of edge servers are used to solve the delay and energy consumption of processing tasks on mobile devices. By jointly optimizing computing power, latency, task offloading, task deployment and cache deployment in an edge computing environment, the underlying computing and communication resources are maximized and system energy consumption is minimized. Using DQN to solve this NP-hard and discrete variable problem can not only obtain the suboptimal solution of system energy consumption, but also the processing time of the algorithm can meet the delay requirement. The edge computing environment has multiple edge servers and multiple MDs based on container technology. As shown in Figure 1, in the MEC system, the task of each MD can be executed locally or offloaded to a suitable edge server. At the same time, it is necessary to select an appropriate MD data cache in the edge server to maximize the use of system resources and minimize total energy consumption under the premise of meeting the delay.

**Figure 1.** The multi-MD multi-container multi-server MEC system.

In this paper, we study the problem of container-based edge caching and task deployment in the MEC system. Meanwhile, we utilize DRL to obtain a near-optimal task deployment strategy. Then, a DQN-empowered efficient edge caching and task deployment (DQCD) algorithm is proposed to solve the problem of energy consumption minimization. The main contributions of this paper are as follows:


The remainder of this paper is organized as follows. Section 2 presents the system model and problem formulation. In Section 3, we propose a energy-efficient edge caching and task deployment based DQN (DQCD) algorithm to solve the formalization problem. Section 4 presents the simulation results and discussion. Finally, we conclude our paper in Section 5.

#### **2. Related Work**

Edge computing has been the topic of numerous studies. In [9], a novel permissionless mobile system based on QR codes and NFC technology is proposed for privacy-preserving smart contact tracing. An urban safety analysis system is proposed in [10], which can use multiple points of cross-domain urban data to infer safety indices. These are all studies on the practical application of edge computing, but they did not study the delay and energy consumption in the system. In [11], a mobile phone-assisted positioning algorithm is proposed, and the positioning scheme and the communication between the service areas of the remote control are formulated. In [12], a container engine was suggested that makes it easy to create, manage and delete containerized apps. Running several containers uses little overhead and a small amount of CPU power. The network performance of edge computing is boosted in several studies using container technology. It is suggested to utilize containers to administer and monitor apps [13]. One of the most useful and noticeable aspects of network edge computing is the scalability of containers. The IoT resource management framework uses a federated domain-based strategy in light of this. Resources given by IoT-based edge devices, such as smartphones or smart automobiles, can be used more effectively in containers that are dynamically allocated at runtime.

To solve the shortcomings of mobile devices in resource storage, computational performance and energy efficiency, computational offloading refers to resource-constrained devices that shift resource-intensive computing from mobile devices to resource-rich adjacent infrastructure [14]. In [15], the authors consider the context of a wearable camera uploading real-time video to a remote distribution server under a cellular network, maximizing the quality of uploaded video while meeting latency requirements. Technology for offloading computing minimizes transmission time while also relieving demand on the core network. MEC is able to execute brand-new, sophisticated applications on user equipment (UE), and computation offloading is a crucial edge computing technology. There are numerous relevant research findings, mostly focusing on the two core issues of resource allocation and offloading decisions. Offloading decisions include what to offload from mobile devices, how much to offload and how to offload computing duties. The analysis of where to dump resources is known as resource allocation.

The UE in an offload system typically consists of a decision engine, a system parser and a code parser. Three steps make up the execution of its unload decision, and the ability to uninstall something depends on the type of application and code data partition. The code parser first determines what can be uninstalled. The system parser then monitors various parameters, such as bandwidth availability, the size of data that need to be offloaded or the amount of energy used by running local applications [16]. According to the optimization target of the offloading decision, computational offloading can be divided into three categories: with the goal of lowering latency, with the goal of reducing energy consumption and with the goal of balancing energy consumption and delay. Under the premise of ensuring the delay of the device, the energy consumption of the device is usually the second concern of the user. In [17], energy harvesting methods are introduced for wearable devices. In [18], the data communication power consumption of wearable devices is minimized under the premise of meeting the delay requirements.

Base station caching, mobile content distribution networks and transparent caching are examples of mobile edge caching systems. Technology for cache acceleration can boost content delivery effectiveness and enhance user experience. Users can access nearby content once it has been cached at the mobile network's edge, eliminating repetitive content transfer and relieving demand on the backhaul and core networks. Additionally, edge caching might lessen user request network latency, enhancing the user's network experience. Edge caching can also enable the mobile network resource environment to offer tenants and users enhanced services [19]. A Cognitive Agent (CA) is suggested to assist users in pre-caching and performing tasks on MEC and to coordinate communication and caching to lessen the burden of MEC [20].

According to [21], a coordinate descent (CD) method that searches along one variable dimension at a time is suggested for task deployment decisions for edge computing. In [22], the authors research an analogous iterative binary offloading decision-adjusting heuristic search technique for multi-server MEC networks. Another extensively used heuristic is convex relaxation. For instance, Ref. [23] shows that integer variables can be relaxed to be continuous between 0 and 1, or [24] shows that binary constraints can be approximated by quadratic constraints. On the other hand, a heuristic algorithm with a lower level of complexity cannot ensure the quality of the solution. On the other hand, search-based approaches and convex relaxation methods are not appropriate for rapid fading channels since they both require a large number of iterations to attain adequate local optimums.

Our work is motivated by the cutting-edge benefits of deep reinforcement learning for solving issues in reinforcement learning with vast state spaces and action spaces [25]. The application range of deep reinforcement learning is very wide, and it is very well used in the fields of route planning, indoor positioning, edge computing and so on. In [26], the authors utilize reinforcement learning's multi-objective hyperheuristic algorithm for smart city route planning. The goal of [27] was to automatically draw indoor maps through reinforcement learning for changes in AP signal strength in different rooms. In [28], the authors utilize deep learning techniques to build a fine-grained and facility-labeled interior floor plan system. In [29], the authors use the training of WiFi fingerprints to construct indoor floor plans. Deep neural networks (DNNs) in particular are used to build the best state space to action space mappings by learning from training data samples [30]. Regarding MEC network offloading based on deep reinforcement learning, Ref. [31] suggested a distributed deep learning based offloading (DDLO) technique for MEC networks by utilizing parallel computing. To improve the computational performance for energy-harvesting MEC networks, Ref. [32] presented an offloading approach based on Deep Q-Network (DQN). In [33], the authors examined a DQN-based online compute offloading approach under random task arrival in a comparable network environment. However, these two DQN based works both take the calculation unloading and task deployment decision as the state input vector and do not consider the container and task type or the edge cache.

#### **3. System Modle**

In this paper, we study a multi-MD multi-server environment in a given area, as shown in Figure 1. The container of the edge server can provide computing services to mobile devices with limited resources, and each edge server has a certain storage for user data caching. *M* = {1, 2, 3, ... , *M*} represents a set of edge servers, while *N* = {1, 2, 3, ... , *N*}, represents a set of MDs. The network is divided into several small areas, and *Si* ⊆ *M*, ∀*i* ∈ *N* represent the edge server with which the MD*i* can communicate wirelessly. Services are abstractions of applications requested by MDs. Example services include video streaming, face recognition, augmented reality and others. Computation tasks need to have corresponding container types on the BS.

Assume all MDs are randomly distributed within the coverage of edge servers, and all edge servers are connected by wired links. To allow multiple access, all MDs follow the Orthogonal Frequency Division Multiple Access (OFDMA) protocol. Assuming that the delay constraints of all tasks are the same, equal to *T*, the tasks will be executed within a specified execution deadline, and each MD has an indivisible task within *T*. *R*(*i*) (*Ii*, *Ci*, *Vi*, *Ti*), ∀*i* ∈ *N* represents a task, where *Ii* represents the size of the input data, *Ci* represents the number of CPU cycles required to complete a 1-bit task, *Vi* represents the task type, and *Ti* represents the time constraint of the task. The goal is to perform edge offloading and task deployment under the constraints of task type, computing power, delay and task offloading to minimize the total energy consumption of the system.

#### *3.1. Task Deployment Model*

The task computation of MDs can be accomplished locally or transferred to the optimal edge server in *M*. Assume the MD*i* is within the coverage of an edge server, *Si* ⊆ *M*. At the same time, the decision matrix *a* is introduced, and *ai*,*j*, *i* ∈ *N*, *j* ∈ *M* is defined as a task placement strategy to determine whether the computing task of the MD*i* is offloaded to the server *j*. If *ai*,*<sup>j</sup>* = 1, it means that the task *i* is offloaded on the edge server *j*. Here, *R*(*i*) is transmitted locally to *Si*, and from *Si* to the edge server *j*. This paper ignores the data transmission delay between edge servers caused by the use of optical cable transmission. The task deployment strategy *ai*,*<sup>j</sup>* is denoted as follows:

$$a\_{i,j} \in \{0, 1\}, i \in N, j \in M \tag{1}$$

$$\sum\_{j=1}^{M} a\_{i,j} = 1, i \in N \tag{2}$$

#### *3.2. Edge Cache Model*

We introduce a decision matrix *b*, and *bi*,*j*, *i* ∈ *N*, *j* ∈ *M* define it as a caching strategy to determine whether the data of the MD*i* are cached on the server *j*. If *bi*,*<sup>j</sup>* = 1, it means that the data of the MD*i* are cached on the edge server *j*, and the caching strategy is shown as follows:

$$b\_{i,j} \in \{0, 1\}, i \in N, j \in M \tag{3}$$

At the same time, the storage of each edge server is limited. Assuming that the maximum storage space allocated to each edge server is *Q*max *<sup>j</sup>* , the size of the MD data is *Li*, and there are the following cache constraints:

$$0 \le \sum\_{i=1}^{N} b\_{i,j} L\_i \le Q\_j^{\text{max}} \tag{4}$$

#### *3.3. Task Placement Model*

Different MDs have different task types; at the same time, different edge servers have different container types, which can process a limited number and type of tasks. The container type set of edge server *j* is *Gj*. When the container type of the edge server includes the task type of MD, the computing task *i* or the data of MD*i* can be cached:

$$V\_i \cap G\_j = V\_i, i \in N, j \in M \tag{5}$$

#### *3.4. Transfer Model*

The computing task *R*(*i*) offloads to *Si*, *pi* represents the transmit power of MD*i* in the current time block, *hi*,*<sup>j</sup>* represents the channel gain of the upload task, and *Wi* is the channel bandwidth. The number of MDs in the offload process will not be more than the number of sub-channels; the user is disturbed by the environment as *σ*2, and the offload time of *R*(*i*) obtained by Shannon's formula is

$$t\_i^{up} = \frac{(1 - \mathbf{b}\_{i,j})a\_{i,j}I\_i}{\mathcal{W}\_i \log\_2 \left(1 + \frac{p\_i h\_{i,j}}{\sigma^2}\right)}, \forall i \in \mathcal{N}, \forall j \in \mathcal{M} \tag{6}$$

The energy consumption formula that *R*(*i*) offloads from the edge server is

$$E\_i^{up} = p\_i t\_i^{up}, i \in \mathcal{N} \tag{7}$$

#### *3.5. Computational Model*

The computing model consists of two parts: local computing and offloading computing in the optimal edge server. Tasks are completed within *T*. Suppose *t loc <sup>i</sup>* represents the time required to complete *R*(*i*) locally in the MD, and *t e <sup>i</sup>* is the time to complete for the edge server, there are the following time constraints:

$$\begin{aligned} 0 &< (1 - a\_{i,j})(1 - b\_{i,j})t\_i^{loc} + \\ a\_{i,j}(1 - b\_{i,j}(t\_i^c + t\_i^{up}) &\le T \end{aligned} \tag{8}$$

#### 3.5.1. Local Computing

We consider that each MD has an *R*(*i*) within *T* that cannot be divided into subtasks. *Floc <sup>i</sup>* represents the CPU frequency allocated by MD to *<sup>R</sup>*(*i*) in *<sup>T</sup>*. We can get *<sup>F</sup>loc <sup>i</sup>* as follows:

$$F\_i^{\rm loc} = \frac{(1 - b\_{i,j})a\_{i,j}l\_i\mathbb{C}\_i}{t\_i^{\rm loc}}, i \in \mathbb{N} \tag{9}$$

In local computing, when the CPU is larger, the energy consumption is smaller. Let x be the chip energy coefficient of MD, the local computing energy consumption is as follows:

$$E\_i^{loc} = \frac{k\_i((1 - b\_{i,j})a\_{i,j}I\_iC\_i)^3}{T^2} \tag{10}$$

Let x denote the maximum calculation frequency of MD, and the task also has the following constraints on the calculation frequency of the local calculation:

$$0 \le \frac{(1 - b\_{i,j})a\_{i,j}l\_iC\_i}{T} \le F\_i^{\max}, i \in N \tag{11}$$

3.5.2. Edge Server Computing

After all tasks are deployed, the selected server *j* executes the computing task immediately. Computing resources between edge server containers do not interfere with each other, and each container can be executed in parallel. In the same container, tasks are executed serially. The resource constraint allocated by edge servers for computing tasks is denoted as follows:

$$0 < F\_i^{\varepsilon} < F\_i^{\max}, i \in N \tag{12}$$

where *Fmax <sup>i</sup>* represents the maximum computing resources that can be allocated by the container, and let *kj* be the chip energy coefficient of the edge server *j*; then, the CPU frequency of *R*(*i*) in the edge server is

$$F\_i^\varepsilon = \frac{I\_i \Sigma\_{j=1}^M (1 - b\_{i,j}) a\_{i,j} \mathbb{C}\_j}{\binom{t\_i^\varepsilon}{i}}, i \in \mathbb{N}, j \in M \tag{13}$$

The larger the computing delay on the edge server, that is, *t e <sup>t</sup>* + *t up <sup>t</sup>* = *T*, the smaller the energy consumption. The computing energy consumption on the edge server side is

$$E\_l^\* = k\_j F\_i^{\epsilon^3 t\_i^r} = \frac{I\_i^3 \sum\_{j=1}^M k\_j a\_{i,j} \mathbb{C}\_j^3}{\left(T - t\_i^{\mu p}\right)^2}, i \in N \tag{14}$$

#### *3.6. Problem Formulation*

By jointly optimizing the edge cache and task deployment strategy, to minimize the total energy consumption value of the system, the problem is expressed as follows:

$$P\_1 = \min\_{a, b, t\_i^{up}, t\_i^{\epsilon}} \left\{ \sum\_{i=1}^N E\_i^{up} + E\_i^{loc} + E\_i^{\epsilon} \right\} \tag{15}$$
 
$$\text{s.t.(1),(2),(3),(4),(5),(8),(11),(12)}$$

Obviously, the problem *P*1 is a combinatorial optimization problem with multiple discrete variables, and it is an NP-hard problem.

#### **4. Problem Solution**

This paper uses deep Q-learning to solve this NP-hard problem, and the task needs to find the optimal or sub-optimal solution (2*M* + 1)*<sup>N</sup>* from the kinds of edge caches and task deployment strategies during deployment. To solve the combinatorial optimization problems with multiple discrete variables, global search algorithms such as genetic algorithm (GA) and heuristic algorithms such as dynamic programming (DP) can be used. However, as the number of MDs increases, the search space will be too large, and it will take a long time for such algorithms to obtain the optimal solution or sub-optimal solution. Therefore, the former algorithms do not meet the needs of the rapid response of the strategy in actual scenarios.

For this reason, this paper proposes an efficient edge caching and task deployment (DQCD) algorithm based on deep Q-learning, which can shorten the execution time of the algorithm while obtaining the optimal or sub-optimal solution. The algorithm consists of two parts: generating the edge cache and task deployment strategy and updating the edge cache and task deployment strategy. The edge cache and task deployment strategy are generated by taking the feature vector (*Ii*, *Ci*, *Vi*, *Li*, ...) of the task and the feature vector (*Nj*, *kj*, *Cj*, *Qmax <sup>j</sup>* , ...) of the server as the features input Δ*<sup>t</sup>* of the DNN in *t* − *th* time block and outputting the predicted task deployment strategy *a*ˆ*t*. The neural network calculates the reward value by outputting the predicted task deployment strategy and stores the newly obtained actions into the buffer; then, it reads the training samples from the buffer to train the neural network in the deployment strategy update of the *t* − *th* time block. This updates the weight and bias values in the trained model *δt*. The newly trained model *δt*+<sup>1</sup> is fed with new features in the next time block to generate a task deployment policy *a*ˆ*t*+1. Through the above-mentioned iterative process, the DNN training model is gradually improved to the optimal or sub-optimal task deployment strategy.

When generating edge cache and task deployment strategies, it is a huge work to directly generate a size (2*M* + 1)*<sup>N</sup>* policy space and find the optimal or suboptimal solution in such a large space. This paper optimizes the action space when the edge cache and task deployment strategy are generated through the container type, task type and edge cache size conditions, etc. On the one hand, the computing tasks of MD have different types, and the numbers and types of containers in the edge server are also different. By matching the task type and container type, the strategy matrix can be pruned and optimized. On the other hand, the cache size of the edge server is limited, and the policy matrix can be pruned

and optimized through the maximum cache upper limit of the edge server. Through the above pruning optimization operations, the response speed of edge caching and task deployment strategies and the convergence speed of model training are accelerated.

The key elements of the model definition of the algorithm are as follows:

(1) *State*: The state of the system is a set of parameters about the MEC system, and the state at the *t* − *th* time block can be defined as

$$state\_t \triangleq \Delta\_t = (I\_{i,t}, \mathbb{C}\_{i,t}, \dots), \left(N\_{\hat{\jmath},t}, \mathbb{C}\_{\hat{\jmath},t}, \dots\right) \tag{16}$$

(2) *Action*: With the input, the DNN will choose from the above pruned-optimized deployment strategies, which can be defined as

$$action\_t \triangleq \{a\_t, b\_t\} \tag{17}$$

where *at* is the task deployment vector and *bt* is the edge cache deployment.

(3) *Reward*: The goal of this paper is the minimization of system energy consumption. It is obvious that in the *t* − *th* time block can be represented by the objective function(15), which can be defined as

$$reward \stackrel{\triangle}{=} \left\{ \sum\_{i=1}^{N} E\_{i,t}^{\mu p} + E\_{i,t}^{loc} + E\_{i,t}^{c} \right\} = E\_t \tag{18}$$

In summary, the DNN continuously learns from the optimal state operation matrix selected in the current state and generates better edge caching and task deployment strategy in selfiteration. In our algorithm, the buffer area given DNN is limited, so the DNN just learns the latest edge cache and task deployment strategy and then generates the latest data samples through the latest strategy. This is a closed-loop self-learning mechanism under the reinforcement learning mechanism, which improves its own edge caching and task deployment strategies in continuous iterations. The pseudocode DQCU is shown in Algorithm 1.

**Algorithm 1:** Energy-efficient edge caching and task deployment based on DQN (DQCD) algorithm to solve the problem of minimizing system energy consumption.

**Input:** Vector integrals of task request and edge server parameter arrays Δ*t*, time block *T*, approximate action matrices *K*.

**Output:** Edge cache and task offloading strategy matrix *a*ˆ*r*, time overhead *t up <sup>i</sup>* , *t loc <sup>i</sup>* , *t e <sup>i</sup>* of data transmission and data computation in *T*.


#### **5. Simulation Results**

In the simulation, without special instructions, we study a MEC system composed of three MEC servers and multiple MDs. The main parameters of this paper refer to [1], and the remaining parameters are as follows. When the MD transmits to the edge server, the sub-channel bandwidth follows a uniform distribution *Wi* ∼ *U*(10, 15) Mbps; the task transmission power follows a uniform distribution *Pi* ∼ *U*(0.1, 1) W. For the local computing model, the maximum CPU frequency of each MD follows a uniform distribution *F*man *<sup>i</sup>* ∼ *U*(1.5, 3) Ghz, and the CPU cycles required to calculate 1 bit follow a uniform distribution *Ci* ∼ *U*(100, 1500) cycles. For the edge computing model, the maximum CPU frequency of each MD follows a uniform distribution *Fmax <sup>j</sup>* ∼ *U*(32, 128) Ghz, and the CPU cycles required to calculate 1 bit follow a uniform distribution *Cj* ∼ *U*(400, 600) cycles. We set the maximum number of iterations as *maxNum* = 30, 000, the training interval as *Tim* = 10, and the time block as *T* = 1.0 s. The comparison algorithms in this paper include the following: the GACD algorithm uses the idea of the genetic algorithm for a global search in the edge cache and task deployment strategy matrix, respectively; the DPCD algorithm uses the idea of the genetic algorithm for a heuristic search in the edge cache and task deployment strategy matrix, respectively; and RandCD algorithms use random edge caching and task deployment strategies.

#### *5.1. Algorithm Convergence Proof*

During the training process of the DQCU algorithm, the deep neural network model is always updated iteratively, and the test data set is input into the current model to obtain the optimal deployment strategy. As shown in Figure 2, it is verified that the DQCD algorithm finally achieves convergence. At the same time, the DQCD algorithm can have a relatively stable suboptimal solution around the 50th group.

**Figure 2.** DQCD training convergence process.

#### *5.2. Algorithm Feasibility Verification*

Figure 3 compares the energy consumption of the DQCD algorithm with the 1.0 T in the environment of 3 edge servers and 10 MDs. We compared the energy consumption of DQCD with GACD, DPCD and RandCD algorithms. In addition, we also verified the necessity of considering computing offload and edge cache. DQCD is also compared regarding the energy consumption of only considering the cache (CacheCD), only considering the offload to edge computing (EdgeCD), only considering local computing (LocalCD), and not

considering the cache (NocacheCD). It can be seen from the figure that the energy consumption of DQCD can reach suboptimal solution, and the energy consumption comparison is obviously better than CacheCD, EdgeCD, LocalCD and NocacheCD, which proves the effectiveness of our proposed joint edge caching and task deployment algorithm in reducing energy consumption.

**Figure 3.** Energy consumption comparison.

Figure 4 compares the energy consumption of the DQCD algorithm with the T increasing from 1.0 to 2.0 in the environment of 3 edge servers and 10 MDs. Compared with the GACD algorithm, DQCU can achieve suboptimal or even optimal task deployment strategies. Compared with DPCD, DQCU can achieve the purpose of reducing energy consumption; compared with RandCD, DQCD can have a great overall energy consumption reduction.

**Figure 4.** Comparison of energy consumption.

Figure 5 compares the energy consumption and algorithm execution time in the environment with 3 edge servers and 1.0 T in which MDs increase sequentially from 4 to 12. Compared with the GACD algorithm, DQCD can realize the task deployment strategy of the suboptimal solution in terms of energy consumption. Compared with the DPCD algorithm, DQCD has a slight energy consumption reduction. Compared with RandCD, DQCD has a large overall energy consumption reduction. Compared with RandCD, DQCD has a slight disadvantage in algorithm execution time. Compared with GACD and

DPCD, the algorithm time of DQCD is greatly reduced. As the number of users increases, the advantage of DQCD in terms of the algorithm time becomes more and more obvious.

Through the analysis of the above simulation results, the proposed DQCD algorithm can obtain a performance close to the optimal solution when solving the task deployment problem in the multi-MD multi-server MEC system, and at the same time, DQCD realizes the rapid response of the policy.

**Figure 5.** Comparison of execution time.

#### **6. Conclusions**

This paper studies an MEC system composed of multi-MD multi-servers, under the constraints of container type and computing ability, delay, task offloading and deployment, to solve the problem of minimizing system energy consumption. Task offloading, task deployment and edge caching strategies are all discrete variables, and the energy minimization problem is a combinatorial optimization problem containing multiple discrete variables, and it is NP-hard. In order to minimize energy consumption and achieve a fast policy response, this paper proposes a joint optimization algorithm for edge caching and task deployment based on deep Q-learning (DQCD) by optimizing the policy space to accelerate the model convergence speed. At the same time, a deep neural network model close to the optimal solution is obtained through iterative training. The simulation results show that, compared with the existing baseline algorithms, the DQCD algorithm not only achieves near-optimal performance but also effectively reduces the execution delay.

At present, the research scale of this paper is small because the current resources are limited. In a large-scale scenario with more edge servers and users, more powerful computing resources would be needed for algorithm model training, and then we would consider purchasing computing resources with more powerful performance so that largerscale scenarios could be calculated. This is what we need to do next.

**Author Contributions:** Conceptualization, Y.L.; Methodology, Y.L.; Software, P.W.; Resources, C.D.; Data curation, P.W.; Writing—original draft, P.W.; Writing—review & editing, C.D. and Y.L.; Supervision, L.M.; Project administration, L.M. All authors have read and agreed to the published version of the manuscript.

**Funding:** This research received no external funding.

**Data Availability Statement:** The raw data can be provided on simple request.

**Acknowledgments:** We would like to thank all participants.

**Conflicts of Interest:** The authors declare no conflict of interest.

#### **References**


### *Article* **Deep Learning Techniques for Pattern Recognition in EEG Audio Signal-Processing-Based Eye-Closed and Eye-Open Cases**

**Firas Husham Almukhtar 1, Asmaa Abbas Ajwad 2, Amna Shibib Kamil 3, Refed Adnan Jaleel 4,\*, Raya Adil Kamil <sup>3</sup> and Sarah Jalal Mosa <sup>5</sup>**


**Abstract:** Recently, pattern recognition in audio signal processing using electroencephalography (EEG) has attracted significant attention. Changes in eye cases (open or closed) are reflected in distinct patterns in EEG data, gathered across a range of cases and actions. Therefore, the accuracy of extracting other information from these signals depends significantly on the prediction of the eye case during the acquisition of EEG signals. In this paper, we use deep learning vector quantization (DLVQ), and feedforward artificial neural network (F-FANN) techniques to recognize the case of the eye. The DLVQ is superior to traditional VQ in classification issues due to its ability to learn a code-constrained codebook. On initialization by the k-means VQ approach, the DLVQ shows very promising performance when tested on an EEG-audio information retrieval task, while F-FANN classifies EEG-audio signals of eye state as open or closed. The DLVQ model achieves higher classification accuracy, higher F score, precision, and recall, as well as superior classification abilities as compared to the F-FANN.

**Keywords:** signal processing; information retrieval; deep learning vector quantization; feedforward artificial neural network; electroencephalography; classification

#### **1. Introduction**

The human brain is made up of an enormous number of nerve cells that interact with one another in a sophisticated network using electrical signals. These signals are transmitted to the cell body after passing through electrochemical solutions in order to modify their impact on that cell's output signal. The electrical signal's current, which travels through these cells, modifies the electrical polarity at the connections between the cell body and dendrites, in which the electrochemical solution can be found [1]. The output of the cell is controlled by the electrochemical junctions' conductivity and updates its state in response to inputs, which is determined by both conscious decision making and environmental feedback [2]. As a result, a lot of focus has been placed on examining the electrical impulses that are sent and received by various brain regions during various actions and states. Despite the availability of various techniques for tracking brain activity in response to specific state, one of the most common practices is the use of EEG, which involves keeping tabs on the brain's electrical signals [3]. This method is gaining popularity, as it is inexpensive and compact equipment is required to gather the signals in comparison to the other methods, such as functional magnetic resonance imaging (fMRI); this is dependent on monitoring variations in cerebral blood flow during certain states or the performance of various tasks [4]. As a result, in recent years, a greater emphasis has been placed on the analysis of EEG data in order to identify these activities and states, plus the detection of anomalies in these signals for use in medicine [5]. A variety of applications

**Citation:** Husham Almukhtar, F.; Abbas Ajwad, A.; Kamil, A.S.; Jaleel, R.A.; Adil Kamil, R.; Jalal Mosa, S. Deep Learning Techniques for Pattern Recognition in EEG Audio Signal-Processing-Based Eye-Closed and Eye-Open Cases. *Electronics* **2022**, *11*, 4029. https://doi.org/10.3390/ electronics11234029

Academic Editors: Yanhui Guo, Deepika Koundal and Rashid Amin

Received: 22 October 2022 Accepted: 2 December 2022 Published: 5 December 2022

**Publisher's Note:** MDPI stays neutral with regard to jurisdictional claims in published maps and institutional affiliations.

**Copyright:** © 2022 by the authors. Licensee MDPI, Basel, Switzerland. This article is an open access article distributed under the terms and conditions of the Creative Commons Attribution (CC BY) license (https:// creativecommons.org/licenses/by/ 4.0/).

and services have used the analysis of EEG signals, such as medical diagnoses and the brain–computer interface (BCI). BCI enables users of computers to operate without any direct physical contact, where EEG data are examined in order to determine the computer task that will be needed [6]. EEG signals are frequently used with anomaly detection to forecast various diseases. However, in addition to the advantages of understanding the subject's condition at the moment, the EEG data gathered at the same states and actions have shown various EEG patterns in various states of the eye [7]. Eyes often exist in one of two states, which are either closed, known as eye closed (EC) or open, known as eye open (EO). This conduct demonstrates the significance of determining the eye's condition before performing any additional EEG data analysis [8]. Recent years have seen remarkable progress in the pattern-detecting power of DL methods, which have been applied to an ever-increasing amount of input data and analyze the relations between these inputs in order to extract the required knowledge from these data [9].

Processing could be greatly simplified by DL via autonomous end-to-end learning of feature extraction, classification, and preprocessing methods to achieve above-average results on the intended task. In fact, DL architectures have had great success in processing difficult data over the past few years such as audio signals, images, and text, resulting in top-tier results on a wide range of publicly available measures, such as the challenge of visual identification on a broad scale and its expansion in industrial applications [10].

DL is a topic of machine learning that depends on how computational techniques can learn hierarchical representations of incoming data during successive nonlinear transformations. As a result of research into the perceptron and other prior models, deep neural networks (DNNs) have been developed, where (1) a network of artificial "neurons" performs a linear change on the information it receives at each successive layer and (2) a nonlinear activation function is given the linear transformation output by each layer. It is significant that these transformations' parameters are directly found out by systematically reducing a cost function. Despite the widespread use of the term "deep" that refers to neural networks, there is no agreement on what exactly constitutes a "deep" network, and subsequently, what actually qualifies as a deep network and what does not [11].

Acquiring and labeling EEG signals is necessary because DL requires a labeled dataset to draw relevant conclusions and build a model that can be used to predict labels for new inputs. The quality of the predictions made by each classifier must also be evaluated because their abilities to extract accurate knowledge may cause them to perform differently. While unlabeled data are what the predictions are designed for, labeled data are still necessary for measuring their accuracy, so that for the occurrences in the evaluation dataset, the predicted labels are contrasted with the actual labels, which are defined as the testing data. The closer the predicted labels to the actual labels of the test data instances, the better the knowledge extraction, which indicates the better performance of the classifier [12].

#### **2. Literature Survey**

The procedure for processing and analyzing EEG data typically consists of two steps: extraction of feature and recognition of patterns [13,14]. Before the popularity of deep learning, the most common way to extract features was to use signal analysis to pull out time–frequency features, such as spectral density, density of power [15], band power [16], separate components [17], and differential entropy [18]. The widely researched recognition on pattern and machine learning techniques contains ANN [19,20], naive Bayes [21], support vector machines [22,23], etc. Due to deep advertising and widespread use of DL, an ever-increasing number of neuroscience and brain study teams are discovering its strength in building techniques to use EEGs to intelligently comprehend and analyze brain activity, hence offering an end-to-end approach that unifies the extraction of features, classification, and clustering. To categorize mental stresses, the authors of [24] developed a multichannel DNN. In [25], an LSTM network was used to categorize forms of motor imagery and extracted the network's useful characteristics using a one-dimensional aggregation approximation technique. In [26], the authors employed a CNN-based predictive

modeling strategy for estimating the ages of brains. Their research revealed that the brain's ability to estimate age is extremely accurate. With their suggested spatiotemporal deep convolution model, the authors of [27] significantly enhanced the accuracy of driver fatigue detection by highlighting the significance of geographical data and the timing of EEGs. To automatically detect epileptic convulsions in EEG recordings, further, a full-stack multiview DL architecture was suggested in [28].

In [29], the authors attempted to build CNN with transfer learning in mind and successfully used the model to diagnose mild depression in patients. Furthermore, a mixed neural network of LSTM on domain of time–frequency data was trained using an activation function of rectified linear unit (ReLU) to classify sleep stages [30]. Later on, a compact model based on a complete convolutional network (EEGNet) was presented for EEG for the various BCI classification tasks [31]. The authors in [32] suggested a parallel and cascaded convolution recurrent neural network technique by efficient learning of spatiotemporal representation to distinguish the human motion commands of actual EEG signals. Moreover, in [33], EEG data are transformed into EEG-dependent optical flow and video information, which is classified by RNN and CNN, for the development of a successful BCI-based rehabilitation support system.

Large amounts of content information and varied visual qualities are packed into multimedia data that are commonly employed in the collection and analysis of EEG data [34–36]. Through the study of EEG signals, researchers attempted to determine and categorize the content information of users' watched multimedia material [37–39].

The authors of [35] created a mapping link between natural image attributes and EEG representation using LSTM network learning to develop a model for EEG responses to visual cues. The improved EEG signal representation was then used for the classification of natural images. These DL-based strategies produced remarkable classification results, especially when compared to conventional approaches.

Moreover, twenty-eight university students' EEG data were gathered while they were resting with their eyes closed and their eyes open; then, Fourier transform was applied in the theta, delta, beta, and alpha bands and analyzed in nine regions across the scalp to make estimations of total power [40]. Arousal level was also determined by measuring skin conductance. Topographic effects of the situation are shown in Figure 1.

Resting EEG data from 70 subjects were analyzed by the authors of [41] (46 adults, 29 male), across testing sessions spaced by 12 ± 1.1 years. Alpha of EEG was separated, quantified, and identified by applying reference-free techniques that combine principal component analysis (PCA) with current source density (CSD). Measures of overall (EC-plus-EO) and net (EC-minus-EO) posterior amplitude of alpha and inconsistencies in asymmetry were compared among several trials. Waves 4 and 6 of the resting EEG's CSD-fPCA structure are shown in Figures 2 and 3 at 13 electrode sites common to both waves and topographies of the mean alpha of factor scores, respectively.

Recent research has demonstrated that by mining EEG data, multimedia content information may be reconstructed. In [34], the authors developed a strategy for deducing the information of visual inputs from electrical brain activity. By applying generative adversarial networks (GANs) and a variable-valued auto-encoder (VAE), they discovered that patterns relating to visual content can be observed in EEG data, and images that are semantically compatible with the incoming visual stimuli can be produced using the content. Despite the fact that these techniques have proven that a DL framework may be used for EEG-based image classification, the input is frequently the actual EEG measurements or time–frequency properties retrieved using signal analysis methods, and certain aspects of human brains have not been given much thought, such as hemispheric lateralization, and the accuracy of classification produced is 82.9% [35].

**Figure 1.** Environmental influences on topography (**a**) Delta (**b**) Theta (**c**) Alpha (**d**) Beta. One EEG band is active in each row. In the left column, we see the typical power output across all conditions. Power distributions for the EO and EC circumstances are displayed in the middle columns; the power decrease in each band when the eyelids open can be easily observed. The rightmost column displays standard scores representing the difference (EO-EC) at each electrode site; note the focal changes apparent in theta, beta, and delta bands.

**Figure 2.** Waves 4 and 6 of the resting EEG's CSD-fPCA structure with alpha prefiltering at 13 electrode locations that are shared by both waves. (**A**): Separation of low-frequency alpha signals can be seen in factor loadings (9.37 Hz peak; blue) and alpha of high frequency (10.15 Hz peak; red) from a broad delta factor (1.56 Hz peak, green). (**B**): Topographies of mean factor scores for EO and EC condition (averaged of waves 6 and 4).

**Figure 3.** Mean factor score alpha topographies (pooled across low- and high-frequency alpha) for top row (EO), middle row (EC), and overall alpha (EO + EC; bottom row), independently offered for each wave.

#### **3. Materials and Methods**

Many applications and analyses that use EEG signals require knowledge of the subject's eye state, whether it is open or closed. DLVQ and F-FANN techniques can be used to predict a class, or label, per each data instance, depending on the attributed values of that instance, which makes them applicable for predicting the state of the subject's eyes, based on the values collected using EEG. However, in order to use a classifier in a certain application, it is important to train that classifier using data collected from the same environment, where each instance is labeled according to the actual state of that instance. Thus, to use DLVQ and F-FANN techniques in predicting the state of the eye based on the EEG signals, data must be collected from different subjects in a controlled environment, i.e., the actual state of their eyes are logged alongside with the EEG signals. However, DLVQ and F-FANN techniques may have different performance depending on the input data, which imposes the need to evaluate the performance of the classifiers, in order to select the most appropriate one. The rationale behind the comparison of DLVQ and F-FANN in EEG signal processing is that we can demonstrate that the activations and parameters of a neural network can be quantized using product quantization with shared subdictionaries without materially affecting the network's accuracy.

#### *3.1. Data Collection and Classification*

In order to train classifiers and evaluate their performances, a labeled dataset is required so that classifiers extract the relations between the attribute values of each instance and the label given to it, while evaluation is conducted by comparing the predictions provided by the classifiers to the actual labels of the instances used in the evaluation. For this purpose, the EEG data included 150 recordings from 27 participants, taking around 24 s in each condition. EEG signals are collected from a 16-channel V-AMP amplifier at 1024 samples/second, sample rate, while the impedance of the channels is maintained below 5 kΩ. The electrodes that collect the EEG signals are positioned at (Fp2, Fp1, Fz, F3, FCz, F4, T3, T4, CPz, Cz, Pz, P7, C3, P8, C4, and Oz). Undesired frequencies are filtered out

using three types of filters, an 80 Hz online low-pass filter, a 0.1 Hz high-pass filter, and a 50 Hz notch filter. Moreover, two reference nodes are connected to the subject's ears, one to each ear. After preprocessing the signals, using the Brain Vision Analyzer, the collected data are referenced using the average EEG of all the collected channels, per each instance, and the sampling rate is reduced to 256 Hz. Artifact segments and epochs with amplitudes greater than 150 μV are marked for removal for further analysis. Finally, the band power values of the alpha, beta, delta, and theta are computed from the collected EEG signals, which results in 64 attributes that describe each data instance in the collected dataset. When all of the training data are utilized at once, this is called an epoch, and it is measured in terms of the total number of training data iterations in a single cycle.

The performance of DLVQ and F-FANN is evaluated in this study in order to select the classifier with the best performance for EEG signal classification, to forecast the state of the subject's eye, depending on the data extracted from the EEG signals. The data collected from EEG signals are used to train the DL techniques and evaluate their performance, which reflects the quality of the knowledge extracted from the training data.

#### *3.2. Deep Learning Vector Quantizer*

It is possible to obtain a frame-level codeword and initial codebook sequence using the level structured k-means using VQ approach described in [42]. These data can then be used for DLVQ. DLVQ is based on the LVQ principle, which has been found to be helpful in several disciplines, including ASR and text classification, and uses the power of deep learning simultaneously. However, the difference between this method and the method given in [42] is that in this study, it is applied on brain signals and not on heart signals, and comparison is carried out between DLVQ and F-FANN on EEG signals.

#### 3.2.1. DLVQ System Structure

Similar to the DLVQ approach used in [42], which employs DNN as a code-book learner and VQ, this study follows the same basic outline. As with DNN-based ASR, a DNN can be trained using the frame-level label information provided from the initial quantizer. Figure 2 depicts the overarching framework of the training program. K-means is used to first learn an initial codebook using training frames (No enclosing contexts are employed). Then, using normal VQ, each frame's codeword is acquired. Finally, a DNN is trained using optimization objective for cross-entropy, with the codeword serving as the class target for every individual frame.

#### 3.2.2. DNN Training

A core frame was spliced into the DNN's input (whose label is the splice label) and its left and right sides have n context frames, e.g., *n* = 6 or *n* = 8. The sigmoid units were used to build the hidden layers, and a softmax layer was used for the output. It has exactly the same number of nodes as the VQ initializer's codeword. DNN's fundamental structure is depicted in Figure 4. In particular, an expression for the node values looks like Equations (1) and (2),

$$\mathbf{v}\mathbf{x}^{i}\&=\begin{cases}\textit{Weight}\_{1}\mathbf{o}^{t}+\mathbf{b}\mathbf{v}\_{1\prime} & i=1\\\textit{Weight}\_{i}\mathbf{y}^{i}+\mathbf{b}\mathbf{v}\_{i\prime} & i>1\end{cases}\tag{1}$$

$$\mathbf{v}\mathbf{y}^i \&= \begin{cases} \text{sigmoid}\left(\mathbf{x}^i\right), & i < tn\\ \text{softmax}\left(\mathbf{x}^i\right), & i = tn \end{cases} \tag{2}$$

where **bv**1, **bv***<sup>i</sup>* are the bias vectors; *tn* is the softmax and sigmoid functions; and the total number of hidden layers are element-wise operations. The vector **vx***<sup>i</sup>* corresponds to activations of prenonlinearity and **vy***<sup>i</sup>* is the vector of neuron at the *i*th hidden layer. Codeword posterior estimates were derived from the softmax outputs as in Equation (3):

$$P\left(\mathcal{CW}\_{\dot{l}} \mid \mathbf{o}\_{t}\right) = \mathbf{E}\_{t}^{\boldsymbol{\mathfrak{u}}}(\boldsymbol{j}) = \frac{\exp\left(\mathbf{x}\_{t}^{\boldsymbol{\mathfrak{u}}}(\boldsymbol{j})\right)}{\sum\_{i} \exp\left(\mathbf{x}\_{t}^{\boldsymbol{\mathfrak{u}}}(\boldsymbol{i})\right)}\tag{3}$$

where *CWj* represents the *j*th codeword and **E***<sup>n</sup> <sup>t</sup>* (*j*) is the *j*th element of **E***<sup>n</sup> <sup>t</sup>* (*j*).

**Figure 4.** Structure of DLVQ system.

Through increasing the log posterior probability across the training frames, DNN was trained. This is the same as trying to minimize the loss function with the largest negative cross-entropy. Let X represent the entire training set with N frames, i.e., **x**<sup>0</sup> 1:*N*∈X, then the loss with respect to X is given by Equation (4):

$$\mathcal{L}\_{1:N} = -\sum\_{t=1}^{N} \sum\_{j=1}^{J} \mathbf{l}\_t(j) \log P\left(\mathbf{C} \mathcal{W}\_j \mid \mathbf{o}\_t\right) \tag{4}$$

where *P CWj* | **o***<sup>t</sup>* is mentioned in Equation (3); **l***<sup>t</sup>* is the vector of label at frame t, which is the pseudo one obtained from the initializer of k-means VQ. Utilizing error backpropagation, we are able to reduce the loss objective function, which is a gradient-descentdependent optimization technique that is advanced for neural networks. Calculating partial derivatives of the function of loss objective in relation to the output layer's prenonlinearity activations X<sup>n</sup> will produce the vector of error to be backpropagated to the previous hidden layers. In the previous hidden layer, backpropagated error vectors are described in Equations (5) and (6):

$$
\epsilon\_t^n = \frac{\partial \mathcal{L}\_{1:N}}{\partial \mathbf{x}^n} = \mathbf{E}\_t^n - \mathbf{l}\_t \tag{5}
$$

$$\epsilon\_t^i = \mathcal{W}\_{i+1}^T \epsilon\_t^{i+1} \ast \mathbf{y}^i \ast \left(1 - \mathbb{E}^i\right), i \prec n \tag{6}$$

where ∗ refers to element-wise multiplication. Vectors of error from specific hidden layers combined with the overall gradient with respect to the matrix of weight through training *Wi* are computed by Equation (7):

$$\frac{\partial \mathcal{L}\_{1:N}}{\partial W\_{\mathrm{i}}} = \mathbf{C}\_{1:N}^{i-1} \left(\boldsymbol{\epsilon}\_{1:N}^{i}\right)^{T} \tag{7}$$

From Equation (7), it is observed that above both **<sup>C</sup>***i*−<sup>1</sup> 1:*<sup>N</sup>* and *<sup>i</sup>* 1:*<sup>N</sup>* are measures, which are constructed by stringing together vectors representing each training frame, from frame 1 to frame *N*, i.e., *<sup>i</sup>* 1:*<sup>N</sup>* 1: *<sup>N</sup>* = [*<sup>i</sup>* <sup>1</sup>, ... , *<sup>i</sup> <sup>t</sup>*, ... , *<sup>i</sup> <sup>N</sup>*]. Parameters are recalculated using the gradient in Equation (7), a batch-based gradient-descent update, only once. Parallelization can thus be readily carried out to hasten the learning process after each sweep across the entire training set. Stochastic gradient descent (SGD), on the other hand, typically functions more effectively in practice. This is because SGD assumes that the true gradient may be approximated by the gradient at a single frame t, i.e., **<sup>C</sup>***i*−<sup>1</sup> <sup>1</sup> *i* 1 *T* , and each frame's parameters are updated immediately after viewing. The minibatch SGD is more popular because all of the matrices fit into the GPU memory due to the minibatches' appropriate size, resulting in a more computationally efficient learning procedure. In this work, the parameters are updated using minibatch SGD.

In order to maximize the accuracy of the DNN, it is best to train it with a cross-entropy loss function that minimizes the likelihood that it will forget any labels it has been given by its initializer of VQ; that is, a "perfect" training cycle will enable the DNN to achieve the same VQ outcomes as its initializer. In contrast, low frame accuracy was reported throughout the realistic training approach: less than half for the testing and training data. This shows that DNN is capturing new information in the input rather than learning exactly what its initializer does.

#### *3.3. Feedforward Artificial Neural Network*

F-FANN is implemented to predict the state of the eye, depending on the input values collected from the EEG signals. As each instance consists of a one-dimensional vector, with 64 values, the implemented neural network uses only fully connected layers. According to the number of attributes in the data, the number of neurons in the input layer is set to 64 neurons: 1 neuron per each input value. This input layer is linked to the first hidden layer, which consists of 256 neurons. In addition to this hidden layer, 3 more hidden layers are used before the output layer, with 256, 128 and 64 neurons, sequentially, producing a total of 4 hidden layers. The output layer consists of a single neuron, as a single output is required from the neural network to describe the probability of the input to be collected from a subject with EC state. A summary of the implemented feedforward artificial neural network is shown in Figure 5.

**Figure 5.** F-FANN for EEG classification.

The use of such topology allows for the extraction of complex features, from the input, without dramatically increasing the complexity of the computation in the neural network, which requires more computer resources or execution time. Moreover, according to the benefits of the ReLU activation function, including the faster learning and elimination of the vanishing gradient problem, all hidden layers use this activation function. As the output required from the neural network is limited to the range from zero to one, this neuron uses the Sigmoid activation function. In artificial neural networks, overfitting occurs when the neural network is emphasized on a certain path, among neurons, to reach the required solution. As the training continues in an overfitted neural network, more emphasis is added to that batch, by amplifying the weights among neurons in that path, during backpropagation. Thus, to avoid such behavior, a predefined percentage of the neurons in every layer are dropped through the training, which is defined as the dropout rate. These neurons are selected randomly per each training epoch, so that the neural network is enforced to find multiple paths to come up with the same prediction. Relying on different features, enforced by the dropout, the output of the neural network considers all these features, which eliminates the errors that may occur according to the strict reliance on a certain feature. Figure 5 illustrates an example of dropout during training.

#### **4. Performance Evaluation and Results**

In order to choose the classifier with the best assessment in EEG signal classification, to predict the eye state of the subject, predictions provided by each DL classifier are compared to the original states of the eyes in the dataset, by distributing these predictions and the actual states in the confusion matrix. The true EO represents the number of instances that are collected from subjects with their eyes open and forecasted by the DL classifier as EO. False EO is the number of instances that have EO labels, while the classifier predicts them as EC. True EC represents the number of EC instances that are correctly predicted by the classifier as EC, while false EC is the number of EO instances that are predicted as EC by the classifier. Using the values in the confusion matrix created based on the classification results of a certain classifier, the measures of performance are applied to describe the performance of that classifier. Thus, the accuracy of the predictions is calculated using Equation (8). Moreover, the precisions of the predictions provided for the EO and EC classes are shown in Equations (9) and (11) sequentially, while the recalls of each of these classes are calculated using Equations (10) and (12). These values are then used to calculate the F Scores for the EO and EC classes, according to Equations (13) and (14). Moreover, as some of the applications that rely on EEG classification to estimate the state of the subject's eye require faster decisions, the average time required by each classifier to produce a prediction for a single instance is also measured. Based on these measures, the DL techniques with the best performance can be selected for the purpose of eye state prediction based on the EEG signals.

$$\text{ACC} = \frac{\text{True EC} + \text{True Eo}}{\text{True EC} + \text{False EC} + \text{True EO} + \text{False EO}} \tag{8}$$

$$\text{Precision}\_{\text{EC}} = \frac{\text{True EC}}{\text{True EC} + \text{False EC}} \tag{9}$$

$$\text{Recall}\_{\text{EC}} = \frac{\text{True EC}}{\text{True EC} + \text{False EC}} \tag{10}$$

$$\text{Precision}\_{\text{EO}} = \frac{\text{True EO}}{\text{True EO} + \text{False EC}} \tag{11}$$

$$\text{Recall}\_{\text{EO}} = \frac{\text{True EO}}{\text{True EO} + \text{False EC}} \tag{12}$$

$$\text{FScore}\_{\text{EO}} = 2 \times \frac{\text{Precision}\_{\text{EO}} \times \text{Recall}\_{\text{EO}}}{\text{Precision}\_{\text{EO}} + \text{Recall}\_{\text{EO}}} \tag{13}$$

$$\text{FScore}\_{\text{EC}} = 2 \times \frac{\text{Precision}\_{\text{EC}} \times \text{Recall}\_{\text{EC}}}{\text{Precision}\_{\text{EC}} + \text{Recall}\_{\text{EC}}} \tag{14}$$

Each channel's EEG signal is first adjusted to one variance and zero mean. Additionally, it is segmented into adjacent frames (each one lasts for one second, which is equivalent to the length of one hundred and fifty samples). The result is 10 frames transmitted on each channel. Figure 6 shows an illustration of this procedure.

Three level-structured k-means VQ systems were developed for use as benchmarks. with 128 (8 clusters on first level, and 4 on the second level) 256 (three levels with 16, 4 and 2 clusters in every level), and 512 (4 levels with 32, 8, 16, and 4 clusters in each level) codewords characterized by 128 k-means, 256 k-means, and 512 k-means.

The number of times each codeword appears in an audio EEG clip was used to create the BoW vector representation of that clip. The pseudo codeword labels produced by the systems of baseline k-means were used to build DLVQ systems. The input of all DNNs is a splice of the center frame and its 8 context frames, and each layer of the 7 hidden layers has 2048 nodes as in [42]. Each system's output softmax layer shares the same dimensionality as the codebook vocabulary of its respective VQ initializer system, that is, 128, 256, and 512, respectively. Based on the Kaldi voice recognition tools, we developed the DNN systems. The DNN is trained using the following method: layer-by-layer generative

pretraining is used for parameter initialization. As a next step, we use backpropagation and the cross-entropy goal function to train the network discriminatively. The initial learning rate is set at 0.09, and the minibatch size is set to 256. Then, frame accuracy is checked on the development set after every training iteration, the learning rate is reduced by a factor of 0.5% if the improvements are less than 0.5%. After the accuracy of frame enhancement drops to less than 0.1%, the training procedure is terminated. The 128-codeword k-means-based DNN trained in practice obtained 34% frame accuracy on the training and test datasets, respectively; the 256-codeword k-means model obtained 23% and 29%, respectively; the model dependent on 512-codeword k-means obtained 24% and 27%. Figure 5 shows these findings, which show that the frame EEG accuracy-shifting tendencies in the training and development set are similar and primarily growing. This demonstrates that the DNN can successfully imitate its VQ initializer through cross-entropy training (the "labels" by k-means VQ being retained); however, because the accuracy of the last frame was below 50%, we may infer that the DNN is not figuring out the specifics of its initializer's operation, but rather is actively gathering fresh data. The representation of BoW of an audio EEG clip was then made by running the clip's frames through a trained DNN and adding up the resulting vectors. For both the baseline and suggested frameworks, as a histogram, the properties of the BoW vector representation for each clip were normalized so that they added up to 1. HIK kernel and SVMs were applied as the classifiers. It is evident that in MAP, DLVQ achieves a 4.5% relative increase over the k-means baseline. An approximate 10.5% relative gain was obtained when fusing the findings of the baseline and proposed systems. DLVQ picks up some supplementary data that k-means miss. According to their AP scores on the development set, the two systems' classifier scores from the basic late fusion approach are simply weighted together. These encouraging results demonstrate that DLVQ does aid in improving the representative power of VQ-based BoW vectors. Figure 7 offers the accuracy of DLVQ with different codebooks in training sets. Figure 8 offers the accuracy of DLVQ with different codebooks in development sets.

**Figure 6.** Amplitude versus time for EEG signal.

The F-FANN is implemented using the Keras library and evaluated using the 5-fold cross-validation method. Per each iteration in the cross-validation, the model is trained for 1000 epochs using the training bins and evaluated using the testing bin. These results are used to calculate the performance evaluation measures. The average time consumed by the feedforward artificial neural network to come up with a prediction per data instance is 0.009 ms.

The results show that DLVQ scores better overall accuracy than the feedforward neural network. Higher precision is scored by the feedforward artificial neural network in predicting the EO state, while higher recall is scored in the EC predictions. However, the F score for both states is equal, which is 91%. Thus, the overall F score scored 91% as

well. Figure 9 illustrates the accuracy of F-FANN in training sets, while Figure 10 shows the accuracy of F-FANN in development sets. Tables 1 and 2 also the Figures 11 and 12 offer the precision, recall, and F score for DLVQ and F-FANN, respectively.

**Figure 7.** Accuracy of DLVQ with different codebooks in training sets.

**Figure 8.** Accuracy of DLVQ with different codebooks in development sets.

**Figure 9.** Accuracy of F-FANN with different codebooks in training sets.

**Figure 10.** Accuracy of F-FANN in development sets.

**Table 1.** Precision, recall, and F score for DLVO.


**Table 2.** Precision, recall, and F score for F-FANN.

**Figure 11.** Precision, recall, and F score for DLVO technique.

**Figure 12.** Precision, recall, and F score for F-FANN technique.

F-FANN shows the highest overall performance measure, with an average prediction time of 0.009 mS per each data instance. Moreover, DLVQ is also able to average a prediction time of 0.074 mS.

#### **5. Limitations and Optimal Points**

Large volumes of data are a major roadblock for this proposed work. It can be costly to train it using huge and complicated data models. A lot of hardware is also required to perform complicated mathematical computations. There is no standard or single way to choose DL tools. It is not always possible to obtain answers using DL algorithms when dealing with interdisciplinary issues. With DL, a perfect solution might not always be possible. Inaccurate or incorrect output might result from poor-quality, incomplete, or incorrect data. In fact, DL may not be able to answer issues that are not provided in a classification format, as its methods are optimized for such situations.

The optimal points are briefed as follows: the best features of the proposed system are DLVQ and F-FANN, which perform well with unstructured or unlabeled data, as there are different DL algorithms, libraries, and open-source frameworks available. The number of practical uses for them is extensive, scalable, and efficient.

DLVQ and F-FANN make it easier to automatically recognize features without first extracting those characteristics. One neural network-based technique may be modified and applied to a variety of data kinds and applications, since it is a resilient system.

#### **6. Conclusions and Recommendation**

In this paper, EEG signals were collected from 27 participants in order to evaluate the performance of 3 of the DL techniques, namely DLVQ and F-FANN, to predict the state of the subject's eye based on the collected EEG signals. The collected data were split into five bins, where each bin was used once for evaluation while the remaining bins were used for training, using a 5-fold cross-validation evaluation approach. This approach ensures unbiased evaluation, where data instances in a randomly selected testing set may be more suitable for one classifier than another, which produces biased evaluation measures. DLVQ showed the highest overall performance measure. Additionally, we provide a discriminative approach to LVQ in this research, employing a DL framework to extract a superior VQ representation from the initializer baseline VQ systems. When combined with its k-means VQ initializer, the DLVQ system is able to capture novel information and achieve a highly encouraging relative performance improvement.

There are still many areas where DL techniques in training may be enhanced. We would also like to examine DLVQ's and F-FANN's performance in other fields, such as computer vision, and investigate the theoretical relationship between DLVQ and its initializers. Furthermore, the integration of DLVQ and F-FANN with preexisting technology has become more feasible, including the brain–computer interface, big data, and the Internet of things (IoTs).

**Author Contributions:** Methodology, F.H.A.; conceptualization, A.A.A.; writing—original draft, A.S.K.; review and editing, R.A.K.; software and supervision, R.A.J.; Methodology, S.J.M. All authors have read and agreed to the published version of the manuscript.

**Funding:** This paper received no external funding.

**Data Availability Statement:** The data shall be made available on request.

**Conflicts of Interest:** The authors declare no conflict of interest.

#### **References**


### *Article* **Anomaly Detection in Fog Computing Architectures Using Custom Tab Transformer for Internet of Things**

**Abdullah I. A. Alzahrani 1, Amal Al-Rasheed 2, Amel Ksibi 2,\*, Manel Ayadi 2, Mashael M. Asiri <sup>3</sup> and Mohammed Zakariah <sup>4</sup>**


**Abstract:** Devices which are part of the Internet of Things (IoT) have strong connections; they generate and consume data, which necessitates data transfer among various devices. Smart gadgets collect sensitive information, perform critical tasks, make decisions based on indicator information, and connect and interact with one another quickly. Securing this sensitive data is one of the most vital challenges. A Network Intrusion Detection System (IDS) is often used to identify and eliminate malicious packets before they can enter a network. This operation must be done at the fog node because the Internet of Things devices are naturally low-power and do not require significant computational resources. In this same context, we offer a novel intrusion detection model capable of deployment at the fog nodes to detect the undesired traffic towards the IoT devices by leveraging features from the UNSW-NB15 dataset. Before continuing with the training of the models, correlationbased feature extraction is done to weed out the extra information contained within the data. This helps in the development of a model that has a low overall computational load. The Tab transformer model is proposed to perform well on the existing dataset and outperforms the traditional Machine Learning ML models developed as well as the previous efforts made on the same dataset. The Tab transformer model was designed only to be capable of handling continuous data. As a result, the proposed model obtained a performance of 98.35% when it came to classifying normal traffic data from abnormal traffic data. However, the model's performance for predicting attacks involving multiple classes achieved an accuracy of 97.22%. The problem with imbalanced data appears to cause issues with the performance of the underrepresented classes. However, the evaluation results that were given indicated that the proposed model opened new avenues of research on detecting anomalies in fog nodes.

**Keywords:** network security; deep learning; feature selection; intrusion detection

**Copyright:** © 2022 by the authors. Licensee MDPI, Basel, Switzerland. This article is an open access article distributed under the terms and conditions of the Creative Commons Attribution (CC BY) license (https:// creativecommons.org/licenses/by/ 4.0/).

#### **1. Introduction**

IoT devices are currently widely employed in intelligent applications, including smart cities, healthcare [1], and transportation. All of these IoT-enabled applications share two similar functions: "monitoring" (regularly checking the sensors' state) and "actuating" (acting on the data gathered during monitoring). Additionally, IoT is a networked system built on recognized standards that exchange knowledge. Further, many communication standards, tools, and protocols have been developed due to the many appliance domains. As a result, the Internet of Things (IoT) is frequently referred to as the Internet of People (IoP) because practically everyone uses it regularly, from people to institutions. Moreover,

Al-Rasheed, A.; Ksibi, A.; Ayadi, M.; Asiri, M.M.; Zakariah, M. Anomaly Detection in Fog Computing

**Citation:** Alzahrani, A.I.A.;

Architectures Using Custom Tab Transformer for Internet of Things. *Electronics* **2022**, *11*, 4017. https:// doi.org/10.3390/electronics11234017

Academic Editors: Yanhui Guo, Deepika Koundal and Rashid Amin

Received: 11 November 2022 Accepted: 28 November 2022 Published: 3 December 2022

**Publisher's Note:** MDPI stays neutral with regard to jurisdictional claims in published maps and institutional affiliations.

it enables measurement collection from small, affordable, intelligent end nodes dispersed over a vast physical region with less expensive implementation and operation [2]. However, these advantages come at a cost in terms of finite resources, particularly the end nodes' battery life.

Some of the data collected by IoTs are seen to be unexpected. The shocking data may come from environmental changes, deliberate action, faulty operation, coincidence, or perhaps both. Anomalies were used to describe them [3]. Even though the sensors employed at the edge are insufficient, anomalies are predicted to occur. The end nodes' battery lives may suffer due to the nodes' rapid processing. Due to these restrictions, the network may be more susceptible to errors and malicious attacks [4].

IoT devices are vulnerable to assault since they are connected to the Internet and lack proper security measures. An attacker can swiftly hack IoT devices by taking over smart gadgets that can be used maliciously to exploit other IoT-connected devices [5,6]. Therefore, it is crucial to recognize improper actions to ensure the network operates reliably and securely. Additionally, IoT networks can prevent the broadcasting of useless or inaccurate measurements by spotting intriguing or uncommon events. As a result, the network's dependability can increase while energy consumption is decreased [7].

Anomaly detection entails the identification of noteworthy or unexpected occurrences in the network [8]. Finding a model for the vast majority of normal data is essential to identify anomalies in a dataset. The anomalies can then be identified as those data vectors that considerably depart from the normal model. Finding abnormalities in the network [9] while minimizing overhead and obtaining high detection accuracy is a major challenge.

In the Internet of Things, there are two categories of anomaly detection mechanisms: statistical and machine learning [10]. Only regular IoT traffic is used in statistical methods to create trained models [11]. While doing this, machine learning techniques use both legitimate and malicious communications to train their models. Based on the learning process, these methods are divided into supervised, unsupervised, and semi-supervised categories [12]. During the supervised learning process, the traffic features are mapped to a traffic class, such as normal or assault. Only labeled datasets are used in this learning procedure. By finding intriguing structures in the data, the unsupervised learning process learns the traffic features without being aware of the traffic class. Unsupervised learning groups comparable data in semi-supervised learning, whereas labeled data is used to categorize unlabeled data.

The current detection methods for anomalies depend primarily on a centralized cloud's [13] inability to address IoT requirements, such as resource allocation and scalability. With IoT, operations are carried out across many devices, and large amounts of data are exponentially generated [14]. Since it enables users to access Internet-based services, the cloud is essential to the Internet of Things (IoT). However, because of its centralized architecture, it cannot manage IoT devices even while it does expensive calculations. The great distance between an IoT device and the centralized anomaly detection system also results in a high detection time. Since the centralized cloud environment can accommodate the service requirements of IoT, anomaly detection in IoT differs from currently used methodologies [15]. A brand-new distributed intelligence technique called "computation" is used to reduce the gap. The fog exchanges information by processing data near the data sources, i.e., IoT devices. At the fog layer, as depicted in Figure 1, where fog nodes perform dispersed processing, security measures can be put into place [16]. To implement distributed security mechanisms, expensive computations and storage from IoT devices may be offloaded [17].

**Figure 1.** Implementation of the anomaly detection system in IoT systems.

In this study, we introduced a framework model and a hybrid algorithm for effective ML algorithm selection to discover workable methods for anomalies and incursion IoT network traffi in a fog environment from many ML algorithms.

Significant contributions of the current work include:


The rest of the paper's organization is as follows. Section 2 discusses the literature review of the past work done. The proposed strategy, including data sets, model construction, and performance evaluations, is discussed in Section 3. In Section 4, we use experiments to evaluate the proposed methodologies quantitatively. Then, we discuss the methods and results, conclude the paper in Section 4 and finally conclude in Section 5.

#### **2. Literature Review**

This section presents relevant research and comprehensive background information on machine learning (ML) selection for detecting anomalies and intrusions in IoT networks [22] traffic.

Anomaly detection in IoT data using deep learning was proposed by [15], and it was shown to be more effective than a conventional IDS for identifying coordinated IoT Fog assaults. The NSL-KDD intrusion dataset was used. Compared to the standard model's binary classification recall of 97.50%, it achieved a score of 99.27% for deep learning. In addition, machine learning earned an average recall of 93.66% in multi-classification, whereas deep learning scored an average recall of 96.5%.

The authors of [23] suggested cognitive fog computing for IDS in an IoT network. The suggested methodology could detect malicious behavior in nearby fog nodes as opposed to employing a centralized cloud-based infrastructure. The cloud stores a list of all fog nodes for future research. The proposed model is assessed using the NSL-KDD dataset, and detection is accomplished using the online sequential extreme learning machine (OSELM) method. Their model has a 0.37% FAR and a 97.36% accuracy rate.

The authors of [24] suggested an adaptive IDS for IoT that can recognize DoS threats. In this work, a fresh dataset was gathered using Wireshark over the course of four consecutive days on an IoT testbed. Unfortunately, their suggested model outperforms the Naive Bayes classifier.

The authors of [25] proposed an IDS based on neural networks and locust swarm optimization. For this experiment, which makes use of the NSL-KDD and UNSW-NB15 datasets, the accuracy and FAR are 94.04 and 2.21%, respectively.

Li et al. suggested a combined K-means clustering technique with a PCA fog computing design for anomaly detection. An ELM-based Semi-supervised Fuzzy C-Means (ESFCM) technique was put out by [12]. The NSL-KDD dataset was utilized. The suggested system outperformed the centralized attack detection framework in terms of performance. It reported an accuracy rate of 86.53% and a decreased detection time of 11 milliseconds.

To put in place an adaptive Intrusion Detection System (IDS) that can recognize when a Fog node has been hacked, and then take the appropriate action to ensure communication availability [26], authors in [18] developed an Anomaly Behavior Analysis Methodology based on Artificial Neural Networks [27] and ensemble approach [21]. The training dataset was produced using the IoT testbed. The accuracy rate of the approach was 97.51%.

Similarly, the authors in [22] suggested a variational long short-term memory (VL-STM) learning model based on reconstructed feature representation for intelligent anomaly identification. Experiments using the publicly available UNSW-NB15 IBD dataset demonstrate that the proposed VLSTM model can successfully address the imbalance and high dimensional issues, and that it can also significantly improve accuracy and decrease false rates in anomaly detection.

By dividing the Intrusion Detection System functions across the fog nodes and the cloud, [19] low resource overheads are achieved. As a result, an accuracy of up to 98.8% was achieved. In addition, compared to installing a neural network on the fog node a 10% decrease in the energy usage of the fog node is observed.

This work is novel since it develops intrusion detection for IoT traffic using SDN and deep learning. SDN enables intelligent network management by separating the control and data planes. In the current IDS, deep learning-based classifiers outperform traditional classifiers in terms of results. Any infiltration in networking systems, in particular IoT networks, is detected by the suggested model [28]. Current existing work related to Anomaly detection is listed in Table 1.

Encryption is necessary to safeguard and stop such errors in transmitting delicate data over the internet and other networks. To strengthen the safety of the delicate data or information, the author created an improved variety of the Caesar cipher in this paper and developed a technique in which flexible arithmetic is used to transform plaintext into ciphertext. The author also created a decryption method that is entirely unrelated to encryption by incorporating divisibility tests and arithmetic modulo.


**Table 1.** Existing work on Anomaly detection.

The conventional approach to situational awareness prediction in network security is comparatively simple. For perception and prediction, only one algorithm is typically utilized, and its prediction accuracy is constrained. This study optimizes a radial basis function (RBF) neural network using the simulated annealing (SA) algorithm and the hybrid hierarchy genetic algorithm (HHGA). Hence, it constructs an RBF neural network prediction model based on the HHGA optimization and performs relevant experiments to investigate the application impact of intelligent learning algorithms. The results show that the projected scenario value of the enhanced RBF neural network is relatively close to the actual situation value in 15 instances. The neural network has a significant predictive influence and can assist with network security maintenance [29].

To the best of our knowledge, no study demonstrates which ML algorithm is efficient for the identification of IoT dangerous traffic, despite numerous research proposals on various identification models for accurately detecting IoT malicious traffic. Most academics conduct experiments to evaluate the ML algorithm's performance and, based on the results, they choose the most efficient method. However, it is crucial to research and find the most efficient machine learning method for anomaly and intrusion in IoT network traffic identification by reviewing frequently cited and primarily studied literature reviews.

#### **3. Materials and Methods**

#### *3.1. Dataset*

The current analysis employed the UNSW-NB15 dataset as a benchmark [30–32] Previous datasets, including NSLKDD [33], KDD98, KDDCUP 99 [34], CIDDS-001, DARPA, and ADFA were already accessible for Network Intrusion Detection System (NIDS) research [35]. These datasets, most of which date back more than 20 years, have several limitations, making them unreliable and out-of-date. Such datasets are no longer thought to provide a complete or accurate representation of contemporary attack environments, and algorithms trained on such datasets will not exhibit realistic output performance. These databases distort regular traffic and exclude modern attack types, making it simple for stealthy/spy attacks to pass for normal activity.

The following dataset-specific issues also exist: n uneven number of records from various types of traffic, an excessive number of attacks, incomplete training sets that do not accurately reflect all attacks found in the testing set, a dearth of validation work, data generation techniques, and low data rates, etc. [36,37].

The Australian Center for Cyber Security (ACCS) produced a more recent dataset in collaboration with several specialists worldwide to solve the problems presented by earlier datasets in the field. It has been a publicly available dataset for the current NIDS since 2015. As indicated in Table 2, the dataset has 45 total network attributes, including flow and network-based properties. Flow, fundamental, substance, time, and other created features are additional classifications. Approximately 2.5 million CSV-formatted records in total, including 175,341 training data and 82,331 testing data, constitute the entire dataset. The training and testing datasets are devoid of duplicate data to guarantee NIDS evaluation

dependability. Two distinct traffic labels are initially applied to the dataset (attack and normal). The attack categories in Table 2 are further classified into nine more class types according to the attack type.

**Table 2.** Type of attacks present in the UNSW-NB15 dataset and their description.


#### *3.2. Data Preprocessing*

In Machine Learning, more data results in more accurate models. However, data from the real world is inconsistent, noisy, incomplete, and consists of missing values as it is compiled utilizing data mining and storage. Therefore, it is crucial to pre-process raw data into the processed form. The data preparation enhances data quality so that valuable insights can be extracted. This will be beneficial for model development and training. The approaches used to pre-process the UNSW-NB15 dataset are described below.

#### *3.3. Data Cleaning*

We tried to list the count of the missing values in the dataset corresponding to each feature. The feature "service" had 94,168 missing values for the train set and 47,153 for the test set. After removing the records with missing features, the count of the records corresponding to each class in the total dataset has been reduced. Figure 2 shows the modified distribution of categories in the total dataset.

#### *3.4. Data Transformation*

The characteristics "proto", "service", "state", and "attack cat" contained categorical information that could not be directly put into the ML models. We utilized "One-hotencoding" to encode absolute values into the binary format, except for "attack cat," which was the target multiclass attack label that the model had to predict. The columns of the three one-hot encoded characteristics were eliminated, bringing the total number of classes to 61.

The range of the numerical characteristics in the dataset is varied. Therefore, it was essential to normalize the values. Except for the "id" and "label" columns, the numerical feature columns have been normalized using the "MinMaxScaler."

For binary categorization of the characteristics into "normal" and "abnormal", the "labels" column was encoded using LabelEncoder() as "0" for the normal class and "1" for the abnormal class. Again, the binary dataset contains 61 columns.

For multiclass classification, the "'attack cat' attribute's nine categories were label encoded as 0 ('Analysis'), 1 ('Backdoor'), 2 ('DoS'), 3('Exploits'), 4('Fuzzers'), 5('Generic'), 6('Normal'), 7('Reconnaissance'), and 8 ('Worms'). Consequently, the total number of attributes in the multiclass classification dataset has increased to 69.

(**b**)

**Figure 2.** The distribution of labels (**a**). Binary (train) (**b**). Multiclass (train).

#### *3.5. Feature Selection*

Feature selection is essential for the efficient training of machine-learning models [38]. This is because the selection of the features contributes the most to accomplishing a task and eliminates unneeded or redundant qualities [39]; otherwise, the model can learn from noise

and collect insignificant patterns. Consequently, feature selection enhances processing and prediction reliability [38]. In this paper, correlation-based feature selection is used.

#### *3.6. Model Development*

#### 3.6.1. Tab Transformer

The widely used Transformer design by the authors in [40] served as an inspiration for the TabTransformer architecture that was developed by the authors in [41]. A column embedding layer, a stack of N Transformer layers, and a multilayer perceptron are the components of the suggested design [42]. As described by [43], each Transformer layer comprises a position-wise feed-forward layer, followed by a multi-head self-attention layer. In the study that we are currently presenting, we have utilized a variation from the modified tab transformer model that was proposed by the authors in [44]. The proposed model is illustrated in Figure 3. The revised version only utilized the Tab-transformer's capability to handle the continuous input features. It removed the categorical features and the subsequent normalization layer and concatenation layer related to these features. In other words, it only used the Tab-transformer to handle the continuous features in the input.

The detailed methodology for detecting anomalies in the fog node is depicted in Figure 4.

#### 3.6.2. Model Training Pipeline

Following the data cleaning process, there were 1,41,321 data samples. In total, 80% of those samples were designated for training, while the remaining 20% were used for testing. The sklearn and keras libraries were utilized during the development of the machine learning models. Pytorch-widedeep is responsible for the implementation of the Tab transformer. A total of ten epochs were used to train the tab transformer model. On the NVidia T4 GPU with 40 GB of RAM, it took 15 s for each epoch to complete.

#### 3.6.3. Performance Evaluation

Accuracy: The ratio of the number of correct predictions to the total number of predictions represents how often the classifier makes accurate predictions, as shown in Equation (1).

Recall: The fraction of true positives successfully identified, as shown in Equation (2) Precision: Proportion of anticipated positives that are positive, as shown in Equation

(3)

F1 score: The harmonic mean of recall and precision, as shown in Equation (4).

$$\text{Accuracy} = (\text{TP} + \text{TN}) / (\text{TP} + \text{TN} + \text{FP} + \text{FN}) \tag{1}$$

$$\text{Recall} = \text{TP} / (\text{TP} + \text{FN}) \tag{2}$$

$$\text{Precision} = \text{TP} / (\text{TP} + \text{FP}) \tag{3}$$

$$\text{F1 score} = (2 \times \text{Precision} \times \text{Recall}) / (\text{Precision} + \text{Recall}) \tag{4}$$

**Figure 3.** Proposed tab transformer architecture.

**Figure 4.** Proposed framework for detecting anomalies from IoT networks.

#### **4. Experimental Environment and Discussion**

#### *4.1. Feature Selection*

The correlation matrix of features in the binary dataset, excluding the ID feature, is displayed in Figure 5. This graph illustrates the correlational relationships between dataset features. Using this correlation matrix graph, the importance of the features could be understood. By establishing a threshold for the correlation with the label feature, we extracted only the most essential features from the binary dataset with 61 features.

#### *4.2. Multi Class Dataset*

The Correlation matrix of features in the multiclass dataset is shown in Figure 6. From the correlation matrix, if a threshold of 0.3 is set up with the label class 14, features were shortlisted from the 69 feature original multiclass dataset.

#### *4.3. Model Development*

This section presents the results of the numerous machine-learning algorithms used for the chosen features. The results of the many classic machine learning models applied to the binary classification (normal/abnormal problem) are provided in Table 3.

**Table 3.** Performance of the binary classification (normal/abnormal) of the data by various ML models.


The 1D CNN performs significantly better than other classical ML models when classifying the records into normal or abnormal categories. However, the performance of the other models was not significantly lower. There is only a marginal discernible change in performance. The results of the various machine learning models' attempts to categorize

the data into multiple attack categories are presented in Table 4. Here also, the highest performance was achieved was for 1D CNN.

**Figure 5.** Correlation matrix of the binary class dataset. By setting a threshold value of >0.3, 15 columns of features from the total of 61 columns are extracted.

**Figure 6.** Correlation matrix of the multiclass dataset.


**Table 4.** Performance of the multiclass classification of the data by various ML models.

#### 4.3.1. Performance of the Customized Tab Transformer

A performance evaluation of the tab transformer was carried out after it had been subjected to training for 10 iterations. The metrics accuracy and loss plots for the training and validation datasets showed that the model performed adequately on both the training dataset and the validation dataset when the training was being carried out. This indicated that the model was fit for use. On the other hand, loss plots displayed a steady decline in quality, in contrast to the train's accuracy and the test set's results, which both showed persistent signs of improvement. The plots in Figures 7 and 8 revealed that the model did not clearly demonstrate any evidence of overfitting. This was likely because the feature selection handled the risk of learning unnecessary information, which may cause the model to learn noise data and cause overfitting.

**Figure 7.** Loss plot for binary classification (Normal vs. Abnormal).

**Figure 8.** Loss plot for multiclass classification (9 classes).

The performance of the suggested tab transformer for the binary and multi-class problems is provided in Table 5. According to the findings, the newly generated model has a better performance than the typical machine learning models that have already been produced, as shown in Tables 3 and 4. In addition, the findings suggested that the Tab transformer model was only trained for a limited amount of epochs, yet despite this, it was still able to demonstrate satisfactory performance.


**Table 5.** Performance evaluation of the tab transformer.

However, in addition, an essential piece of information was gleaned from the findings. In the instance of the binary classification, every single metric of evaluation was granted a satisfactory performance. Even the problem of uneven data distribution between the normal and abnormal classes does not significantly contribute to the model's performance with regard to the classification of the groups (Figure 9a), as Table 6 demonstrates.

Tab transformer (multiclass) 97.22% 57.37% 55.02% 56.04%

**Table 6.** Class wise performance evaluation of the tab transformer for binary classification.


However, the accuracy metrics were the only ones that showed a positive result in terms of multiclass classification. The remaining metrics hovered around the 50% mark. This may be because of a problem with the data imbalance that occurred in the initial dataset. Table 7 also reveals that certain classes, such as "Analysis," "Backdoor," and "Worms," which have a smaller number of instances of representation in the training dataset, exhibit an almost null value in terms of their precision, recall, and accuracy as they were confused with the over-represented classes (Figure 9b).

**Table 7.** Class wise performance evaluation of the tab transformer for multiclass classification.


#### 4.3.2. Discussion

Safety concerns were elevated since the systems that are based on the Internet of Things are advancing at a rapid rate. Any action taken by an Internet of Things device that was not intended to be taken could result in significant damage; hence, these devices need to be carefully monitored [45]. However, on the cloud side, it is incredibly challenging to solve the problem since a large volume of data is arriving at this end. Hence, a more effective strategy would be the identification of unexpected patterns of data, also known as an anomaly, on the side of the fog node. As a result, the security procedures can be implemented at the fog layer, which is where fog nodes are located when distributed processing is being carried out. As a result, the burden of performing costly computations and storing data on IoT devices might be offloaded. This would enable the deployment of distributed security mechanisms. The primary objective of this design should be to reduce the number of false positive alert instances while improving its detection accuracy.

**Figure 9.** Confusion metrics of the (**a**) Binary and (**b**) Multiclass anomaly detection.

Even though several methods already exist for anomaly detection in the IoT devices, such as firewalls and rule-based methods systems [44], these methods appear insufficient to detect unknown attacks. This is likely because these methods cannot keep up with the most recent and sophisticated attacks, which are most researched and have different patterns mostly unknown to the rule book. In addition, it is incredibly challenging to write rules for each of these attacks.

Therefore, the most effective way would be ML-based methods, as they enable systems to learn and improve through the use of historical data. ML-based computer programs do not require explicit engineering (programmed). They are capable of acquiring knowledge on their own [46]. In the vast majority of prior research for identifying abnormalities at the fog node, conventional ML techniques were utilized and yielded the best results. The authors of [47] employed conventional ML techniques to identify assaults at the edge nodes with the highest accuracy. As demonstrated in Tables 3 and 4, we obtained the best performance for classical ML models in our study. Before the classification task, feature selection has been performed in the two studies. In contrast to [47] which utilized Chi-Square which is a filter-based feature selection strategy, we have employed correlationbased methods. As a result, the feature selection technique undoubtedly contributed to the high accuracy, as it eliminates noise from the data. However, the most significant contribution to the fog node-based anomaly detection model is that it will be lightweight. Since it handles a smaller number of features, allowing it to be readily integrated into fog nodes, which are battery-constrained and computationally less intelligent.

The model created in this paper is a tab transformer that outperforms existing machine learning (ML) approaches and is a novel way of studying anomaly detection in the UNSW-NB15 dataset, as shown in Table 5. We use this technique since it has been demonstrated to be superior to contemporary deep neural networks for tabular data [44] and aim to manage data imbalance issues more intelligently. Given that our study comprises both binary and multi-classification and that the dataset at hand is uneven in terms of the number of entries (Figure 4), accuracy may not be an adequate performance evaluation criterion. Precision, recall, and F1 scores must be considered while comparing various ML algorithms [48]. The binary classification results (normal/abnormal) demonstrated that the Tab transformer model proposed successfully resolved the data imbalance issue, as shown in Table 6.

On the other hand, for multiclass classification, the model achieved high accuracy but not the highest performance in other metrics, as shown in Table 5. However, the performance by class revealed a correlation between the performance of the class prediction and the quantity of class samples, as shown in Table 7. Prior to training a tab transformer, a data augmentation strategy or more data collection for underrepresented classes would be preferable. Clearly, balancing the dataset may generate a model with Tab transformer architecture that performs better.

Even though the data imbalance affected the multiclass categorization of anomalies, the suggested model appears to beat many of the earlier attempts in the domain using the same dataset. Despite the complexity of the feature selection procedure, authors in [49] were able to reach an accuracy of 90.85% for binary classification using decision tree architecture. The decision tree model built in our work got 96.5% accuracy with correlation-based feature selection. Support Vector Machine was created by authors in [48] to detect binary and multiclass abnormalities in the dataset. However, the model's accuracy was poor. As the original dataset is unbalanced between classes, the work suggested that accuracy may not be the most appropriate performance metric for evaluating the model's performance. Authors in [50] developed an integrated rule-based approach that differs from machine learning (ML) approaches. The proposed work is truly innovative as they developed rules for understanding the features from multiple classes. However, they were only able to achieve an accuracy of 84.83% on multiclass prediction, as shown in Table 8. These values are taken from the respective papers.

It would appear that the proposed model is a more suitable fit for the identification of abnormalities at the fog node. On the other hand, a few items need to be addressed shortly, such asthe problem of data imbalance when performing multiclass classification tasks. In addition, the fact that edge devices typically have limited computation resources and memory, necessitates that we carry out an exhaustive analysis of the amount of computation power and the amount of time necessary for predicting anomalies, given that the model must be deployed at the fog node.


**Table 8.** Comparison of the proposed model with the previously developed anomaly/intrusion detection systems.

#### **5. Conclusions**

The current work proposed a fog-based anomaly detection system for IoT networks. Implementation of anomaly detection indicated that fog nodes can be utilised effectively in decentralizing an IoT-based network based on cloud architecture. The suggested model was developed on the UNSW-NB15 dataset and employed its architecture to identify aberrant traffic in IoT networks. The proposed detection technique reduced the number of features for multiclass and binary datasets using correlation-based feature selection. However, the test dataset remains unbalanced. Yet, both the ML and suggested Tab transformers demonstrated satisfactory performance. Our Tab transformer design outperforms conventional ML models and obtained 98.35% accuracy on binary classification (Normal vs. Abnormal Traffic) and 97.22% accuracy on multiclass detection jobs.

Furthermore, by comparing the performance of the proposed model to that of previously created models on the same dataset, we have proved the significance of the correlationbased feature selection method. As IoT devices have varying memory capacity, network bandwidth, and battery life limits, we might construct a lightweight anomaly detection model by utilising an optimum collection of attributes. In the future, we intend to test the performance of the proposed model utilising additional balanced IoT-based data sets and conduct performance research of the proposed model in terms of computation complexity and time.

#### *Limitations*

Although the proposed methodology to detect anomalies is performing better than others, it nevertheless has some limitations.

The computational complexity can be further increased by applying new data augmentation techniques. Further, the parameters can be reduced by applying customized models. A few more features can be added to enhance the accuracy further.

The techniques used by authors in [19] are lightweight and human immune. Whereas we applied the Tab transformer technique in our work, we offer a novel intrusion detection model capable of deployment at the fog nodes to detect the undesired traffic towards the IoT devices by leveraging features from the UNSW-NB15 dataset. A further limitation of the study [19] is that it did not give a comparison with other works. Furthermore, no technical details are clearly mentioned about the features extracted and how the processing was done.

**Author Contributions:** A.I.A.A. drafted the problem and designed the methodology for implementation, A.A.-R. worked on dataset exploration, A.K. started the implementation beginning with data preprocessing, M.A. implanted the designed plan, M.M.A. drafted the initial paper and designed the sections, M.Z. finished the paper writing aby compiling all the sections together with formatting and English proofreading. All authors have read and agreed to the published version of the manuscript.

**Funding:** This study is supported by Princess Nourah bint Abdulrahman University Researchers Supporting Project number (PNURSP2022R235), Princess Nourah bint Abdulrahman University, Riyadh, Saudi Arabia.

**Data Availability Statement:** The dataset used in this work is taken from: N. Moustafa and J. Slay, "UNSW-NB15: a comprehensive data set for network intrusion detection systems (UNSW-NB15 network data set)," in 2015 military communications and information systems conference (MilCIS), 2015, pp. 1–6.

**Conflicts of Interest:** The authors declare no conflict of interest.

#### **References**


### *Article* **Machine-Learning-Based COVID-19 Detection with Enhanced cGAN Technique Using X-ray Images**

**Monia Hamdi 1, Amel Ksibi 2, Manel Ayadi 2,\*, Hela Elmannai <sup>1</sup> and Abdullah I. A. Alzahrani <sup>3</sup>**


**Abstract:** The coronavirus disease pandemic (COVID-19) is a contemporary disease. It first appeared in 2019 and has sparked a lot of attention in the public media and recent studies due to its rapid spread around the world in recent years and the fact that it has infected millions of individuals. Many people have died in such a short time. In recent years, several studies in artificial intelligence and machine learning have been published to aid clinicians in diagnosing and detecting viruses before they spread throughout the body, recovery monitoring, disease prediction, surveillance, tracking, and a variety of other applications. This paper aims to use chest X-ray images to diagnose and detect COVID-19 disease. The dataset used in this work is the COVID-19 RADIOGRAPHY DATABASE, which was released in 2020 and consisted of four classes. The work is conducted on two classes of interest: the normal class, which indicates that the person is not infected with the coronavirus, and the infected class, which suggests that the person is infected with the coronavirus. The COVID-19 classification indicates that the person has been infected with the coronavirus. Because of the large number of unbalanced images in both classes (more than 10,000 in the normal class and less than 4000 in the COVID-19 class), as well as the difficulties in obtaining or gathering more medical images, we took advantage of the generative network in this project to produce fresh samples that appear real to balance the quantity of photographs in each class. This paper used a conditional generative adversarial network (cGAN) to solve the problem. In the Data Preparation Section of the paper, the architecture of the employed cGAN will be explored in detail. As a classification model, we employed the VGG16. The Materials and Methods Section contains detailed information on the planning and hyperparameters. We put our improved model to the test on a test set of 20% of the total data. We achieved 99.76 percent correctness for both the GAN and the VGG16 models with a variety of preprocessing processes and hyperparameter settings.

**Keywords:** COVID-19; pretrained models; convolutional neural networks; generative adversarial networks

### **1. Introduction**

Coronavirus is a novel disease that emerged in 2019. The COVID-19 disease initially appeared in Wuhan, China. After that, it spread worldwide. In just over a year, the number of individuals infected with COVID-19 had topped 450 million, with more than six million deaths reported in more than 200 countries. Moreover, the World Health Organization [1] recognizes this number of cases, but in reality, the number of individuals diseased with the coronavirus is perhaps much more complex than what is mentioned. Infection caused by COVID-19 can result in many complications, such as acute respiratory distress syndrome (ARDS), pneumonia, septic shock, multi-organ failure, lung injury, acute liver injury, etc. Moreover, besides the damages, symptoms, and effects caused by COVID-19, it has caused global economic damage worldwide due to business closures and reduced productivity.

**Citation:** Hamdi, M.; Ksibi, A.; Ayadi, M.; Elmannai, H.; Alzahrani, A.I.A. Machine-Learning-Based COVID-19 Detection with Enhanced cGAN Technique Using X-ray Images. *Electronics* **2022**, *11*, 3880. https:// doi.org/10.3390/electronics11233880

Academic Editor: Chunjie Zhang

Received: 25 October 2022 Accepted: 21 November 2022 Published: 24 November 2022

**Publisher's Note:** MDPI stays neutral with regard to jurisdictional claims in published maps and institutional affiliations.

**Copyright:** © 2022 by the authors. Licensee MDPI, Basel, Switzerland. This article is an open access article distributed under the terms and conditions of the Creative Commons Attribution (CC BY) license (https:// creativecommons.org/licenses/by/ 4.0/).

There have been many tests that have been used [2] to test for positivity or negativity against COVID-19, such as molecular tests, which take a sample from the throat, the nose, or both using a cotton swab, and the PCR test, which is the most common test used among other tests. It also works by taking a sample from the patient to detect infection with COVID-19 or not. Rapid diagnostic tests (RDT) see the presence of viral proteins. Moreover, the virus can be detected through chest X-rays, C.T. scans, and many other ways. Artificial intelligence and machine learning have shown much progress in many applications, such as patient diagnosis and recovery from the COVID-19 pandemic. Many research papers have been released in the past two years to address the COVID-19 pandemic by helping and assisting doctors and replacing the manual process of detecting, diagnosing, and tracking patients through effective automatic ways.

Modern artificial intelligence (A.I.) uses machine-learning algorithms to find patterns and relationships in data, as opposed to knowledge-based A.I. from previous generations, which relied on professionals' prior medical knowledge and formulation-based rules [3–5]. Deep learning, which uses a large, labeled dataset to train an artificial neural network, has a significant impact on A.I.'s recent comeback [3–5]. A sophisticated deep-learning network typically contains numerous hidden layers [6]. Because of the recent response of A.I., many people are now wondering whether robot doctors will take the place of human doctors shortly. In the meantime, experts think that AI-driven intelligent systems may considerably assist doctors in making better judgments (e.g., radiography) and even eliminate the need for human decisions [7]. The recent success of A.I. in healthcare may be attributed to increased healthcare data due to the more significant usage of digital technologies and big data analytics. Because of the widespread use of mobile devices, it is now much simpler to gather and obtain these data using mobile apps [8]. Even though A.I. research in healthcare is on the rise, most studies focus on cancer, neurology, and cardiovascular disease. With evidence-based A.I., medical data may be mined for insights that can be utilized for decision making and forecasting [9–11]. Researchers believe that artificial intelligence (A.I.) might be valuable in the battle against COVID-19, given its success in healthcare. Artificial Intelligence (A.I.) has revolutionized healthcare by predicting pandemics and creating antiviral compounds. According to recent studies, A.I. can detect COVID-19 infection and infected populations, as well as future breakout, attack patterns, and even treatment [12–14]. In the last few years, A.I. [15], such as biological data mining and machine-learning (ML) techniques [16], has been used to help with detection, diagnosis, classification, and vaccine development for COVID-19 [17].

The new coronavirus infection may be diagnosed using artificial intelligence tactics such as case-based reasoning (CBR) [18], latent semantic transformation (LSTM) [19], and sentiment analysis [20]. On the other hand, the CNN model's techniques are more successful and promising. Since a significant deal of research has been conducted on using CNN to recognize and categorize COVID-19 in digital images, the deep-learning model has appeared as one of the most often used and successful ways, as indicated by the results of recent studies on the topic [21–29]. COVID-19 may be diagnosed with clinical pictures, C.T. scans of the chest, and X-rays of the chest. Because of the system's capacity to automatically learn qualities from digital photos, it has been shown to be very successful. Deep machine-learning algorithms have a lot of advantages, and it would be beneficial to examine how they may be improved for even greater efficiency. When building CNNs and other image-processing algorithms, past research did not pay enough attention to selecting hyperparameters.

We focused on recognizing X-rays using convolution neural networks (CNNs) and the generation of chest X-ray images in this work due to the limited number of photos available, mainly medical photographs, which are difficult to gather manually. The generative adversarial network (GAN), one of the most influential and innovative methods for creating realistic images, is used to generate chest X-ray images. Using GAN, working with little datasets and rendering images is also highly powerful and successful. As a result, GAN

improves classification accuracy by increasing and balancing the number of images in the dataset.

The significant contributions of the work can be summarized as follows:


The Abstract summarizes the main aspects of the paper and the obtained results. The Section 1 is the Introduction, followed by the Section 2, the Literature Review, which shows the current state-of-the-art works. Then, Section 3 is Materials and Methods, which has a Dataset subsection (Section 3.1) with its description and the number of images in detail. The Section 3.4 discusses the proposed methodologies used too. The Section 3.5 discusses the results obtained after using the GAN and modified VGG16 architecture. Then, the Section 4 is dedicated to the discussion, which includes comparative analysis. The Section 5 is the conclusion, which states what we have reached.

#### **2. Literature Review**

The coronavirus illness (COVID-19) has become one of the world's most well-known and challenging problems in the last two years. The coronavirus has affected more than 200 countries. It has infected millions of people, and many others have died from the virus [30]. To fight against COVID-19, many other types of research have been published that introduce innovative solutions and techniques for COVID-19, especially in early detection and diagnosis, drug discovery, and many other areas. Section 2 will emphasize our investigation on the research directing COVID-19 analysis and detection from chest X-ray images.

In [1], CovidGAN sought to produce synthetic chest X-ray images using the GAN classification technique. This network obtained 95% accuracy, 90% sensitivity, and 97% specificity using the combination of three databases (COVID-19 Radiography Database, COVID-19 Chest X-ray Dataset Initiative, and IEEE8023/Covid Chest X-Ray Dataset) with roughly 80% training and 20% testing data. All three of the fine-tuned CNNs achieved accuracy > 99%, sensitivity from 98% (AlexNet) to 100% (GoogLeNet and SqueezeNet), and specificity > 99% using the same database.

In [31], the X-ray images are classified into three classes—normal, bacterial pneumonia, and COVID-19. With roughly 89% (3697 pictures) training, validation (462 images), and testing data and 11% (459 images) testing data, Bayes-SqueezeNet employed the COVID-19 Radiography Database (Kaggle) and IEEE8023/Covid Chest X-Ray Dataset. Data augmentation was used for the network training. The accuracy, specificity, and F1 scores were 98%, 99%, and 0.98, respectively.

The authors of [21] used the IEEE8023/Covid Chest X-Ray Dataset and Chestx-ray8 Database [27] to perform 3-class classification and 2-class classification (COVID-19 and no-findings) (COVID-19, no-findings, and pneumonia). For the 2-class classification, Dark-CovidNet achieved accuracy = 98%, sensitivity = 95%, specificity = 95%, precision = 98%, and F1 score = 0.97 using fivefold cross-validation (80% training and 20% testing).

Moreover, in [23], VGG-19, MobileNet-v2, Inception, Xception, and Inception ResNetv2 were implemented for the classification of COVID-19. The networks were trained and tested using the IEEE8023/Covid Chest X-Ray Dataset and other chest X-rays collected on

the internet. The best 10 cross-validation (90% training and 10% testing) results obtained from the VGG-19 were accuracy = 98.75% for the 2-class classification and 93.48% for the 3-class classification, using the COVID-19 Radiography Database. Some of the studies conducted on features are [32,33].

CNN was used by Polsinelli et al. [34] to identify the virus that causes a bacterial infection—COVID-19. A pre-trained CNN encoder and machine-learning algorithms were used to extract feature representations from the data. The suggested technique has been found to improve COVID-19 detection accuracy by 95.6%. For COVID-19 diagnosis, Shorfuzzaman and Hossain [35] recommended using VGG-16 fast regions with CNNs (R-VGG-16). In X-ray pictures, R-VGG-16 was utilized to distinguish COVID-19. R-VGG-16 has a 97.36 percent accuracy rate in identifying COVID-19 patients. A CNN ResNet version created by Karthik et al. [36] was utilized to determine the new COVID-19 virus in C.T. scans. With the help of ResNet, researchers could accurately notice COVID-19 from C.T. scan images with a correctness rate of 95.09%.

Similarly, [37] used the grey wolf optimization approach to identify COVID-19 patients in a CNN architecture. Accordingly, grey wolf optimization was employed in this work to optimize CNN hyperparameters and create an archetype for COVID-19 recognition. Utilizing X-ray pictures as input, the updated CNN could correctly spot COVID-19 with a correctness value of 97.78 percent. Using a mixture of 3D and 2D CNNs, another work [38] recognized COVID-19 in X-ray pictures [39]. This model's training and merging method employed a CNN, a depth-wise separable convolution layer, and a spatial pyramid pooling module. COVID-19 can be detected 96.71 percent of the time using the suggested method, according to the findings. Bayoudh and coworkers [39] used a transfer-learning-based CNN to identify COVID-19 in X-ray pictures. The YOLO predictor was employed to construct a computer-aided diagnostic tool for the simultaneous recognition and diagnosis of COVID-19 lung sickness from entire chest X-ray pictures, according to Al-antari et al. [40]. They used two different X-ray scans of the chest to evaluate their model. Evidence showed that it performed well in the fivefold tests for multiclass prediction problems. The CAD analyst attained a recognition and classification correctness of more than 90% after the experimental findings, demonstrating its ability to identify and classify COVID-19 contaminated locations accurately. Seven of the twelve current CNN designs were worse than SqueezeNet and DenseNetv2, and two were better; AlexNet and GoogleNet were superior to the proposed CNN architecture. In contrast, VGG16 and VGG19 were found to be inferior to them. For Majeed et al. [41], they created an online attention module that employed a 3D CNN to diagnose COVID-19 from C.T. scan images accurately, and their results have been published in *Nature Communications*.

Based on the above-conducted literature review, it is evident that the accuracy achieved so far is not that promising; moreover, there were so many challenges for computational complexity. So, we designed the below methodology to overcome all these shortcomings.

#### **3. Materials and Methods**

#### *3.1. Dataset*

COVID-19 RADIOGRAPHY DATABASE is the dataset used. A team of investigators from Qatar University released this dataset in 2020 [42] in collaboration with the University of Dhaka. This COVID-19 RADIOGRAPHY DATABASE consists of 4 classes, which are COVID-19 positive cases, normal cases, non-COVID-19 lung infection, and the last one is viral pneumonia class. The dataset consists of about 21,200 images. The images are distributed among the classes in the following way. First, the number of images in the COVID-19 class is 3616, while the number of images in the normal class is 10,192. Next, the number of images in the non-COVID-19 lung infection class is 6012; finally, the number of images in the last class, the viral pneumonia class, is 1345 [43,44].

In this research, we will solely look at the classes that can help us determine whether or not a patient has COVID-19 by looking at the chest X-ray images in this dataset. Therefore, only the COVID-19 and normal classes with 3616 and 10,192 images were examined.

Figure 1 shows random samples of the two classes in the dataset, and Figure 2 shows the data distribution (number of images) of the two classes in the dataset.

**Figure 1.** Random samples of the two classes (COVID-19 class and normal class).

**Figure 2.** Distribution of the two classes in the dataset.

As shown in Figure 2, the number of images of both classes is imbalanced as the COVID-19 class contains less than 4000 images, while the normal class contains above 10,000 images.

#### *3.2. Data PreparationData Augmentation Using Conditional Gan* Data Augmentation Using Conditional Gan

Generative modeling is a branch of unsupervised learning widely used in machine learning and deep learning. The generative models try to learn the structure and distributions of the data to generate new ones that look like real data. GANs are applied in generating images, video generation, audio, text generation, image super-resolution, translation from one image mode to another, etc. The GAN architecture was introduced in 2014 by Goodfellow et al. [45].

GANs have been used widely in data generation and augmentation to deal with the unbalanced classes problem and small datasets problem. GANs have also been used widely in data augmentation in the medical field. In addition, GANs have been widely used in image generation in the medical field to improve classification accuracy, since manual data collection in the medical field is not easy to obtain [46].

We trained the discriminator using real images, for it to know which images are real. The discriminator decides if an image may look real. So, you tell it whether it is real or

that it is fake. This way, you can obtain a discriminator that can differentiate between a poor-quality X-ray generated by the generator in the first epochs and the real ones in the dataset. This is performed until the generator generates perfect fake X-ray samples, which can fool the discriminator. So, with the help of the discriminator, the generator will know in which direction to go and will improve more and more by looking at the scores assigned by the discriminator.

So, the primary purpose of the generator and the discriminator networks is that the Generator tries to learn how to generate good fakes that look real so that the discriminator cannot decide whether an image is fake. So, the generator forges the fake images to look as realistic as possible, hoping to deceive the discriminator. The discriminator learns to differentiate between the real samples from the dataset and the fake ones generated by the generator network.

To make augmentations, we have used conditional GAN (cGAN) to address the problem of the unbalanced number of images in each class. As shown in Figure 3, our cGAN model was used for training both the generator and discriminator networks

**Figure 3.** GAN training for both generator and discriminator networks.

#### *3.3. Data Preprocessing*

Data preprocessing is an initial step to transforming our data into a proper structure before feeding them to the model. Many data preprocessing techniques were used to improve the accuracy of our model results.

• Formatting Images

We resized our images to a fixed size of 224 × 224 to be companionable with applying transfer learning with some models, such as the model used in this paper, which is VGG16.

• Eliminating the noise and flattening the images

Median filtering is applied to flatten the images and further decrease the noise. It is an effective technique which has dual ability to reduce the noise and also preserve the borders of the image.

• Data augmentation

Concerning the medical images which are applied in this study, image augmentation technique is applied to zoom, flip, and control the image brightness.

• Data balancing and filtration

Since our two classes were unbalanced, the normal class had more than 10,000 images, while the COVID-19 class had around 3500 images. As we mentioned, we used GANs for data generation for the COVID-19 class, and we applied random filtration for the normalclass images, achieving almost balanced classes: 6200 images in the COVID-19 class and about 7500 images in the normal class, which made our model much more generic on the test set.

• Splitting Dataset for Training and Testing

The dataset is split with the ratio of 80% training to 20% testing, where the raised model was trained on about 11,000 images and tested on about 2700 images. From the data allotted for training, 20% was separately kept for validation.

#### *3.4. Proposed Methodology*

The generator network was trained with the following procedure:


The discriminator network is trained with the following procedure:


The training of the generator and discriminator networks (G and D) are performed simultaneously [47], where the goal of the generator is to update the parameters to minimize log(1—D(G(z)); on the other hand, the discriminator goal is to update the parameters to reduce logD(X). G is the parameters for generator, and D is the parameters for discriminator. x∼p\_{data}, where x is the real data. By definition of G, G(z) is fake generated data. D(x) is the output of the discriminator for a real input x, and D(G(z)) is the output of the discriminator for fake generated data G(z).

$$\text{minim}\_{\mathbf{G}} \mathbf{max}\_{\mathbf{D}} = \left( \mathbf{E}\_{\mathbf{x}, \mathbf{y} \sim P\_{\text{data}}(\mathbf{x}, \mathbf{y})} \left[ \log \mathbf{D}(\mathbf{x}, \mathbf{y}) \right] \mathbf{E}\_{\mathbf{y} \sim p\mathbf{y}} \left[ \log \left( 1 - \mathbf{D}(\mathbf{G}(\mathbf{z}, \mathbf{y}), \mathbf{y}) \right) \right] \right) \quad \text{(1)}$$

We have used a modified architecture for the generator and discriminator that fits our input images' shape and the best hyperparameters for our data. Figure 4 shows the generator and discriminator networks. The convolution layers in our discriminator network had strides equal to 2 to down-sample the feature map with half the size each time. The last layer used sigmoid activation to assign a probability to the given input image (this probability represents if the image given to the discriminator is real or fake). The transpose convolution layers were used in the generator network to up-sample the feature map with double the size each time. The activation function that is used in the final layer in the generator network is tanh. Leaky ReLU was used with alpha = 0.2 in both discriminator and

**Figure 4.** Discriminator and generator networks.

The generator network in GANs at the early epochs almost produces noisy images, as shown in Figure 5. It took many epochs for the generator to begin learning the underlying structure of the data and be able to produce good fakes that look realistic and fool the discriminator. We can see in Figure 5 some of the random samples generated by the generator during the learning process at different epochs. We can see that the model continues to create more images that look real as the epochs increase. By the advantage of using cGAN, we have generated about 2600 images for the COVID-19 class to reach a total number of images of about 6200 images for the COVID-19 class. First, we trained the proposed conditional GAN model (cGAN) with the X-ray images in our COVID-19 category in the dataset to make the generator learn to produce new X-ray images belonging to the COVID-19 class to increase the images. In that class, the unbalanced problem of the two classes is solved. After the augmentation step, all the images in both classes are resized into 224 × 224.

**Figure 5.** Generator learning process.

#### *3.5. VGG16 Architecture*

VGG16 is a simple and widely used convolutional neural network (CNN) architecture that consists of 16 layers, with 13 layers convolutional layers used for feature extraction and 3 dense (fully connected) layers. VGG16 is used widely in many classification problems due to its simplicity and performance [48]. The VGG16 model is also pretrained on Google ImageNet dataset [49], so it can fine-tune the model rather than train the weights from scratch. The general VGG16 architecture is shown in Figure 6. The input to the network is a colored RGB image of size 224 × 224. The image is fed to a set of convolutional layers with filters, each of size 3 × 3. Those convolution layers are responsible for extracting features, and the number of filters applied on the produced feature maps increases as the network goes deeper to remove high-level features in the latest layers of the network. ReLu activation function is used on the feature map after each convolutional layer to add non-linearity to the network. The convolution layers output a feature map with the exact spatial resolution by setting the padding to 1 and the stride to 1. VGG16 network also consists of 5 MaxPooling layers with a window size 2 × 2 and a stride of 2, which are added after some convolutional layers. They are responsible for reducing the feature map by half the size each time so the input, which is 224 × 224, is reduced five times with half the size by each MaxPooling layer to reach a size of 7 × 7, as shown in Figure 6.

The last feature map is flattened to a 1-d vector of size 7 × 7 × 512 and is passed through fully connected layers, and the final layer in the network is the softmax layer, which distributes probabilities among the classes. Altered VGG16 architecture layers, hyperparameters, optimizer, and so forth are shown in Table 1.

**Figure 6.** VGG16 architecture.

**Table 1.** Summary of the modified VGG16 model.


#### *3.6. Performance Evaluation*

Precision: it is also named the positive predictive value. Precision is the portion of positive predictions separated by the entire number of positive class value forecasts. Equation (2) is used to calculate precision.

$$\text{Precision} = \frac{\text{True Positive}}{\text{True Positive} + \text{False Positive}} \tag{2}$$

Recall is also recognized as sensitivity. It is the portion of positive predictions alienated by the quantity of positive class values. Equation (3) is used to calculate recall.

$$\text{Recall} = \frac{\text{True Positive}}{\text{True Positive} + \text{False Negative}} \tag{3}$$

F1 score is also named the F-score or F-measure. It carries the equilibrium amid precision and recall. It associates precision and recall to guarantee that our model has high precision and has high recall. The worth of the F1 score becomes big only if the standards of precision and recall together are high. F1-score values fall in the interlude [0, 1] and the maximum the value, the improved the classification accuracy [50]. F1 score is calculated by Equation (4).

$$\text{F1-score} = \frac{2 \ast \text{Precision} \ast \text{Recall}}{\text{Precision} + \text{Recall}} \tag{4}$$

#### **4. Results**

The model performance was evaluated on the testing set by splitting the dataset into training and testing sets with 80 and 20%, respectively. We first made many preprocessing techniques, as we mentioned in the preprocessing section (Section 3.3), and then we used our modified VGG16 model to differentiate between normal and COVID-19 classes.

The model performed better on the normal-class images due to the difference in the number of images for each class, where the normal class has much more images than the COVID-19 class, as we mentioned in the Dataset subsection (Section 3.1).

Figure 7 displays the correctness of the training versus the correctness of the validation over 27 epochs on the altered VGG-16 model, while Figure 8 shows the losses of training and validation across 27 epochs on the altered VGG-16 model, whereas Figures 9 and 10 show the accuracy and loss among 27 epochs for our modified VGG-16 model, but after using our cGAN model for generating images for the COVID-19 class.

**Figure 7.** Model accuracy before using cGAN.

**Figure 8.** Model loss before using cGAN.

**Figure 9.** Model accuracy after using cGAN.

**Figure 10.** Model Loss after using cGAN.

We can conclude from the figures that the model, when using cGAN, seems to be more stable and achieves higher accuracy and lower loss on the testing set.

We also evaluated our model based on parameters such as confusion matrix, precision, recall, F1 score, and ROC.

The confusion matrix tells us how our model performs with each class individually. The model is evaluated using these metrics to ensure that the model performs well in each class independently. The main difference is the accuracy metric, which measures the performance on the whole test set, which may be biased to one of the classes more than others. Figures 11 and 12 show the confusion matrix of our VGG-16 model with and without using cGAN, respectively. The confusion matrix is an N\*N matrix where N is the number of classes (two in our case).

The confusion matrices show that the model performs much better for the COVID-19 class when using cGAN, as in Figure 11. We can see that out of the COVID-19 testing images, 27 of them were misclassified as being normal. In contrast, actually, they were cases of COVID-19. This refers to a type-II error and is treated as a crisis in the medical world, since if someone is diagnosed as normal and actually has COVID-19, he will be left without treatment, which may lead to further side effects or even death. After using the cGAN model and plotting our confusion matrix again, as shown in Figure 12, it can be seen that the number of patients misclassified as normal, while actually having COVID-19, is decreased from twenty-seven samples to only two samples, while also considering that the testing set after applying GANs is much larger for the COVID-19 class. In Figure 11, there are 27 cases misclassified as normal out of 420 total COVID-19 images. In Figure 12, there are 2 misclassified cases as normal out of 1200 total COVID-19 images. So, the misclassification ratio decreased from 0.065 to 0.0015 after applying our GAN model for data generation purposes.

**Figure 11.** Confusion matrix before using cGAN.

**Figure 12.** Confusion matrix after using cGAN.

In addition, we employed other assessment metrics such as precision and recall and F1 score to assess the proposed model separately concerning class independently for our two classes, normal and COVID-19.

As shown in Tables 2 and 3, we can see that the model's performance in terms of precision and recall and F1 score has improved slightly after using cGAN.


**Table 2.** Model evaluation on different metrices before using cGAN.

**Table 3.** Model evaluation on different metrices after using cGAN.


Another used metric for evaluation is the area under the curve (AUC), which measures a classifier's ability to distinguish between classes by constructing a graph showing the model's performance at different classification thresholds, where the AUC metric value represents the area under the graph. The more complex the AUC, the more improved the model's efficiency at differentiating between the positive and negative classes. When AUC approaches one, the classifier can perfectly distinguish between the two classes.

The achieved AUC score on our classification model before using cGAN is 0.968. The achieved AUC score on our classification model after using cGAN is 0.99877. The summary of our results is shown in Table 4.

**Table 4.** Overview of our obtained final results.


As shown in Table 5, our modified VGG16 has achieved an accuracy of about 98.6%, and when using conditional GAN to solve the unbalanced classes problem by performing data generation for the COVID-19 class to address and solve this problem, the accuracy has been improved to reach 99.76%.

**Table 5.** Displays some of these investigations.


#### *Comparative Analysis*

We analyzed numerous recent papers that are closely connected to our findings.

One disadvantage of the present work is that, because the research involves medical images, we must understand the degree of error involved with each prediction. CNN has a default error since it is unable to provide the amount of inaccuracy associated with each prognosis. It is limited to expressing the probability of each class at the final softmax layer. This can be addressed in subsequent work. Additionally, there should be more focus on additional categories, such as viral pneumonia, bacterial pneumonia, and other lung illnesses, so that the model may be applied in practice.

#### **5. Conclusions**

The primary objective of the research was to construct a classification model capable of determining whether or not COVID-19 was present in the chest X-rays collected. The study used data from the COVID-19 RADIOGRAPHY DATABASE for training and testing. The study concentrated on two distinct groups of interest: normal and COVID-19. To address the issue of unbalanced classes, we used a variety of preprocessing strategies, including data generation using conditional GANs, to increase the number of images in the COVID-19 class. All images were resized to a fixed width and height of 224 × 224. A median filter of size 3 × 3 was used to smooth the image and remove noise. The acquired chest X-ray images were also enhanced using zooming, flipping, image brightness, and contrast. The training dataset contains approximately 11,000 images. The VGG16 model was pre-trained using Google ImageNet data. The model weights were fine-tuned using transfer learning rather than training our model from scratch. The model was validated on over 2700 images and achieved 99.78% accuracy. Additionally, we examined the model performance using metrics such as the confusion matrix, precision, recall, F1 score, and area under the curve (AUC) to confirm that our model performs well across all classes.

**Author Contributions:** M.H. did the implementation and so; M.A. worked on the problem statement and implementation idea and results visualization; A.K. started data searching, preprocessing, and implementation; H.E. started the manuscript drafting and designing; A.I.A.A. proofread and formatted the paper. All authors have read and agreed to the published version of the manuscript.

**Funding:** This research was funded by Princess Nourah bint Abdulrahman University researchers supporting project number (PNURSP2022R125), Princess Nourah bint Abdulrahman University, Riyadh, Saudi Arabia.

**Data Availability Statement:** The dataset used in this study is publicly available at https://www. kaggle.com/datasets/tawsifurrahman/covid19-radiography-database (accessed on 24 October 2022).

**Conflicts of Interest:** The authors declare no conflict of interest.

#### **References**


### *Article* **Realistic Rendering Algorithm for Bubble Generation and Motion in Water**

**Huiling Guo 1,2,3,\*, Hongyu Wang 1,3, Jing Zhao 1,3 and Yong Tang 1,3,\***


**Abstract:** A simplified bubble model and its solver optimization method are proposed to solve the problem of poor realistic simulation and complex solutions for bubble-motion behavior in water. Firstly, the internal velocity of the bubble was avoided, and the bubble model was established by only considering the net flux of the inlet and outlet bubbles, which reduced the computational complexity. The bubble constraint was then introduced into the motion equation of water, and the mixed Euler–Lagrangian method was used to solve it. FLIP particles tracked the bubble position, velocity, and deformation, and the mesh updated the vector field. At the same time, the viscosity term was simplified. Finally, it was combined with implicit incompressible SPH particles to achieve the purpose of volume correction. The experimental results show that the method in this paper can present a simulation effect of bubbles in water with rich detail and a realistic sense, whether compared with actual pictures or with existing methods.

**Keywords:** bubble formation; FLIP method; volume correction

**Citation:** Guo, H.; Wang, H.; Zhao, J.; Tang, Y. Realistic Rendering Algorithm for Bubble Generation and Motion in Water. *Electronics* **2022**, *11*, 3689. https://doi.org/10.3390/ electronics11223689

Academic Editor: Soon Ki Jung

Received: 14 October 2022 Accepted: 8 November 2022 Published: 10 November 2022

**Publisher's Note:** MDPI stays neutral with regard to jurisdictional claims in published maps and institutional affiliations.

**Copyright:** © 2022 by the authors. Licensee MDPI, Basel, Switzerland. This article is an open access article distributed under the terms and conditions of the Creative Commons Attribution (CC BY) license (https:// creativecommons.org/licenses/by/ 4.0/).

#### **1. Introduction**

With the continuous development of computer graphics in fluid simulation, there is a higher requirement for fluid animation, including more artistic visual effects, more accurate flow trends, more efficient calculation methods, and greater micro detail, which brings significant challenges to fluid animation technology based on physics.

Fluid movement in real life usually involves gas. Bubbles are generated due to entrained air flow, whether boiling water, splashing water droplets, or carbonated drinks. If the bubble phenomenon cannot be simulated realistically, the reality of fluid animation is significantly reduced. Where bubbles exist, it represents the coexistence of liquid and gas. However, due to the complexity of bubble motion, the accuracy of detail simulation is relatively high, and the influence of gas on the free surface should also be considered. Bubbles and related two-phase flow phenomena are often ignored in the actual fluid simulation because of the high computational cost. Ideally, one would like to avoid simulating the air or at least drastically reduce the degrees of freedom spent on capturing it. It is standard in bubble simulation to ignore the air and assume a free-surface boundary condition at the interface [1]. However, this treats air as a massless void that collapses when entrained by the liquid because no opposing force is preserving its original volume. Furthermore, water and air differ in density by about three orders of magnitude, making it difficult to solve [2].

Bubbles make the liquid simulation more vivid. In this paper, the Navier–Stokes equation (N–S equation) method is used to simulate the bubble, and the complexity and reality of its calculation are studied.

In order to reduce the computational complexity, based on satisfying the momentum conservation equation of incompressible fluid, and the standard pressure projection, the incompressibility constraint of the bubble is taken as the boundary condition; thus, a simplified model which ignores the change of velocity inside the bubble is established to spare the calculation cost. At the same time, the space segmentation method is used to separate the space in which the particles require retrieving in the form of equidistant grids in the particle neighborhood search phase; only the distance between the space itself and the particles in the adjacent segmentation domain is calculated, which dramatically reduces the calculation time and makes the neighborhood search more efficient.

In enhancing realism, the viscosity coefficient is simplified as a constant extracted in the divergence calculation to correctly represent the vibration deformation of the bubble surface in a low-viscosity liquid. The implicit incompressible Smoothed Particle Hydrodynamics (SPH) particles are generated according to fluid implicit particle (FLIP) interpolation. The pressure quadratic projection is completed by solving the Poisson equation of SPH particles to achieve the purpose of volume correction, which makes the simulation effect of bubbles more realistic and natural.

In summary, the main contributions of this paper are the following:


#### **2. Related Work**

A gas-containing fluid is usually produced in complex fluid scenes. When the fluid flows rapidly and violently, gas will be involved in the fluid. Based on the visual effect, this type of gas material can be divided into large and small-size gas material. We mainly focus on the coupling problem of gas–liquid two-phase flow and simulate the gas phase as the observable bubble in the fluid. Large-size gas materials can show rich gas movement rules. The level set method and SPH method are basically used for bubble modeling and imulation.

As early as 2001, Foster and Fedkiw et al. used the particle level set method to model and simulate bubbles through SPH particles [3], which was well able to simulate bubbles with a resolution lower than a Euler mesh. Greenwood et al. also used the particle level set method to label particles to simulate bubbles but ignored the shape change of the bubbles [4]. At the same time, Song and Zheng et al. used the region level set method to model and simulate large bubbles based on the grid [5,6]. However, this did not successfully simulate bubbles whose shapes were smaller than the size of the established mesh. Kim et al. [7] extended the application of particle level sets in fluid simulation, adding some massless particles, and using the particle system to update their positions when these particles left the host fluid. In 2008, Hong et al. [8] used velocity field to couple the Lagrange and Euler methods to simulate bubbles in water; they used the level set method to simulate deformable bubbles and fluids. At the same time, the Euler method based on the grid was used to capture the movement of water around the bubbles, and gas and liquid were coupled through a velocity field. However, due to the limitation of mesh resolution, it was difficult to simulate the unstable path of foam rising. In 2011, Ihmsen et al. [9] simulated the movement of bubbles and water by the single-phase SPH method, calculated the density and pressure of the gas–liquid two-phase, and then added the tension model to solve the interaction between the two phases. This method solved the problem of a high-density ratio and simulated the complex motion of foam in water. However, the SPH method must ensure a certain distance between particles during initialization. The number of larger-size gas particles is limited by the resolution of fluid particles, which makes the foam less

dense. In 2020, Wang Hui et al. [10] proposed a new Eulerian–Lagrangian hybrid method to simulate air bubbles with moving-least-squares (MLS) particles.

It is not difficult to see that the level set method produces a good effect on the phase interface tracking. The SPH method is more suitable for solving fluid motion on a complex surface because it does not rely on the mesh and it avoids the mesh distortion and reconstruction problems caused by the violent fluid movement. Both methods are widely used, but they include several defects.

The level set method continuously leads to the loss of volume due to numerical errors so that the bubbles disappear. Kim and Busaryev et al. chose the volume control method to solve this problem [11,12]. By strengthening the interaction, the correction of the bubble volume resulted in the arrangement and geometry of the bubbles being more in line with the laws of physics. Stomakhin et al. [13] achieved high resolution by only simulating the narrow band around the critical area, while using FLIP particles to conserve bubble volume in water and accurate interface tracking. However, the density at the free surface was misestimated.

Unlike the above method, Hong et al. simulated the bubbles and surrounding liquid as a multiphase flow [14] and combined the volume of fluid (VOF) and front tracking method to track the bubble surface. Thürey et al. coupled the particle system with the twodimensional shallow-water model to simulate the motion of bubbles and water bodies and to calculate the flow field around bubbles with the experimental curl potential function [15]. Cleary et al. [16] used the gas-diffusion equation to simulate nucleation bubbles generated by gas dissolution. Discrete spheres with fixed shapes modeled bubbles. The coupling between gas materials and liquids was realized by a drag force, while the influence of gas materials on the liquid was ignored. In 2012, Shao et al. [17] proposed a gas material generation model considering gas concentration and solid–liquid velocity difference. When a gas particle on the surface of a rigid body is regarded as a virtual nucleation point, the gas particle will absorb gas from the surrounding liquid and become an observable foam. In 2015, Ando et al. proposed a stream function model solver to calculate complex bubble formation without explicitly solving the gas phase [18]. Aanjaneya [19] developed an efficient solver to accelerate fluid-structure coupling and improve calculation efficiency. In 2020, Ishida et al. [20] took into account the thickness of the gas-particle film and realistically realized the simulation of single foam deformation and multifoam collapse and fusion phenomena. Liu Sinuo et al. [21] suggested a multi-scale gas simulation method in liquid under the unified particle framework to simulate gas particles with different radii and their coupling process with liquid particles, which avoided the instability caused by random initialization. Luan et al. [22] presented a novel velocity transport technology for two individual particles based on the affine particle-in-cell (APIC) method to solve the instability problem caused by fluid-implicit particles and effectively realized the simulation of a two-phase fluid.

Minor errors and incompressibility in surface tracking inevitably lead to volume changes in liquid simulation. Regarding volume correction, Kim et al. [11] first proposed using divergent sources to recover the lost volume. The particle-based model introduced direct particle position and density correction [23–25]. In the context of mesh-based surface tracking, Langlois et al. used scalar fields to track the identity and volume of bubbles so that their volumes could be used to synthesize physical-based bubble sounds [26]. However, the method did not describe how to redistribute the volume, mainly when multiple topological changes coincided in nearby bubbles. Chen et al. [27] proposed an extended-cut grid method to deal with liquid structures with a lower surface tension than grid elements.

Targeting the high computational cost of bubbles in water, we propose a bubble simulation method which simplifies the viscosity term in the motion equation, uses marker and cell (MAC) mesh to capture bubbles and liquids, and strives to simulate bubbles in water as accurately as possible. At the same time, by interpolating FLIP particles, implicit incompressible SPH particles are generated for secondary pressure projection to correct the bubble volume.

#### **3. Bubble Modeling**

#### *3.1. Simplified Bubble Model*

Gases in liquids are usually of different sizes, and visible air masses are regarded as large-size gas materials. When modeling large-size gas particles, the fluid is usually regarded as a gas–liquid two-phase fluid, and a complex physical model including buoyancy and resistance is established. This model can vividly describe the shape change of a bubble, but the calculation process is very time-consuming.

It is well known that air is much lighter than water and cannot usually transfer much momentum to water, but it still has the constraint of incompressibility: bubbles in water retain their volume to a large extent. Therefore, we set the mass of the bubble to be zero, so its momentum can be ignored, and only the net flow into and out of each bubble is considered. Thus, a simplified model that ignores the velocity change inside the bubble is established, saving computational costs. The schematic diagram of the simulation area is shown in Figure 1.

**Figure 1.** Schematic diagram of the simulation area.

To maintain the volume of each bubble, for the *i*th bubble, we have

$$\iint\limits\_{\partial\Omega A\_i} \mathfrak{u}\bullet \mathfrak{m} dA = 0 \tag{1}$$

where *∂*Ω*Ai* is the *i*th bubble region and its surface, and *u* is the velocity of the incompressible liquid.

Formula (1) can be divided into two liquid and solid parts:

$$\begin{cases} \mathbf{B}\_i \boldsymbol{\mu} = \sum\_{\Omega L \cap \partial \Omega\_{A\_i}} \boldsymbol{\mu} \bullet \mathbf{n}\_f d\mathbf{A}\_f \\\\ \boldsymbol{b}\_{S\_i} = \sum\_{\Omega S \cap \partial \Omega\_{A\_i}} \mathbf{u}\_S \bullet \mathbf{n}\_f d\mathbf{A}\_f \end{cases} \tag{2}$$

where Ω*L* is the liquid region, Ω*S* is the solid region, and **B***<sup>i</sup>* is a row vector representing the discretization of the *i*-th bubble constraint, summing the net flow through the bubble incident liquid surface. Similarly, the solid contribution is expressed as *bSi* , **n***<sup>f</sup>* is the unit face-normal oriented outside the bubble area, and A*<sup>f</sup>* refers to the area of the associated face.

#### *3.2. Solution of Motion Equation*

In this paper, the fluid momentum equation under incompressible conditions is considered

$$\frac{Du}{Dt} = \mathbf{g} - \frac{1}{\rho} \nabla \mathbf{p} \tag{3}$$

Here *ρ* is the liquid density and **p** is the liquid pressure.

All bubble constraints are represented by a matrix **B**, and the constraint conditions are associated with the pressure term; the partial differential equation is obtained as follows:

$$\begin{cases} \rho \frac{\partial \mathbf{u}}{\partial t} = -\nabla \mathbf{p} - \frac{\partial \mathbf{B}^T}{\partial \mathbf{u}} \lambda, \\ \nabla \bullet \mathbf{u} = 0, \\ \mathbf{B} \mathbf{u} = -\mathbf{b}\_S \end{cases} \tag{4}$$

where *λ* is a Lagrange multiplier assigned to each bubble.

When the FLIP method is used to solve the problem, the pressure and velocity of FLIP particles are first interpolated into the grid, and then solved on the grid. The increment of these physical attributes is then interpolated back to the particles to update the particle attributes.

These are directly discretized on the standard MAC grid. As shown in Figure 2, the MAC grid is a cross-arranged grid divided based on spatial coordinates. The grid and liquid are independent, and different physical quantities are stored in different grid positions. The center of the grid stores the fluid pressure value, and the edge part stores the fluid velocity field vector in a staggered order. The tracing of the liquid interface is shown in Figure 3. The lattice with liquid is marked L, the lattice with a solid boundary is marked S, and the air area is marked A. These grids form the shape of the liquid surface. In the calculation process, velocity and pressure are the main variables of the flow field. The pressure and velocity fields of fluid motion are obtained by solving the continuous equation for a viscous incompressible fluid.

**Figure 2.** MAC mesh structure.


**Figure 3.** Marker grid 2D example.

#### *3.3. Viscous Fluid Simulation*

The appearance of a viscous force reduces the bubble velocity and has a non-negligible effect on the bubble shape. After solving the linear system of **p** and *λ*, the influence of the viscosity term on velocity is solved.

The momentum equation of a viscous liquid is as follows:

$$\frac{Du}{Dt} = -\frac{1}{\rho}\nabla \mathbf{p} + \frac{1}{\rho}\nabla \bullet (\mu(\nabla \mathbf{u} + \nabla \mathbf{u}^T))\tag{5}$$

Generally, the simulation of bubbles in highly viscous liquids is not considered. In order to reduce the computational complexity, in this paper, the above formula is simplified. For simulation scenarios with a constant viscosity, that is, when *μ* is constant, the viscosity term can be removed from the divergence calculation, and Formula (5) is expressed as

$$\begin{array}{rcl} \frac{Du}{Dt} &=& -\frac{1}{\rho} \nabla \mathbf{p} + \frac{\mu}{\rho} \nabla \bullet (\nabla \boldsymbol{\mu} + \nabla \boldsymbol{\mu}^T) \\ &=& -\frac{1}{\rho} \nabla \mathbf{p} + \frac{\mu}{\rho} [\nabla \bullet \nabla \boldsymbol{\mu} + \nabla \bullet (\nabla \boldsymbol{\mu}^T)] \end{array} \tag{6}$$

The results are as follows:

$$\frac{Du}{Dt} = -\frac{1}{\rho}\nabla \mathbf{p} + \frac{\mu}{\rho} [\nabla \bullet \nabla \boldsymbol{\mu} + \nabla (\nabla \bullet \boldsymbol{u}))] \tag{7}$$

and the liquid is incompressible, that is ∇•*<sup>u</sup>* = 0. Finally, if the kinematic viscosity *<sup>ν</sup>* <sup>=</sup> *<sup>μ</sup> <sup>ρ</sup>* is substituted, the viscosity term is expressed as *<sup>ν</sup>*∇2*u*.

Either the explicit or implicit method can be used to solve the viscosity term in the grid. If the forward Euler method is directly used to calculate the viscosity term at the current velocity *u*, the change of the velocity field can be directly calculated at each time step. The central difference is used as the difference scheme of the Laplace operator

$$\begin{array}{lcl}\nabla^2 \mathfrak{u}^n &= L(\mathfrak{u}^n) \\ &= \frac{\underline{u\_{i+1}} - 2\underline{u\_{i}} + \underline{u\_{i-1}}}{\underline{\Delta x^2}} \dot{\mathfrak{l}} \\ &+ \frac{\underline{u\_{i+1}} - 2\underline{u\_{i}} + \underline{u\_{i-1}}}{\underline{\Delta y^2}} \dot{\mathfrak{l}} \\ &+ \frac{\underline{u\_{k+1}} - 2\underline{u\_{k}} + \underline{u\_{k-1}}}{\underline{\Delta z^2}} \dot{\mathfrak{k}} \end{array} \tag{8}$$

The stability conditions are as follows:

$$\frac{12\Delta\hbar\nu}{\Delta x^2} < 1$$

It is not difficult to see that this method is strictly limited by step size and will cause unstable diffusion solution if the accuracy is insufficient.

In this paper, the backward Euler iterative equation is used to obtain:

$$
\mu^{n+1} = \mu^n + \Delta t \nu \nabla^2 \mu^n \tag{9}
$$

The equation is a linear equation system, which is solved similarly to the pressure term. The conjugate gradient descent method is used to solve linear equations, and the bubble motion in liquids with different viscosities can be simulated by controlling the input of different viscosity coefficients.

The pressure solution is forced on the free surface **p** = 0. At the solid boundary *u*•**n** = *uS*•**n**. The no-slip boundary condition is used to solve the viscosity term.

#### **4. Bubble Tracking**

Having built the physical model for simulating the bubble, reducing the bubble volume drift is the next significant difficulty. Due to the instability of bubble shape, the complexity of motion, and the uncertainty of a moving object's position, it is incredibly complex to correctly simulate bubble tracking.

Previous bubble simulation methods mostly used the level set or particle methods to directly track liquid and bubble regions. This paper extends the primary FLIP method by simply labeling any non-solid or non-liquid region as a bubble. However, it is well known that cumulative numerical advection errors can cause liquid FLIP particles to separate or aggregate over time, resulting in erroneous volume changes and occasionally creating false space gaps or voids. Since we do not explicitly track the geometry of the air, the liquid volume drift destructively modifies the implied bubble volume, while the artificial void produces false bubbles that begin to rise. Precise tracking of the bubble material can prevent voids, reduce volume changes to some extent, and add additional overhead. In order to solve these problems, we implicitly tracked bubbles by adding bubble ID attributes to each FLIP particle.

We used the old bubble ID stored on the FLIP particles in the previous step and the bubble ID assigned to the new bubble regions to build a bubble ID map from the previous time step to the next time step. As shown in Figure 4, the particle was initially assigned its adjacent bubble ID, and after advection, the particle ID was used to map the old bubble ID to the new bubble ID. The mapping forms a bipartite graph in which uppercase letters represent the old bubble regions, lowercase letters represent the new bubble regions in the next step, and edges indicate whether the bubbles are simply advected or have undergone more complex mergers and splits.

(**a**) old bubble regions

(**b**) new bubble regions

**Figure 4.** Bubble ID mapping legend.

Using this mapping, the bubbles generated without reason correspond to the new bubble ID node with no edges introduced. They can be disintegrated by applying negative divergence. The volume change of the bubble is related to its divergence.

$$\int\_{\Omega\_{\rm B}} \nabla \cdot \mathbf{u} \approx \frac{\upsilon\_{\rm B}^{n+1} - \upsilon\_{\rm B}^{n}}{\Delta t} \tag{10}$$

Here bubble collapse is driven by setting *<sup>v</sup>n*+<sup>1</sup> *<sup>B</sup>* <sup>=</sup> 0.

#### **5. Correction for Bubble Volume**

#### *5.1. Generation of SPH Particles*

Although it is a standard simplified method to solve the viscosity term from pressure projection, it undoubtedly increases the divergence and the effect of volume drift. Therefore, a second pressure projection is required before the advection of the velocity field is continued.

Instead of directly repeating the previous pressure projection step on the mesh, the volume correction is realized by combining FLIP particles with implicit incompressible SPH (IISPH) particles to solve the Poisson equation with a constant density as the source term.

Eight FLIP particles are interpolated in each time step to generate one SPH particle. A Kernel function is used to interpolate FLIP particle velocity *u* to initialize SPH particle velocity.

$$
\mu\_p = \frac{\sum \mu W}{\sum W} \tag{11}
$$

The kernel function *W* = *W*( ( (*xi* − *xj* ( (, 2*h*), where *xi* and *x*<sup>j</sup> are the positions of these two particles, and *h* is the distance between SPH particles.

The pressure Poisson equation is as follows:

$$\nabla^2 \mathbf{p} = \frac{\rho\_0 - \rho^{n+1}}{\Delta t^2} \tag{12}$$

After discretization, solving of **p** updates *up*, and corrects the density change due to divergence to a static density.

Next, we interpolate the SPH particle's velocity back to the flip particle:

$$
\mu^\* = \mu + \frac{\sum \mu\_p(\mathbf{t} + \Delta \mathbf{t}) - \mu\_p W}{\sum W} \tag{13}
$$

#### *5.2. Hybrid Particle Implementation of Flip*

SPH is usually affected by numerical diffusion, especially in low-resolution simulation. FLIP particles help to maintain vorticity and add detail to SPH particles. Therefore, this paper uses a combination of Z-index sorting and compact hashing to accelerate the neighborhood search. This paper briefly introduces how to optimize the neighborhood search in the framework and improve the simulation efficiency and effect by selecting the appropriate acceleration structure.

All FLIP and SPH particles are stored in the same uniform mesh, with a side length of r = 2 h, because the same cubic spline kernel is used in IISPH with the support being 2 h, used for the velocity interpolation between the FLIP and SPH particles. Before the pressure calculation, the adjacent FLIP and SPH particles are queried for each SPH particle.

The efficiency of neighborhood search contributes to the computational cost of the simulation for SPH particle simulation. The related optimization acceleration is, therefore, essential to the research of neighborhood search methods.

Generally, a method based on the regular grid is conducted to accelerate the search. As shown in Figure 5, the current search radius is *h,* and the length of a standard grid is *r*. The position information of the specified particle *i* can be determined by searching

 <sup>2</sup>*<sup>h</sup> <sup>r</sup>* + 1 3 + 1 grids. The advantage of the regular grid is that it is simple to implement, and the time complexity of the neighborhood search is constant in an ideal case.

**Figure 5.** Schematic diagram of particle search based on Grid.

#### **6. Bubble and Rendering of Water Environment**

For the effect that the middle transparency of the bubble in water is higher than the surrounding transparency, we can check the angle between the line of sight and the normal vector of the bubble patch.

Firstly, it is processed in the vertex shader. In the vertex shader, vertex coordinates and normal vectors are transformed into the world, the viewpoint, and the projection without special operation. The processed data are then interpolated into the Fragment Shader for processing. Here, the lighting calculation is performed first, and the Blinn–Phong model is used to calculate the color of the diffuse and specular reflections of the bubbles. The alpha value for alpha blending the pixels with the background is then determined. In order to achieve the effect that the transparency in the middle of the bubble is higher than that of the surrounding transparency, we first calculate the angle between the line of sight and the normal vector

$$\cos \theta = |(n\_{\varepsilon} \cdot \nu)| \tag{14}$$

In fact, what is calculated here is the cosine value of the included angle, but this value is retained for the convenience of calculating the alpha value in the future. In Formula (14), *n*<sup>c</sup> is the normal vector of the pixel in the camera coordinate system, and *ν* is the line of sight direction in the camera coordinate system (i.e., the vector (0,0,1) pointing to the inside of the screen).

As shown in Figure 6, the opacity of the bubble is related to the angle between its normal vector and the line of sight. Only the opacity of the positive face is calculated here. The larger the absolute cosine value of the included angle, the greater the pixel's opacity. On the contrary, the smaller the absolute cosine value of the included angle, the lower the opacity. Therefore, we set the *Alpha* value of the pixel to:

$$Alpha = 1 - \cos\theta \tag{15}$$

to achieve the desired effect.

**Figure 6.** Calculation Schematic.

The rendering of the water body is similar to that of bubbles, except that there is no effect of the intermediate transparency being higher than the surrounding transparency. In the rendering, a simple Blinn–Phong model calculates the lighting, and Alpha Blending is used for blending.

#### **7. Results**

The system environment was a Windows 10 (64-bit) operating system. C++ programming language and Cg shading language were used for programming on the Houdini 3D computer graphics platform and supplemented by Autodesk 3ds Max for modeling, to achieve the simulation of complex bubble motion phenomenon, namely the dynamic tracking simulation of bubbles. The hardware development platform used in this system was an Intel (R) core (TM) i7-8550u CPU @ 1.80 GHz, 1.99 GHz, 8g RAM, and the graphics card was an NVIDIA GeForce MX150.

#### *7.1. Experimental Design*

The experiments in this paper were designed from three aspects: bubble modeling, the underwater bubble motion process, and the large-size bubble effect. The experiment was divided into five groups. The first group of experiments aimed to show the effect of the bubble model after the booming construction. The second group of experiments shows the bubble generation process in a straw. The third group of experiments show the bubble movement process after the viscosity term was added. The fourth group of experiments show the bubble effect under different viscosity coefficients. The fifth group of experiments aimed to show the generation effect of large-scale bubbles. The algorithm simulated the five groups of experiments in this paper, and the results were compared to the real bubble or the current mainstream bubble simulation methods.

Experiment 1. Figure 7 shows the bubble generation process at the bottom of a container. The initial shape of the bubble is spherical because the Lagrange multiplier assigned when modeling the bubble is equivalent to a constant pressure value inside the bubble, and the initial shape is the result of the balance between the internal pressure of the gas and the internal pressure of the liquid. The fluid viscosity and inertia deform the bubble in the following upward motion.

**Figure 7.** Bubble generation effect of the method in this paper.

Figure 7 shows the effects of upward motion at different times after bubble generation. Figure 7a shows the bubbles generated after 80 ms, Figure 7b reveals the bubbles generated after 120 ms, Figure 7c displays the bubbles generated after 160 ms, and Figure 7d indicates the bubbles generated after 200 ms. The objects similar to bubbles in Figure 7a–c below are bubble reflection images, which gradually decrease as the bubbles rise. Due to the constant change of internal and external pressure, the bubble no longer maintains a spherical shape and exhibits a certain degree of deformation when it rises continuously.

Experiment 2. Figure 8a shows the bubble generation process in a straw, captured by mobile phone video. Figure 8b shows the bubble in the water generated by the algorithm in this paper.

(**a**) real picture (**b**) The method in this paper

**Figure 8.** Comparison of the effectiveness of the bubble model.

The bubble in the water is super bright because it is surrounded by a circle of total internal reflection (TIR) with a perfect reflection. The rendering of bubbles in water should be achieved by assigning the bubble geometry higher priority than water and the straw index of refraction inside the bubble. Thus, the refraction effect of light at different boundaries between air, liquid, and glass can be correctly reflected. The experiment demonstrates that our method provides a visually credible bubble effect in water.

Experiment 3. Figure 9 shows the picture of a real bubble, reference [28], reference [23], and this paper on bubble motion. As shown in Figure 9a, the overall shape of the realworld bubble in a low-viscosity liquid is elliptical, and the bottom surface is concave inside, forming a typical bubble jet. As shown in Figure 9b, reference [28] solves the typical Euler equation without considering the viscosity term, so the deformation of a bubble in motion cannot be correctly expressed. As shown in Figure 9c, reference [23] adopts the simulation method of extended particle level set, which can show the complex scene of multiphase flow including bubbles, but it also aggravates the volume loss in the original particle level set method. Figure 9d shows the simplified incompressible viscous fluid model proposed in this paper; this achieves a more realistic bubble motion than the simulation results in references [23,28].

**Figure 9.** Comparison of the bubble movement effect of real picture with the three methods. (**a**) real picture, (**b**) reference [28], (**c**) reference [23], (**d**) this paper.

Experiment 4. Figure 10 shows the bubble effect under different dynamic viscosity coefficients. The viscous force of the liquid hinders the upward movement of the bubble, and the bubble transforms from a spherical shape to an ellipsoid shape. Except for the jet generated on the lower surface due to inertia and viscous force, the surface is subject to a slight oscillation deformation. It can be observed that the larger the dynamic viscosity coefficient, the greater the surface deformation.

(**b**) μ = 1.2

**Figure 10.** Comparison of different viscosity coefficients.

There are more complex and large-scale bubble movements in natural scenes, such as boiling water, a propeller rotating at high speed in the water, gas generated by the underwater explosion, and so on. In order to ensure the realization of real complex motions in nature, a set of comparative experiments were designed to verify the correctness of the proposed theory.

Experiment 5. Figure 11 shows the effects of references [11,29], and this paper on the generation of large-scale bubbles in the water. Reference [11] used the level set method to track the free surface, and divergence was used as the control variable. Although the volume was well maintained, the bubble deformation is unnatural, which is inconsistent with the actual situation. Reference [29] used SPH particles and a new bubble model, but an abnormal phenomenon arises, in that the water particles with high velocity pass through the bubble particles in the simulation. Based on the consideration of the viscosity term, the implicit incompressible SPH particles were generated by interpolation to correct the volume. Figure 11c shows the particle form diagram simulating the large-scale bubbles generation process in a water tank, and Figure 11d is the rendered result, which is more realistic and natural than those in references [11,29].

The method in this paper simulates the movement process of large-sized bubbles with visible shapes in a liquid and achieves an accurate and effective simulation effect. In order to further verify the applicability of our method, we also designed an experiment to simulate the movement process of small-scale gas materials.

Experiment 6. Figure 12 shows the generation process of bubbles at the bottom of a wine glass. Bubbles were generated with the pouring of liquid. Because they are tiny and dense bubbles, the shape is a standard spherical shape. By setting the priorities for glass material, liquid material, and bubble shader, the refraction of light can be correctly reflected, making the scene more realistic and natural.

**Figure 11.** Comparison of generation effects of large-scale bubbles. (**a**) Reference [11], (**b**) Reference [29], (**c**) This paper (in particle form), (**d**) This paper (after rendering).

**Figure 12.** Small-size bubble renderings.

#### *7.2. Program Efficiency*

Table 1 lists the parameters of each scene implemented by the method in this paper. The number of particles is the total number of FLIP particles. The experimental data show that the proposed method can fully meet the needs of bubble-motion simulation in water under large-scale particle numbers.

**Table 1.** Experimental data for different experimental scenarios.


#### **8. Conclusions**

This paper proposes an optimization method for the bubble model and its solver. The viscosity term is considered in the solution of the momentum equation, and it is simplified to a certain extent, and the backward Euler method is used to solve it after the pressure projection. By interpolating FLIP particles to generate incompressible implicit SPH particles and using the secondary projection of new particles on pressure to maintain bubble volume, the accuracy of the simulation is improved, and the problem that bubbles in water cannot be correctly deformed is successfully solved. The motion effect of typical bubbles is realized, and the realism of underwater bubble-motion behavior is significantly improved.

The method in this paper can effectively simulate the formation and movement of large-size bubbles in water. However, the simulation effect of the extremely unstable free surface caused by bubbles requires improvement and the accurate solution of the viscosity term also requires further study. In future work, we will continue to expand our research on the basis of the method in this paper to simulate the behavior of a more abundant multiphase flow in a real and stable way.

**Author Contributions:** Conceptualization, H.G., and H.W.; methodology, H.G., and H.W.; software, H.G., H.W.; validation, H.G., H.W. and J.Z.; formal analysis, H.G.; investigation, H.G.; resources, Y.T. and J.Z.; data curation, H.G., and H.W.; writing—original draft preparation, H.G., and H.W.; writing—review and editing, Y.T.; visualization, H.G., and H.W.; supervision, Y.T., and J.Z.; project administration, Y.T.; funding acquisition, J.Z. All authors have read and agreed to the published version of the manuscript.

**Funding:** This research was funded by the National Natural Science Foundation of China under Grant No. 61902340.

**Institutional Review Board Statement:** Not applicable.

**Informed Consent Statement:** Not applicable.

**Data Availability Statement:** Not applicable.

**Conflicts of Interest:** The authors declare no conflict of interest.

#### **References**


**Zhiqiang Wang \*, Ke Zhang and Bingyan Wang**

Beijing Electronic Science & Technology Institute, Department of Cyberspace Security, Beijing 100070, China **\*** Correspondence: wangzq@besti.edu.cn

**Abstract:** Since the outbreak of COVID-19, the coronavirus has caused a massive threat to people's lives. With the development of artificial intelligence technology, identifying key features in medical images through deep learning, infection cases can be screened quickly and accurately. This paper uses deep-learning-based approaches to classify COVID-19 and normal (healthy) chest X-ray images. To effectively extract medical X-ray image features and improve the detection accuracy of COVID-19 images, this paper extracts the texture features of X-ray images based on the gray level co-occurrence matrix and then realizes feature selection by principal components analysis (PCA) and t-distributed stochastic neighbor embedding (T-SNE) algorithms. To improve the accuracy of X-ray image detection, this paper designs a COVID-19 X-ray image detection model based on the multi-head self-attention mechanism and residual neural network. It applies the multi-head self-attention mechanism to the residual network bottleneck layer. The experimental results show that the multi-head self-attention residual network (MHSA-ResNet) detection model has an accuracy of 95.52% and a precision of 96.02%. It has a good detection effect and can realize the three classifications of COVID-19 pneumonia, common pneumonia, and normal lungs, proving the method's effectiveness and practicability in this paper.

**Keywords:** COVID-19 image detection; deep learning; attention mechanism; residual neural network

#### **1. Introduction**

As shown in Figure 1, as of 25 March 2019, COVID-19 had caused more than 400 million infections and more than 6 million deaths worldwide, according to WHO statistics. As the virus mutates, COVID-19 and its variant strains will continue to threaten the medical and health security of countries worldwide for a long time to come. The current method of screening positive cases through the nucleic acid test can quickly screen positive cases of COVID-19, but it has the disadvantage of low sensitivity. The use of medical imaging, such as computed tomography and chest X-ray, has the advantage of high accuracy in the diagnosis of COVID-19 patients. Medical imaging is often used to help doctors diagnose COVID-19 cases in clinical practice.

Different from traditional manual detection, artificial intelligence-assisted diagnosis and treatment technology enable the computer to automatically learn image features in the way of deep learning, save the learning model, and then make judgments on the basis of the model. Its learning speed, learning efficiency, and accuracy rate are better than manual detection. Under the condition that the number of learning samples is enough and accurate enough, the judgment accuracy of machine learning is also extremely high. Therefore, the use of AI-assisted diagnosis and treatment technology can share the pressure of medical workers and help them diagnose COVID-19 quickly and accurately. This technology can achieve early detection, early isolation, and early treatment, so as to reduce the spread of COVID-19 and the risk of infection of social personnel.

**Citation:** Wang, Z.; Zhang, K.; Wang, B. Detection of COVID-19 Cases Based on Deep Learning with X-ray Images. *Electronics* **2022**, *11*, 3511. https://doi.org/10.3390/ electronics11213511

Academic Editor: Antoni Morell

Received: 1 October 2022 Accepted: 20 October 2022 Published: 28 October 2022

**Publisher's Note:** MDPI stays neutral with regard to jurisdictional claims in published maps and institutional affiliations.

**Copyright:** © 2022 by the authors. Licensee MDPI, Basel, Switzerland. This article is an open access article distributed under the terms and conditions of the Creative Commons Attribution (CC BY) license (https:// creativecommons.org/licenses/by/ 4.0/).

**Figure 1.** Number of COVID-19 cases according to WHO (as of 25 March 2019). This figure is acquired from the WHO database, which is protected by 'CC BY 4.0'.

To better apply the AI-assisted diagnosis and treatment technology in the detection of COVID-19, a large number of medical image datasets of COVID-19 need to be used in the neural network model for training. This training process requires professional doctors to label image data, lung slice data, and infected areas, which leads to a small number of medical image data sets for COVID-19, resulting in low accuracy of model detection. As the coronavirus continues to spread around the world, tens of thousands of chest X-rays are being generated every day. Today, when the epidemic is becoming normal, it is of practical significance to use AI-assisted diagnosis and treatment technology to relieve the pressure on medical resources and help diagnose COVID-19.

To solve the problem of the smaller medical image data samples and low detection accuracy, this paper uses the deep learning method to carry out research on COVID-19 image detection. This paper studies the feature extraction and neural network model design in the process of medical image detection, and carries out the experiment and analysis. The main contributions include the following:


is designed, and the number of parameters and operation efficiency of the neural network are evaluated by experiments.

• Tested and analyzed the detection scheme's effectiveness and the medical images' detection effect. Firstly, we determined the optimal combination of hyperparameters such as convolution kernel size, activation function, optimization algorithm, learning rate, and iteration times of the detection model through experiments. Secondly, we verified the classification effect of COVID-19, common pneumonia and normal lung images by a confusion matrix. Then, we conducted an ablation experiment to evaluate the multi-head self-attention mechanism and the feature extraction scheme based on a gray co-occurrence matrix. Finally, the proposed scheme was compared with the latest image detection methods of COVID-19, including the long short-term memory neural network, twin neural network, and convolutional neural network, to verify the effectiveness of the proposed scheme.

The rest of the paper is organized as follows: Section 2 discusses the relevant work. Section 3 introduces the materials and methods in detail. Section 4 introduces the relevant experiments, and it analyzes and discusses the experiments and results. Finally, we provide the conclusions in Section 5.

#### **2. Related Work**

At present, researchers have conducted a lot of research on image detection of COVID-19, including the use of convolutional neural network for image detection of COVID-19 and the use of attention mechanism for image detection of COVID-19. The following will specifically introduce the research status and progress of image detection of COVID-19.

#### *2.1. Image Detection of COVID-19 Based on the Convolutional Neural Network*

Many scholars use the residual network (ResNet) to carry out experiments and complete feature extraction by calculating residual values. For instance, InceptionV3 neural network, ResNet, and InceptionResNetV2 neural network are used in [1]. These three neural networks are used to detect chest X-ray images to determine whether the patient is infected with COVID-19. In [2], a neural network model based on X-ray images was designed to detect COVID-19. The model implements two tasks. The first is to perform medical image classification, classifying images of patients with COVID-19 and images of normal lungs. The second task is to detect the image by defining the abnormal score in the image detection task to detect COVID-19.

Unlike the ResNet method, Ref. [3] used a transfer learning method for training based on visual geometry group (VGG). In [4], a Bayesian convolutional neural network is used for training, and the model's weight is dynamically adjusted to achieve the training purpose. The authors collected chest X-ray images of COVID-19 positive patients and normal patients from the publicly available kaggle dataset. The authors found through experiments that the accuracy rate of the Bayesian convolutional neural network was 92.9%. Compared with VGG, the detection effect of this method is much better. The authors annotated the focus of neural network on the lung heat map of COVID-19 X-ray, and proposed the optimization direction through the subsequent work of the heat map to further improve the accuracy of detection.

In addition, on the basis of the U-Net neural network, researchers segmented the lung image lesion region, improved the encoder in the feature extraction process to achieve a mosaic feature effect, and thus completed the training on a convolutional neural network [5–8]. Researchers used a convolutional neural network to perform multiple classifications of datasets and divided the datasets into the normal lung, COVID-19, other pneumonia and other types to achieve the effect of multiple classifications [9–12]. In addition, based on U-Net, Ref. [13] adopted the metric-based method in small-sample learning to perform semantic segmentation in the pulmonary infection area of COVID-19 cases and designed an algorithm to dynamically fine-tune the weight of U-Net. Compared with traditional methods, the main difference of this algorithm is that it adopts an online learning paradigm

instead of U-Net static supervised learning, which makes it more effective in segmenting the pulmonary infection region in X-ray images. The authors finally achieved 94.8% accuracy on the experimental dataset. When the sample of the experimental dataset is small, the detection accuracy of the convolutional neural network is not very high. When a convolutional neural network is combined with image segmentation technology, detection accuracy can be improved. However, there are also problems of losing feature information and low efficiency.

From the above literature, it can be found that convolutional neural network is widely used in medical image classification and medical image segmentation. However, the detection accuracy of the convolutional neural network is not very high when the sample of dataset adopted in the experiment is small. When the convolutional neural network is combined with the image segmentation technology, the detection accuracy can be improved.

#### *2.2. Image Detection of COVID-19 Based on an Attention Mechanism*

The memory aggregation network (MA-Net) is proposed in [14], which improves the model by adding an attention mechanism on the pooling layer based on the residual neural network. The authors also applied the multi-scale attention-guided deep network of soft distance regularization (MAG-SD) to automatically classify COVID-19 case images in lung X-ray images. This method can improve the robustness of the training model, solve the problem of lack of training data, and achieve the effect of enhancing image expression and reducing noise. Ref. [15] constructed deep supervised learning with a self-adaptive auxiliary loss (DSN-SAAL) based on the attention mechanism to identify images of COVID-19 cases. The authors added the attention mechanism module to the convolutional neural network to complete the model design, such as the pooling layer of the convolutional neural network. The attention mechanism is used to strengthen the receptive field area of the image for feature extraction to improve the effect of model feature recognition [16,17].

By combining the attention mechanism with the residual network, researchers can calculate the residual value in the residual network through the attention mechanism and obtain high-level features in images to improve the extraction efficiency. Ref. [18] proposed a dual-sampling attention network model based on ResNet34 and adds a dualsampling strategy to alleviate the imbalance in the dataset. In [19], a deep 3D Multi-instance learning neural network based on attention mechanism was proposed, and the pooling method based on attention mechanism was applied to 3D lung computed tomography (CT) images. Based on the residual neural network, Ref. [20] integrated the attention mechanism and used the attention mechanism to complete the calculation of the residual value in the residual neural network, so as to identify the X-ray images of COVID-19. In the neural network, the residual blocks can capture the advanced features of the image and input them into the attention module. Stacking multiple attention modules between the residual blocks can prevent the model from overfitting. On the basis of image segmentation, the excised lesions were fused with the attention module to complete the detection. In [21], a deep neural network based on focus attention was proposed to identify positive cases of COVID-19 through label data of lung CT images.

At present, most of the methods of applying deep learning technology to medical image detection are completed based on a convolutional neural network. The convolutional neural network uses a large number of labeled data sets to train the model to achieve the ideal detection effect. Although the technology of convolutional neural network is mature in image classification, image segmentation, and other problems, it still has the problems of losing feature information and low efficiency. In contrast, the neural network with an attention mechanism can make the model better extract the image's relevant features to improve the model's detection accuracy.

#### **3. Materials and Methods**

This section introduces the datasets and methods used to detect and classify lung X-ray medical images. It outlines the datasets, feature processing techniques, feature selection methods, classification methods, experimental environment and evaluation indicators in this study. The detection process of X-ray images of COVID-19 is shown in Figure 2.

In the detection process, it is necessary to convert the digital imaging and communications in medicine (DICOM) files of X-ray images into portable network graphics (PNG) and other image format files, and then preprocesses the obtained images. The preprocessing includes three interpolation operations and image cropping operations to ensure image normalization. In the feature extraction part, this paper adopts a gray co-occurrence matrix to extract the texture features of medical images and combines PCA and T-SNE algorithms for feature dimension reduction visualization operation. After neural network training, the lung X-ray image detection task is completed. Then, the neural network is trained to complete the task of lung X-ray image detection.

**Figure 2.** Detection process.

#### *3.1. Dataset*

The datasets used in this paper are from the open source lung X-ray datasets from Github, Kaggle, and other websites. We also use image datasets collated in the Covid-Net literature [3,22], where the kaggle dataset is composed of research institutions such as universities and hospitals from places such as Qatar. The researchers created a database of chest X-ray images used to identify positive cases of COVID-19, which is still being updated. The database in Github [23] is annotated and processed by radiologists from Tongji Hospital in Wuhan, China, and contains chest X-ray images of patients with COVID-19 received by Tongji Hospital in Wuhan from January to April 2020. Another open source dataset used on Github [24], from universities in Canada, includes chest X-ray images of COVID-19 and other viral pneumonia.

In this paper, all pneumonia except COVID-19 was classified as common pneumonia because detecting other viral pneumonia was not the focus of this test task. The datasets contain approximately 468 positive cases of COVID-19 and 173 other pneumonia. In this paper, the sorted datasets are divided into three categories: COVID-19, common pneumonia, and normal lungs, and the datasets are divided into the training set, test set, and validation set. The classification of datasets is shown in Table 1. In order to ensure that the model can be fully trained, this paper adjusted the weight parameters during training to make the weight of COVID-19 higher than that of ordinary pneumonia to balance the loss [25].

**Table 1.** Partitioning of datasets.


#### *3.2. Image Preprocessing*

The X-ray medical image format is usually stored in DICOM format. DICOM is an international standard for medical imaging and related information, and one of the most widely used medical and healthcare information standards in the world at present. The data collected in this paper are also stored in DICOM format. When deep learning technology is used for image detection, DICOM files need to be first transformed into commonly used image format files. The pseudo-code of DICOM file transformation is shown in Algorithm 1.


After DICOM file conversion, the open source computer vision library (OpenCV) is also needed to complete image preprocessing operations, including image cropping and cubic interpolation operations. Image cropping can remove unwanted information and determine the receptive field area, which improves the accuracy and speed of processing. Cubic interpolation is to use cubic polynomials to get better interpolation function results, through the interpolation function to calculate the weight of 16 gray values around the location to be calculated, and then according to the weight of 16 gray values calculated

to obtain the gray value of the location to be calculated. The preprocessing operation can improve the speed of feature processing, and the image features can be processed better. Because most medical images are gray maps, after the completion of the preprocessing operation, this paper also generates a gray histogram to describe the gray distribution in X-ray images. All the pixels in the digital image can be counted, and the frequency of their appearance can be counted according to the size of the gray value, so as to facilitate the statistics of where the features in the X-ray image appear. The gray scale is shown in Figure 3. The horizontal axis represents the gray level, which is set to 256, and the vertical axis represents the number of times each gray level occurs. Figure 3 can visually show which part of the medical image has a larger gray value and further identify the location of the lesion.

**Figure 3.** Gray histogram of lung X-ray images.

#### *3.3. Feature Selection*

In this paper, the gray level co-occurrence matrix (GLCM) is used to extract image texture features. Considering that the lung X-ray images contain a large number of features in the gray-level co-occurrence matrix, a dimensionality reduction algorithm is needed to complete the processing. In this paper, the most widely used PCA algorithm is selected, and the T-SNE algorithm is used to complete the visualization operation.

#### 3.3.1. Gray Level Co-Occurrence Matrix (GLCM)

GLCM is one of the important methods for image feature analysis and extraction. It describes image texture by studying the spatial correlation of the gray level. In this paper, the gray level co-occurrence matrix is used to complete the texture feature extraction of medical images, aiming at the application of small-scale datasets to extract image features. The transformation of the gray level co-occurrence matrix is shown in Figure 4.

The gray level co-occurrence matrix adopts four angles: 0°, 45°, 90°, and 135°, which respectively represent different directions of each gray level. Taking Figure 4 as an example, the process of converting a grayscale map with a horizontal direction of 0° and offset distance of 1 into a gray level co-occurrence matrix is shown. In the grayscale map on the left, there is one pair of adjacent horizontal numbers of "0" and "0" in the upper left corner, so the value at [0, 0] in GLCM is 1. There are two pairs of adjacent horizontal numbers of "0" and "1" in the grayscale map, so the value at [0, 1] in GLCM is 2. By analogy, the whole gray level co-occurrence matrix can be obtained. According to the situation, grayscale maps with different angles of 45°, 90°, and 135° can be selected. The matrices with different

angles represent different directions of each gray level, and different offset distances can be selected to generate a grayscale co-occurrence matrix according to needs.

**Figure 4.** Gray matrix conversion diagram. The left side is the gray map, and the right side is matrix.

After obtaining a GLCM, the texture feature parameters can be further obtained from the gray level co-occurrence matrix, and these parameters can describe the texture features of the image. There are up to 14 kinds of eigenvalues [26] in the gray level cooccurrence matrix for statistics: energy, entropy, contrast, uniformity, correlation, variance, sum average, sum variance, sum entropy, difference variance, difference average, difference entropy, correlation information measure, and maximum correlation coefficient. This paper did not use all these feature data in the experiment, but selected some features [27], including angular second moment, correlation, entropy, contrast, contrast score matrix, and energy. Figure 5 shows a lung X-ray image processed with GLCM features.

**Figure 5.** Example of feature map extracted from the gray level co-occurrence matrix.

Angular second moment (*ASM*): Energy is the sum of the squares of the elements of the gray level co-occurrence matrix to describe the uniformity of image gray level distribution and the thickness of texture. The calculation formula is as:

$$ASM = \sum\_{i} \sum\_{j} p(i, j)^{2} \tag{1}$$

Correlation: The correlation degree reflects the similarity degree of spatial gray level co-occurrence matrix elements in the row or column direction, and reflects the local gray level correlation of the image.

Entropy: Measures the randomness (intensity distribution) of an image texture. Entropy is a random measure of the amount of information contained in an image, which shows the complexity of an image.

Contrast: Contrast reflects the sharpness of an image and the depth of furrows in the texture. How the value of the metric matrix is distributed and how much local variation in the image reflects the sharpness of the image and the depth of the furrow in the texture.

Inverse Differential Moment (IDM): The IDM reflects the homogeneity of texture (clarity and regularity) and measures the local changes in image texture.

Energy is the sum of the squares of each element value of the gray level co-occurrence matrix. It is a measure of the stability of the gray level change of the image texture and reflects the uniformity of the gray level distribution of the image and the thickness of the texture.

#### 3.3.2. Principal Component Analysis (PCA)

PCA [28] is a very widely used feature dimensionality reduction algorithm, which can reduce the complexity of data and retain the original data information. PCA uses linear transformation methods such as the eigenvalue decomposition covariance matrix or singular value decomposition covariance matrix to reduce the dimension of original features. To reduce the N-dimensional data to K-dimensional data, the PCA algorithm contains the following steps:


#### 3.3.3. T-Distributed Random Neighborhood Embedding Algorithm (T-SNE)

T-SNE [29] is a widely used technology to reduce dimension and visualize its features. The principle of the algorithm is as follows: similar data points in the high-dimensional space are mapped to similar distances in the low-dimensional space. The gradient formula can thus be simplified by using a T-distribution instead of a Gaussian distribution in low-dimensional Spaces. Its distribution probability formula is as:

$$q\_{i\bar{j}} = \frac{(1 + ||y\_i - y\_{\bar{j}}||^2)^{-1}}{\sum\_{k \neq l} (1 + ||y\_i - y\_{\bar{j}}||^2)^{-1}} \tag{2}$$

Let the points mapped from the higher dimensional space to the lower dimensions be *yi* and *yj*, and *k* be the reference point. *qij* is the probability distribution function that we set. Because we are using a T-distribution instead of a Gaussian in our T-SNE algorithm, we set *<sup>σ</sup>* to be <sup>1</sup> √2 . The gradient of the T-SNE algorithm can be calculated through the distribution probability, and its gradient calculation formula is as:

$$\frac{\partial \mathcal{C}}{\partial y\_i} = 4 \sum\_{j} (p\_{i\bar{j}} - q\_{i\bar{j}}) (y\_i - y\_{\bar{j}}) (1 + ||y\_i - y\_{\bar{j}}||^2)^{-1} \tag{3}$$

In the formula, *C* represents the cost function, *pij* is the probability distribution function of the picture in the higher dimension, and *qij* is the probability distribution function of the picture in the lower dimension. Through the gradient descent method, the gradient of its descent at each point is quickly calculated to achieve the optimal local solution, so as to facilitate the T-SNE algorithm.

T-SNE algorithm is shown in Algorithm 2. The algorithm calculates the probability distribution of each data in the dataset, and compares the calculated cross-entropy with the perplexity value to update the gradient descent rate. The perplexity value can be understood as the number of valid neighboring points around the data point.


3.3.4. Design of the Feature Extraction Scheme

During feature extraction, in order to make the features of medical images prominent, the GLCM was used to extract the texture features of lung X-ray images. The GLCM with the distance of 0◦, 45◦, 90◦, and 135◦ was obtained respectively, and the sliding window size was 5 × 5, and its eigenvalues were obtained. The feature extraction algorithm based on GLCM is depicted in Algorithm 3.

#### **Algorithm 3** Pseudo-code of feature extraction scheme

```
Input: Source image of size M×N
Output: GLCM in four directions, and eigenvalues
//Initialize four T × T GLCMs to store GLCM0◦, GLCM45◦, GLCM90◦, and GLCM135°
for i = 0 to M − 1 do
    for j = 0 to N −1 do
        if j < N − 1 then //GLCM0◦
             GLCM0◦[ Edge[i][j], Edge[i], [j + d]]⇐ GLCM0◦[Edge[i][j], Edge[i], [j + d]] + 1
        end if
        if i < M − 1 and j < N − 1 then //GLCM45◦
             GLCM45◦[Edge[i][j], Edge[i + 1], [j + 1]]⇐ GLCM45◦[Edge[i][j], Edge[i + 1], [j + 1]] + 1
        end if
        if i < M − 1 then //GLCM90◦
             GLCM90◦[Edge[i][j], Edge[i + 1], [j]]⇐ GLCM90◦[Edge[i][j], Edge[i + 1], [j]] + 1
        end if
        if i < M − 1 and j > N − 1 then //GLCM135°
             GLCM135°[Edge[i][j], Edge[i + 1], [j − 1]]⇐ GLCM135°[Edge[i][j], Edge[i + 1], [j − 1]] + 1
        end if
    end for
end for
//Use the greycoprops package to extract the feature values in the grayscale co-occurrence matrix
//Extract {'contrast', 'entropy', 'homogeneity', 'energy', 'correlation', 'ASM'}
greycoprops(GLCM)
```
In view of the large number of features generated by the gray level co-occurrence matrix, in order to improve the efficiency of operation, PCA and T-SNE are used to reduce the dimension algorithm, and the data are visualized to complete the feature extraction scheme. Figure 6 depicts the data results visualized by PCA and T-SNE dimensionality reduction. The left figure shows the effect after PCA dimensionality reduction, and the right figure shows the visualization effect of T-SNE. In the figure, red represents the X-ray images of COVID-19, brown is the X-ray images of normal lungs, and purple is the X-ray images of common pneumonia. The dimension of the X-ray images was reduced and visualized.

**Figure 6.** Renderings after PCA and T-SNE treatment.

### *3.4. Classification Methods*

#### 3.4.1. Design of Detection Model

This section introduces the neural network model designed in this paper, which is named MHSA-ResNet, according to its characteristics. The neural network model was used to detect X-ray images of COVID-19. The neural network model is composed of the following parts, including the convolutional layer, pooling layer, bottleneck layer, and fully connected layer. Among them, the convolution layer in the residual network is divided into several groups to facilitate the calculation of residual values. In the bottleneck layer, the convolution kernel of 1 × 1 is replaced by the multi-head self-attention module to facilitate the extraction of features, shorten the training time and improve the operation efficiency of the neural network. This section also describes the model training process in detail. The sample diagram of the MHSA-ResNet neural network is shown in Figure 7.

Convolution layer: The convolution layer performs convolution operation on the input data and extracts useful information from it. The convolution parameters include the size of the convolution kernel, the step size of each movement of the convolution kernel, and other parameters.

Pooling layer: The pooling layer provides a nonlinear operation that can be performed from a sequence of values in the input matrix and returned. The parameters in the pooling layer include kernel size, which means the window size in the pooling layer. In the model, 3 × 3 means the length and width of the pooling window are 3 × 3. The stride parameter represents the moving distance of each step of the pooling window. In the model, 2 represents the moving of two pixels after each pooling operation. The type parameter indicates the type of the pooling operation, including the maximum pooling operation and the average pooling operation. The maximum pooling layer refers to the retention of the maximum output value in the pooling window matrix. In the average pooling operation, all the values in the pooling window matrix are averaged for retention.

Fully connected layer: parameter num output represents the number of neurons in the fully connected layer, and parameter activation function represents the activation function of the fully connected layer. The ReLu activation function is used in fully connected layer 1 and fully connected layer 2, and the Softmax activation function is used in the last fully connected layer. In the fully connected layer, L2 regularization will be used, that is, the Euclidean norm penalty with weakened weight. The L2 regularization operation will

cause the objective function to input a penalty factor for the sum of squares of the weight data so that the weight parameter is closer to the origin.

**Figure 7.** The neural network model.

Bottleneck layer: The bottleneck layer has the structure of a bottleneck block. Specifically, it uses 1 × 1 convolution block, which is applied in the convolutional neural network and set between two convolution layers of different sizes and different channels.

The training process is as follows: in the first convolution layer, the size of the convolution kernel in this layer is set as 7 × 7, the stride length is 2, and the output image size is 512 × 512. The vector operated through the convolution layer is input into the 3 × 3 maximum pooling layer, and the stride length in the pooling layer is 2. The same convolution operation is used in the second convolutional layer, but the bottleneck structure is added in this layer. The bottleneck structure is composed of three convolution blocks. In the bottleneck layer, a 1 × 1 convolution block is first passed to change the number of channels, and then a 3 × 3 convolution block is used to complete the convolution operation. Finally, a 1 × 1 convolution block is used to restore the number of channels. There are multiple bottleneck blocks in the residual network, which can complete the task of calculating the residual value. The same operation is performed in the third convolutional layer, with the difference in the number of bottleneck blocks and the number of channels. The multi-head self-attention module is added to convolutional layer 4 and convolutional layer 5. The specific model architecture is shown in Table 2.

According to the model listed in Table 2, the residual network is divided into five convolutional layers. In addition to the first convolutional layer, the remaining four convolutional layers complete the operation of calculating the residual value through the bottleneck block, and there is no significant difference in parameters such as convolution kernel, convolution kernel movement step size, and pooling mode. Compared with ResNet50, ResNet152 thickens the third and fourth convolutional layers by increasing the number of bottleneck blocks based on its settings, and the number of output channels and the size of output images do not change.


**Table 2.** Model architecture and comparison.

The MHSA-ResNet model designed in this paper is based on ResNet152, and a multihead attention module is added to the bottleneck block in the last two convolution layers so as to reduce the computation amount and improve the training accuracy of the neural network model without changing the number of channels and the size of each convolution layer.

#### 3.4.2. Multi-Head Self-Attention Module Design

The attention mechanism adopted in this paper is the multi-head self-attention mechanism [30], which combines the characteristics of self-attention and multi-head attention, and its structure is shown in Figure 8. means the sum of the matrix elements by elements, and means the multiplication of each matrix. When the image is executed on the 2D feature map, the height and width of the image feature are input to calculate the range of the segmented receptive field so as to obtain the relative position codes *h* and *w*. The relative position code on the left side of the figure is the matrix calculated by calculating *h* and *w*. Through the addition of the relative position matrix, query matrix, and key matrix, the calculation of attention is completed. The multi-head self-attention formula is as:

$$MHSA = softmax(\sum(QK^T + QR^T))V\tag{4}$$

*R* in the formula is the relative position matrix of the width and height of the 2D feature map obtained by adding elements [31]. Using the relative position matrix to help the image to recognize attention features can effectively improve the efficiency of using attention. By location coding can effectively obtain the model's ability to capture the order of sequence. In the multi-head self-attention, each matrix is carried out by the method of 1 × 1 point-by-point convolution to ensure the recognition accuracy of attention.

In order to fuse the multi-head self-attention module with the bottleneck layer in the residual network, the MHSA attention module is designed in the bottleneck layer.

In the design of the MHSA module, it is necessary to generate the feature map according to the length and width of the image, and the dimension and other features are also used as the parameters of the feature map to calculate the location matrix. In order to combine the MHSA module with the residual network, this paper uses the MHSA module in the last two layers of the residual neural network embedded in the bottleneck layer of the residual network to complete the design of the neural network. Figure 9 shows the structure of the bottleneck layer.

**Figure 8.** Schematic diagram of the multi-headed self-attention structure.

**Figure 9.** Bottleneck block of MHSA.

The calculation of the residual value needs to go through the convolution layer, and the task of modifying the number of channels and the shape of the image is completed through the 1 × 1 convolution block in the convolution layer. In the middle is the MHSA bottleneck block based on attention. After feature extraction through the attention mechanism, the last 1 × 1 convolution block is used to restore the number of channels. The output function after the last operation of the bottleneck block is *F*(*x*). At the same time, the part of the input shape that has not been processed by the convolution block is set as *x*, and the sum of the two is *F*(*x*) + *x*. After the activation function and the regularization operation, we obtain the operation to compute the residual shortcut.

#### 3.4.3. Evaluating the Complexity of the Detection Model

In order to verify the effectiveness of the neural network detection model designed in this paper, this section evaluates the complexity of the neural network through experiments. The complexity of the model is generally evaluated by calculating quantity and parameter quantity. The algorithm complexity is as:

$$Time \sim \mathcal{O}(\sum\_{l=1}^{D} M\_l^2 \cdot K\_l^2 \cdot \mathbb{C}\_{l-1} \cdot \mathbb{C}\_l) \tag{5}$$

$$Space \sim O(\sum\_{l=1}^{D} K\_l^2 \cdot \mathbb{C}\_{l-1} \cdot \mathbb{C}\_l + \sum\_{l=1}^{D} M\_l^2 \cdot \mathbb{C}\_l) \tag{6}$$

Formula (5) describes the computational time complexity (computational quantity), and Formula (6) describes the computational space complexity (parameter quantity), where *K* represents the kernel parameter in the neural network, *C* represents the number of channels, *L* represents the number of layers in the neural network, and *M* represents the feature map parameter. The time complexity and space complexity of neural networks calculated by formulas still need to be evaluated by relevant indicators. The evaluation indicators include:

Params: The number of parameters of the model refers to the total number of parameters to be trained in the neural network, which directly determines the size of the model and also affects the memory usage during inference. The unit is generally *M*.

*FLOPs* : Floating-point operations per second (*FLOPs*) refers to the number of floatingpoint operations, which measures the time complexity of a network model, expressed in giga floating-point operations per second (*GFLOPs*).

In the convolution layer, because the weights in the convolution kernel are shared, the calculation formula for parameter number and the calculation formula for *FLOPs* are as:

$$Parms = K\_{in} \times K\_{out} \times C\_{in} \times C\_{out} + C\_{out} \tag{7}$$

$$FLOPs = 2K\_{\rm in} \times K\_{\rm out} \times C\_{\rm in} \times C\_{\rm out} \times H\_{\rm out} \times W\_{\rm out} \tag{8}$$

The number of parameters in the convolution layer is calculated by multiplying the number of output and output channels by the input convolution kernel and output convolution kernel and finally adding the output feature map. When computing *FLOPs*, you need to output the feature map's height, width, and the number of channels to complete the calculation.

In the fully connected layer, because there is no weight sharing, the *FLOPs* value of the layer is equal to the number of parameters in the layer. The fully connected layer is mainly used to compute the addition operation and the multiplication operation in each neuron. The calculation formula of parameter quantity and *FLOPs* are as:

$$Parms = N\_{in} \times N\_{out} + N\_{out} \tag{9}$$

$$FLOPs = (2N\_{in} - 1) \times N\_{out} \tag{10}$$

Through the above formulas, this paper can count the number of module parameters of the neural network through the tensorboard package. In Table 3, the paper counts the number of parameters of the neural network designed in this paper.

**Table 3.** Comparison of the number of neural network parameters.


Through experiments, it is found that the MHSA-ResNet neural network designed in this paper reduces the number of parameters needed to be calculated and improves the speed of network operation through the attention mechanism, which proves the optimization effect of the attention mechanism on the neural network, so that the recognition efficiency of COVID-19 images in this paper is higher.

#### *3.5. Experimental Situation*

The experimental environment is shown in Table 4. It uses Windows 10 system and Intel(R) Core(TM) i7-4790 CPU 3.6 GHz 16 GB. Each part is written in Python language, and the neural network model is realized with the help of the TensorFlow deep learning framework. In the experiment, NVIDIA GPU GeForce RTX 2070 is used to accelerate the training and operation of the neural network model.

**Table 4.** Experimental environment.


#### *3.6. Evaluation Metrics*

Model performance is tested using performance metrics, and the results can be represented in a table called the confusion matrix, as shown in Table 5, which has four parameter types.

**Table 5.** Evaluation of classification results.


True positive example rate, also known as sensitivity or recall rate, refers to the proportion between the number of positive cases detected and the actual number of positive cases. The calculation formula is as:

$$TPR = \frac{TP}{P} = \frac{TP}{TP + FN} \tag{11}$$

False negative example rate, that is, the proportion of the number of positive samples with negative examples detected in the actual number of positive examples, is calculated using the following formula:

$$FNR = \frac{FN}{P} = \frac{FN}{TP + FN} \tag{12}$$

False positive example rate, that is, the proportion of the number of negative examples detected as positive examples in the actual number of negative examples, is calculated by the following formula:

$$FPR = \frac{FP}{N} = \frac{FP}{FP + TN} \tag{13}$$

True negative example rate, that is, the proportion of the number of negative examples detected as negative examples to the actual number of negative examples, is calculated using the following formula:

$$TNR = \frac{TN}{N} = \frac{TN}{FP + TN} \tag{14}$$

Accuracy refers to the ability of the model to detect samples in the whole dataset, that is, the ability of the model to detect positive samples as positive examples and negative samples as negative examples. The calculation formula is as:

$$\text{ACC} = \frac{TP + TN}{P + N} = \frac{TP + TN}{TP + FN + FP + TN} \tag{15}$$

Precision refers to the ratio between the number of actual positive samples and the number of detected positive samples. The calculation formula is as:

$$PRE = \frac{TP}{TP + FP} \tag{16}$$

Recall refers to the proportion of the number of positive samples with accurate prediction to the number of actual positive samples, in which the actual positive samples include the positive samples with accurate prediction and the negative samples with incorrect prediction. The calculation formula is as:

$$REC = \frac{TP}{TP + FN} \tag{17}$$

F-measure is the weighted average of accuracy and recall and is closer to the smaller of the two. The formula for calculating the *F*1 metric is generally as:

$$F\_{m\text{-asure}} = \frac{(a^2 + 1) \times PRE \times REC}{a^2 \times (PRE + REC)} \tag{18}$$

$$F1 = \frac{2 \times PRE \times REC}{PRE + REC} \times \pi = 1\tag{19}$$

Macro avg is the weighted sum average of precision, recall, and *F*1-score for each category in the confusion matrix. Calculations can be made for specific categories. The calculation formula is as:

$$P\_{\text{micro}} = \frac{\overline{TP}}{\overline{TP} + \overline{FP}} = \frac{\sum\_{i=1}^{n} TP\_i}{\sum\_{i=1}^{n} TP\_i + \sum\_{i=1}^{n} FP\_i} \tag{20}$$

Micro avg is to establish a global confusion matrix for all types in the dataset regardless of category and then calculate the index. In Micro avg, the category with a large sample number dominates the class with a small sample number, and the Micro average of precision rate, recall rate, and *F*1-score can be calculated, respectively. The calculation formula is as:

$$R\_{micro} = \frac{\overline{TP}}{\overline{TP} + \overline{FN}} = \frac{\sum\_{i=1}^{n} TP\_i}{\sum\_{i=1}^{n} TP\_i + \sum\_{i=1}^{n} FN\_i} \tag{21}$$

$$F\_{micro} = \frac{2 \times P\_{micro} \times R\_{micro}}{P\_{micro} + R\_{micro}} \tag{22}$$

Weighted avg: The proportion of the number of samples in each class to the total number of samples in all classes is used as the weight of calculation. It can effectively solve the problem of training data imbalance in classification.

#### **4. Results and Discussion**

#### *4.1. Comparison Experiment of Different Hyperparameters*

The design of the deep neural network model is the core part of the implementation of the COVID-19 image detection model. Based on the backbone network, this section adjusts the hyperparameters, optimizes the model, and evaluates the model's accuracy on the experimental and validation datasets.

#### 4.1.1. Size of the Convolution Kernel and Pooling Mode

The backbone network used in this paper is improved on the premise of the residual network, which is only improved on the bottleneck block of the residual network. Therefore, the size of the convolution kernel and pooling method are consistent with the settings of the ResNet residual network. The convolution kernel [3, 3] is used in the convolution layer, while the convolution kernel [1, 1] is used in the bottleneck layer. The pooling method is the maximum pooling layer.

#### 4.1.2. Comparative Experiment of Activation Function

In this experiment, the neural network model is compared with three activation functions, and the evaluation results are shown in Table 6. Through the experiment, it can be found that, when the rectified liner uint (ReLU) activation function is used, the neural network can reach the highest accuracy of 95.52%. ReLu function is the most common activation function in the field of deep learning, which alleviates the problem of vanishing gradient. The most widely used ReLU activation function has been selected for this paper.

**Table 6.** Comparison experiment of different activation functions.


4.1.3. Comparative Experiment of the Optimization Algorithm

In this paper, three optimization algorithms are evaluated through experiments, and the experimental results are shown in Table 7. Momentum can enhance the stability and convergence speed of gradient descent and is a method to optimize gradient descent. Adam and root mean square prop (RMSProp) are adaptive learning rate algorithms, and their learning rate is not fixed but automatically adjusts the learning rate according to the task situation. The RMSProp algorithm uses an exponentially weighted moving average instead of a gradient sum of squares. Adam algorithm combines the characteristics of Momentum and RMSProp, which is a kind of driven RMSProp. It uses gradient first-order matrix evaluation and second-order matrix evaluation to automatically adjust the learning rate of model parameters. Through the experiment, it can be found that the accuracy of Adam optimization algorithm can reach 95.52%. Compared with the Momentum algorithm and RMSProp algorithm, it can achieve a better convergence rate, and the model learning effect is more ideal.

**Table 7.** Comparison of different optimization algorithms.


#### 4.1.4. Learning Rate

We adopt the Adam adaptive learning rate optimization algorithm to adjust the learning rate according to the situation. In many cases, the ideal effect can be achieved by setting the default value of the learning rate to 0.001. However, the selection of the learning rate should be adjusted according to the actual situation in the experiment. The experimental results of different learning rates through Adam optimization algorithm are shown in Table 8.

According to the experimental results, when the learning rate is set to 0.001, the best detection effect can be achieved, achieving the best detection effect of 95.52%. When the learning rate is set too large, it will bring unsatisfactory results to the experiment. For example, when the learning rate in the experiment is 0.1, the gradient cannot converge quickly, and the accuracy is only 48.03%. In the experiment, when the learning rate is set to 0.0001, it will not contribute greatly to the improvement of the experimental results.

**Table 8.** Comparison of different learning rates.


#### 4.1.5. Number of Iterations

In this paper, the appropriate number of iterations is determined through comparative experiments of different iterations, and the performance of the neural network model under different times is analyzed. The experimental results are shown in Table 9. Because the dataset we used was not very large, the number of iterations was not very many, and the interval of iterations in each experiment was 5. Through experiments, it can be found that the highest accuracy is 95.52% when the number of iterations is set to 45, and the accuracy of the model decreases with the increase of the number of iterations.

**Table 9.** Comparison of different learning rates.


#### *4.2. Parameter Settings and Final Experimental Results of Neural Networks*

Through the hyperparameter comparison experiment in the previous section, this paper determines the hyperparameter configuration of the neural network and also proves the effectiveness of adding a multi-head self-attention module. In this paper, the Adam algorithm is used as the optimization algorithm, the learning rate is set to 0.001, the Dropout rate is set to 0.5, the batch training size is set to process 16 images in each batch, the iteration times are set to 45, and the multi-head self-attention module is added on the bottleneck layer to determine the final neural network model. The model is trained by the training set data, and the model is evaluated by the validation set data. Finally, the accuracy curve and the numerical loss curve of the model are obtained. See Figures 10 and 11.

The data discovery model finally measured in this paper achieved 95.52% accuracy on COVID-19 images, and the measured loss function value was 0.0044.

As one of the most classical loss functions in classification models, cross-entropy is used in this paper. When the number of training iterations increases to more than 45, the loss value does not decrease or increase obviously, which indicates that the model converges to the relative optimal state well, and there is no overfitting phenomenon. It can be found from the change curve of loss value that, when the number of iterations reaches 45, the loss value converges to 0.0044. Because the sample of the data set is relatively small, the loss value will not be reduced to a smaller value, but the loss value will not affect the effectiveness of this experiment.

**Figure 10.** Variation of accuracy curve.

**Figure 11.** Change curve of loss value.

#### *4.3. Evaluation of Classification Effect*

Classification effects were examined using 763 chest X-ray images from a test set, including 239 images of COVID-19, 256 images of normal lungs, and 268 images of common pneumonia. We use the confusion matrix to evaluate the classification effect. Through the confusion matrix, we can calculate the difference between the real value and the predicted value of the model so as to calculate the accuracy, recall, F1-score, and other parameters to evaluate the effectiveness of the deep learning model. The confusion matrix is shown in Table 10, and the visual confusion matrix is shown in Figure 12. In the figure, the rows represent the predictions and the columns represent the true labels. Both predicted and true label scores have three categories: "0" for COVID-19, "1" for normal lungs, and "2" for common pneumonia. On the right is a legend that visualizes quantities as their corresponding colors, with darker colors having smaller values and lighter colors having larger values.


**Table 10.** Classification of data sets by the confusion matrix.

**Figure 12.** Confusion matrix on the test set.

The F1-score, recall rate, and accuracy index of the model can be calculated from the confusion matrix, and the effectiveness of the model can be effectively evaluated through the above indexes. The calculated values are shown in Table 11.

**Table 11.** Comparison of recall, accuracy, and F1-score in different groups.


The curve of accuracy and recall increasing with the number of iterations is shown in Figure 13. In the figure, the blue curve represents COVID-19, the orange curve represents normal lungs, and the green curve represents common pneumonia. Due to the small sample size of the test set, the iterations have converged after more than 20 times. According to the classification effect of confusion matrix in Figure 12, macro average, micro average, and weight average of the confusion matrix can be calculated, and the calculated values are shown in Table 12.

**Figure 13.** Change curve of recall and precision of the classification effect.

**Table 12.** Macro average, micro average, and weight average of each value of confusion matrix.


To evaluate the overall effect of the model, we also use the average value of weight as the evaluation condition. According to statistics in Tables 11 and 12, it is found that, in the classification process of the model, the average weight medium accuracy reaches 96.02%, the recall rate reaches 95.47%, and the F1-score reaches 95.93%. The accuracy of the model's single classification for COVID-19 was 99.12%, the recall rate reached 97.38%, and the F1-score reached 98.32%.

#### *4.4. Ablation Study Results and Analysis*

To verify the effectiveness of each module, this paper conducted ablation research experiments on the scheme to verify the effectiveness of the multi-head self-attention module, PCA dimension reduction algorithm, and T-SNE dimension reduction algorithm. The ablation experiments included the following four, and the experimental results are shown in Table 13.


**Table 13.** Comparison results of different feature extraction schemes.

In Experiment 1, ResNet, a residual network without a multi-head self-attention module, was used for the experiment. ResNet also had the structure of a bottleneck layer, and the residual value was calculated through the bottleneck layer. It verifies the effectiveness of the multi-head self-attention module.

In Experiment 2, the PCA dimensionality reduction algorithm was removed from the feature extraction scheme, and the medical images were directly trained in the neural network after passing the gray level co-occurrence matrix.

In Experiment 3, the T-SNE dimensionality reduction algorithm was removed from the feature extraction scheme. After passing the gray level co-occurrence matrix, the medical images were processed by PCA dimensionality reduction and directly entered the neural network for training.

Experiment 4 uses all modules, that is, the method currently used in this paper.

Experimental results show that the neural network with a multi-head self-attention module can improve detection accuracy compared with ResNet. In the feature extraction scheme. After using the PCA algorithm for dimension reduction, the accuracy of detection can be significantly improved. In Experiment 3, T-SNE was added to the feature extraction scheme, which not only achieved data visualization but also slightly improved the detection accuracy.

#### *4.5. Analysis of the Comparison between the Detection Results and the Latest Detection Methods*

In order to evaluate the effectiveness of the method adopted in this paper, the network model proposed in other papers is referred to, and the comparison experiment is conducted with the convolutional neural network model COVID-19 Net, the long short-term memory LSTM with attention module, and the twin network with an attention mechanism.

In Experiment 1, the COVID-Net [3] is a model used to examine chest X-ray images. Based on the VGG network, this model uses the projection extension reprojection and reexpansion mode (PEPX) to reduce the dimensionality of the image through the convolutional neural network and finally carries out a classification detection through the Softmax layer so as to achieve the classification purpose.

The second experiment used the long short-term memory neural network with an added attention mechanism [16]. It is characterized by the computation of the relationship of each output vector after the last state in the LSTM passes through the linear filter. It is also used for Softmax to rescale, and the image information is obtained by combining the previous output vector. By combining the attention mechanism with LSTM, the effect of LSTM training is enhanced.

Experiment 3 used the twin network with added attention mechanism [32]. Through the mechanism of sharing weights of twin networks, pairs of data are input into neural networks for training. In this experiment, the images of common pneumonia were eliminated to avoid the problem of data imbalance in the training process. In the way of the pairwise combination of COVID-19 positive and normal lungs in the dataset, a pair of positive and negative samples of the twin network were input at the same time, and the images were put into the neural network model to compare and calculate by the feature vectors generated by the model.

In this paper, the organized datasets are respectively substituted into three different neural network models for training, and the experimental results are shown in Table 14.


**Table 14.** Comparison results between this paper and the latest detection methods.

Through comparative experiments, it is found that, compared with the other three neural network architectures, the neural network model designed in this paper has achieved high performance in the four indexes of accuracy, precision, recall, and F1-score, with the accuracy reaching 95.52%, precision reaching 96.02%, recall reaching 95.47% and F1-score reaching 95.93%.

#### **5. Conclusions**

In this paper, a feature processing scheme based on a gray level co-occurrence matrix is proposed. The multi-head self-attention mechanism and residual neural network are combined to detect lung X-ray medical images, which has achieved good experimental results and completed the detection of COVID-19 images.

The main work of this paper is summarized as follows:


Although this study has carried out a lot of exploration and research on the image detection of COVID-19, there are still some places for improvement. Combined with several problems in this paper, we can try to improve on the following aspects:


**Author Contributions:** Conceptualization, Z.W.; methodology, Z.W., K.Z. and B.W.; software, B.W.; validation, Z.W. and K.Z.; formal analysis, B.W.; investigation, K.Z.; resources, K.Z.; data curation, B.W.; writing—original draft preparation, Z.W., K.Z. and B.W.; writing—review and editing, Z.W. and K.Z.; visualization, B.W.; supervision, Z.W.; project administration, Z.W.; funding acquisition, Z.W. and B.W. All authors have read and agreed to the published version of the manuscript.

**Funding:** This work was supported in part by the National Key RD Program of China (No. 2018YFB08 03401), in part by the China Postdoctoral Science Foundation under Grant 2019M650606, in part by the First-Class Discipline Construction Project of Beijing Electronic Science and Technology Institute (No. 3201012).

**Data Availability Statement:** The dataset is publicly available at https://github.com/lindawangg/ COVID-Net, (accessed on 2 September 2022), https://github.com/ieee8023/covid-chestxray-dataset, (accessed on 27 August 2022), https://data.mendeley.com/datasets/rscbjbr9sj/2/files/f12eaf6d-60 23-432f-acc9-80c9d7393433, (accessed on 1 September 2022), and http://openi.nlm.nih.gov/imgs/ collections/ChinaSet\_AllFiles.zip, (accessed on 1 September 2022). The kaggle dataset is available at https://www.kaggle.com/tawsifurrahman/covid19-radiography-database, (accessed on 2 September 2022).

**Conflicts of Interest:** The authors declare no conflict of interest.

#### **Abbreviations**

The following abbreviations are used in this manuscript:


#### **References**


### *Article* **Super-Resolution of Remote Sensing Images for** *×***4 Resolution without Reference Images**

**Yunhe Li 1,\*, Yi Wang 2, Bo Li <sup>1</sup> and Shaohua Wu <sup>3</sup>**


**\*** Correspondence: liyunhe@zqu.edu.cn

**Abstract:** Sentinel-2 satellites can provide free optical remote-sensing images with a spatial resolution of up to 10 M, but the spatial details provided are not enough for many applications, so it is worth considering improving the spatial resolution of Sentinel-2 satellites images through super-resolution (SR). Currently, the most effective SR models are mainly based on deep learning, especially the generative adversarial network (GAN). Models based on GAN need to be trained on LR–HR image pairs. In this paper, a two-step super-resolution generative adversarial network (TS-SRGAN) model is proposed. The first step is having the GAN train the degraded models. Without supervised HR images, only the 10 m resolution images provided by Sentinel-2 satellites are used to generate the degraded images, which are in the same domain as the real LR images, and then to construct the near-natural LR–HR image pairs. The second step is to design a super-resolution generative adversarial network with strengthened perceptual features, to enhance the perceptual effects of the generated images. Through experiments, the proposed method obtained an average NIQE as low as 2.54, and outperformed state-of-the-art models according to other two NR-IQA metrics, such as BRISQUE and PIQE. At the same time, the comparison of the intuitive visual effects of the generated images also proved the effectiveness of TS-SRGAN.

**Keywords:** remote-sensing image; super-resolution; generative adversarial network

### **1. Introduction**

The applications of satellite remote-sensing images are broad, involving agriculture, environmental protection, land use, urban planning, natural disasters, hydrology, climate, etc. [1]. With the continuous updating of optical instruments and other equipment, the spatial resolution of satellite images is constantly improving. For example, Worldview-3/4 satellites can collect eight bands of multi-spectral data with a ground resolution of 1.2 M [2]. However, Worldview-3/4 satellites data cost to acquire, and when covering a large area or performing a multi-temporal analysis, the data cost can be restrictive. Therefore, open-access data with acceptable spatial quality can be considered, such as Landsat [3] or Sentinel [4] data. Sentinel-2 updates remote-sensing images of every location in the world for free approximately every five days, and these remote-sensing images are becoming a more and more important resource for applications. Sentinel-2 uses two satellites to achieve remote-sensing coverage at the equator on a global scale, and provides a multi-resolution layer composed of 13 spectral bands, among which, 10 M resolution images are provided in four RGBN bands, 20 M resolution images are provided in six bands, and 60 M resolution images are provided in the other three bands [4]. The bands with 10 and 20 M resolution are usually used for land cover or water mapping, agriculture or forestry, whereas the band with 60 M resolution is mainly used for water-vapor monitoring [5]. Due to the open data distribution strategy, the 10 M resolution remote-sensing images provided by Sentinel-2 are becoming important resources for some applications. However, such spatial resolution is

**Citation:** Li, Y.; Wang, Y.; Li, B.; Wu, S. Super-Resolution of Remote Sensing Images for ×4 Resolution without Reference Images. *Electronics* **2022**, *11*, 3474. https://doi.org/10.3390/ electronics11213474

Academic Editor: Byung Cheol Song

Received: 30 September 2022 Accepted: 24 October 2022 Published: 26 October 2022

**Publisher's Note:** MDPI stays neutral with regard to jurisdictional claims in published maps and institutional affiliations.

**Copyright:** © 2022 by the authors. Licensee MDPI, Basel, Switzerland. This article is an open access article distributed under the terms and conditions of the Creative Commons Attribution (CC BY) license (https:// creativecommons.org/licenses/by/ 4.0/).

still slightly insufficient for many applications. In order to make the most of freely available Sentinel-2 images, and to achieve spatial resolution of about 2 m, it is worth considering some post-processing methods to obtain the spatial enhancement of LR images and to recover the high-frequency details to generate HR images. In order to improve the spatial resolution of Sentinel-2 images, some researchers [6–12] fused the data of the several bands of Sentinel-2 with 60, 20 and 10 M spatial resolutions to obtain higher spatial-resolution images; however, this paper focuses on SR directly using 10 M resolution images.

Yang et al. [13] and Gou et al. [14] studied a supervised SR model based on dictionary learning, and provided an effective solution by using sparse coding technology. Pan et al. [15] applied structural self-similarity and compressed sensing to SR tasks. Zhang et al. [16] and Li et al. [17] adopted several different image representation spaces in SR to achieve higher performance.

Deep learning has been attracting more and more attention in the field of SR [18]. Deep learning does not need to directly map the relationship between HR and LR domains. As long as there are enough training data, a deep learning network in principle can learn very complex non-linear relationships [19]. Among these techniques, the convolutional neural network (CNN) can make better use of the high-order features of the images to create HR images, and can significantly improve the performance of SR imaging [18]. Dong et al. [19] proposed the SRCNN, which has a strong learning ability. It is based on the CNN, and adopts pixel loss to optimize the network. The result was too smooth without consideration of the perceptual quality. Additionally, on this basis, Kim et al. [20] and Zhang et al. [21] introduced residual learning models, Tai et al. [22] introduced recursive learning models, and Hu et al. [23] introduced an attention mechanism to optimize the deep-learning architecture to improve performance, but these models also had the problem of over-smoothing, because they all relied solely on pixel loss to optimize the network.

Goodfellow et al. [24] proposed the GAN training of two models at the same time, one of which was called the generator (G), and the other was called the discriminator (D). SRGAN [25], proposed by Leding et al., was a pioneering way to implement SR based on GAN theory, and because of its ability to generate images with rich texture and high quality, GAN has been widely used in SR. Wang et al. [26] further improved the SRGAN model, proposed ESRGAN, used a more complex and denser residual-layer combination in the generator, and deleted the batch normalization layer. As an SR model based on GAN was gradually applied to the field of satellite remote sensing, Ma et al. [27] proposed transfer GAN (TGAN) to solve the shortcomings of poor quantity and quality of remote-sensing data. Haut et al. [28,29] and Lei et al. [30] designed the network to form LR–HR image pairs by downsampling the public remote-sensing images, and tested different network architectures. Targeting the remote-sensing images provided by Sentinel-2, Gong et al. [31] proposed the Enlighten-GAN SR model, which adopts an internal inconsistency loss and cropping strategy, and achieved good results in gradient similarity measurement (GSM) for the medium resolution remote-sensing images of Sentinel-2. Sentinel-2 can provide images with a spatial resolution of up to 10 m. In the task of upgrading the resolution from 10 to 2 m, the SR model based on GAN has encountered a great challenge that mainly comes from the lack of real HR images at2Mresolution. Some papers [32,33] used the 10 M resolution images of Sentinel-2 and 2 M HR images of worldview satellites to form LR–HR image pairs to construct training datasets. Galar et al. [32] proposed an SR model based on the enhanced depth residual network (EDSR), and Salgueiro et al. [33] proposed an RS-ESRGAN model based on the ESRGAN. All the proposed models could enhance the 10 M channel of Sentinel-2 to 2 m. However, by using the unnatural low–high image pairs consisting of Sentinel-2 and Worldview images, and other models using BiCubic downsampling to construct LR–HR image pairs [21,25,26,34–37], the track details related to frequency will be lost [38]. In order to solve this problem, inspired by the blind SR model KernelGAN [39] and the blind image denoising model [40], we explicitly estimated the degradation kernel of LR–HR image pairs of natural images through GAN, estimated the distribution of the degraded noises at the same time, and degraded the 10 M resolution

images of Sentinel-2, to construct near-natural LR–HR image datasets. On the basis of these datasets, with reference to the structures of SRGAN, PatchGAN, and VGG-128, TS-SRGAN was designed to implement an SR of Sentinel-2 images from 10 M to 2.5 M.

#### **2. Dataset**

For the convenience of the following analysis, we initially present the datasets used in training and testing. The model proposed in this paper is aimed at Sentinel-2 images, so we used the SEN12MS [41] dataset to train and test the models. SEN12MS contains complete multi-spectral information in geocoded images; it also includes SAR and multispectral images provided by Sentinel-1 and Sentinel-2, and land cover information obtained by MODIS system. This paper mainly focuses on 10 M resolution images of red (B4), green (B3) and blue (B2) bands in multi-spectral images, namely, RGB color images with 10 m resolution. SEN12MS gives Sentinel-2 cloudless images of the region of interest (ROI) at specified time intervals. SEN12MS divides the images into patches with 256 × 256 pixels, which span 128 pixels so that the overlap rate between the adjacent patches is 50%. SEN12MS takes 50% overlap as an ideal compromise between the independence of patches and the maximum number of samples. The SEN12MS dataset obtained a randomly sampled ROI based on four seeds (1158, 1868, 1970 and 2017), and the distribution of ROI is shown in Figure 1.

**Figure 1.** Distribution of regions of interest corresponding to four random seeds.

In this paper, TS-SRGAN used a part of SEN12MS named ROIs1158, which is composed of 56 regions of interest across the globe generated from 1158 seeds from 1 June 2017, to 31 August 2017. ROIs1158 is divided into 56 subsets by region, totaling 40,883 pieces of 256 × 256 pixel images. We randomly selected the subset "ROIs1158\_spring\_106" as the test dataset (ROI\_Te), which contains 784 test images; the remaining 55 subsets, including 40,099 images, were used as the source images dataset (ROI\_Src), and ROI\_Src was degraded, to generate the LR image dataset (ROI\_LR). The source images **I**src in ROI\_Src were directly used as HR images **I**HR in the training, which formed the LR-HR image-pairs dataset (ROI\_Tr) with the images **I**LR in ROI\_LR dataset one by one. This paper compares the performances of the newly proposed models, including EDSR8-RGB [32], RCAN [21], and RS-ESRGAN [33], with the traditional model of BiCubic [42]. BiCubic directly uses ROI\_Te for interpolation testing without training; RCAN takes the images in ROI\_Src as LR images, and generates HR images by BiCubic-interpolating every image to form an LR–HR image pair dataset; the models of EDSR8-RGB and RS-ESRGAN refer to the models proposed in [32] and [33], respectively, to construct a dataset based on ROI\_Src.

#### **3. Methods**

#### *3.1. Structure of TS-SRGAN*

We used TS-SRGAN to generate 2.5 M resolution images **I**SR from 10 M resolution source images **I**src of Sentinel-2, in two stages. In the first stage, KernelGAN was used to implement the estimation of the explicit degradation kernel of **I**src images, and then, along with injecting the degraded noise, the source images **I**src were degraded to LR images **I**LR, which we combined with HR images **I**HR (equivalent to **I**src) to construct LR-HR image pairs (**I**LR,**I**HR). In the second stage, the dataset {(**I**LR,**I**HR)} was used to train the super-resolution generative adversarial network (SR-GAN), which consisted of a superresolution generator (SR-G), a super-resolution discriminator (SR-D), and a super-resolution perceptual-feature extractor (SR-F). TS-SRGAN represents the Sentinel-2 image SR model proposed in this paper, and the structure of TS-SRGAN is shown in Figure 2.

**Figure 2.** Structure of TS-SRGAN on Sentinel-2 remote-sensing images.

#### *3.2. Degraded Model*

Here, we introduce an image degradation model based on the GAN. The natural pairing mapping between HR and LR images can be approximately understood as the degradation relationship between them, and the degradation process can be expressed as follows:

$$\mathbf{I}\_{LR} = (\mathbf{I}\_{SRC} \ast \mathbf{k}^s) \downarrow\_{s} + \mathbf{n} \tag{1}$$

where **k***<sup>s</sup>* and **n** represent the degradation kernel and degraded noise, respectively, and *s* represents the scaling factor. The quality of the degradation kernel and degraded noise determine the relevance between LR–HR image pairs and natural image pairs, and the accuracy of the extracted mapping features from LR and HR resolution images determines the quality of images generated by SR.

#### 3.2.1. Degradation Kernel Estimation

Here, we first consider the noise-free degradation process, assuming that the noise-free LR image **I***LR*\_*cl* is the result of the downsampled HR image **I***SRC*, by using the degradation kernel through the scaling factor *s*:

$$\mathbf{I}\_{LR\\_cl} = (\mathbf{I}\_{SRC} \ast \mathbf{k}^s) \downarrow\_{\mathcal{S}} \tag{2}$$

In this paper, KernelGAN is used to estimate the image degradation kernel **k***<sup>s</sup>* , which is a blind SR degradation kernel estimation model based on Internal-GAN [43] and a completely unsupervised GAN, requiring no extra training data except the image **I***SRC* itself [39]. KernelGAN uses only the images **I***SRC* for training, to learn the distribution of internal pixel patches, with the goal of finding the image-specific degradation kernel and the best degradation kernel to retain the distribution of pixel patches for each scale of the image **I***SRC*. More specifically, the goals are to "generate" downsampled images, and to make the pixel patch distribution of the downsampled images as close to that of the images **I***SRC* as possible. The essence of the model is the extraction of the cross-scale recursive characteristics between LR and HR images through deep learning, and GAN in KernelGAN can be understood as the matching tool for pixel patch distribution. The implementation process of KernelGAN is shown in Figure 3, which illustrates training, using a single input image to learn the distribution of internal pixel patches of the cropped patch. There is a kernel generator (kernel-G) and a kernel discriminator (kernel-D). Both the kernel-G and the kernel-D are fully convolutional, which means that the network is applied to the pixel patch rather than to the whole image. With the given input of images **I***SRC*, the kernel generator will learn to downsample to **I***LR*\_*cl*, whose goal is to make the discriminator indistinguishable from the input images **I***SRC* at the pixel-patch level.

**Figure 3.** Structure of KernelGAN.

The objective function of KernelGAN is defined as follows:

$$\mathcal{G}^\*(\mathbf{I}\_{\mathrm{SRC}}) = \underset{\mathbf{G}}{\mathrm{argmin}} \max\_{\mathbf{D}} \left\{ \mathbb{E}\_{\mathbf{x} \sim \mathrm{poundsness}(\mathbf{I}\_{\mathrm{SRC}})} [|D(\mathbf{I}\_{\mathrm{SRC}}) - 1| + |D(\mathbf{G}(\mathbf{I}\_{\mathrm{SRC}}))|] + R \right\} \tag{3}$$

where *G* represents the generator and *D* represents the discriminator. Additionally, R is the regularization term optimized by the degradation kernel **k***<sup>s</sup>* :

$$\mathcal{R} = \alpha\_{\text{s\\_1}} \mathcal{L}\_{\text{s\\_1}} + \alpha\_{\text{b}} \mathcal{L}\_{\text{b}} + \alpha\_{\text{s\\_P}} \mathcal{L}\_{\text{s\\_P}} + \alpha\_{\text{c}} \mathcal{L}\_{\text{c}} \tag{4}$$

where Ls\_1, Lb,Lsp,L<sup>c</sup> represent losses, and *α*s\_1, *α*b, *α*sp, *α*<sup>c</sup> represent constant coefficients. In this study, the constant coefficients were set according to experience as *α*s\_1 = 0.5, *α*<sup>b</sup> = 0.5, *α*sp = 5, *α*<sup>c</sup> = 1. The losses are defined as the following equations, respectively:

$$\mathcal{L}\_{s\_-1} = \left| 1 - \sum\_{i,j} k\_{i,j} \right| \tag{5}$$

where *ki*,*<sup>j</sup>* represents the parameter value of each point of the degradation kernel, and the goal of <sup>L</sup>*s*\_1 is that the sum of *ki*,*<sup>j</sup>* is 1.

$$\mathcal{L}\_{\mathsf{b}} = \sum\_{i,j} |k\_{i,j} \cdot m\_{i,j}| \tag{6}$$

The goal of L<sup>b</sup> is to punish the non-zero value near the boundary, and *mi*,*<sup>j</sup>* is the constant mask of weight, which increases exponentially with the distance from the center of *ki*,*<sup>j</sup>* .

$$\mathcal{L}\_{\rm sp} = \sum\_{i,j} |k\_{i,j}|^{1/2} \tag{7}$$

The goal of Lsp is the sparsity of *ki*,*<sup>j</sup>* to avoid excess smoothness of the interior kernel.

$$\mathcal{L}\_{\mathbf{c}} = \left\| \left( x\_0, y\_0 \right) - \frac{\sum\_{i,j} k\_{i,j} \cdot \left( i, j \right)}{\sum\_{i,j} k\_{i,j}} \right\|\_{2} \tag{8}$$

The goal of <sup>L</sup><sup>c</sup> is to have the center of *ki*,*<sup>j</sup>* in the center of the interior kernel, and (*x*0, *y*0) represents the indices of the center.

Kernel-G can be regarded as an image downsampling model, which implements linear downsampling mainly through the convolution layer, and the network contains no nonlinear activation unit. A nonlinear generator is not used here because it is possible for the nonlinear generator to generate physically unnecessary solutions for the optimization targets, for example, to generate an image that is not downsampled but contains effective pixel patches. In addition, because the single-layer convolution layer cannot converge accurately, we use the multi-layer structure of linear convolution layers, as in Figure 4.

**Figure 4.** Network structure of the kernel generator consisting of a multi-layer linear convolution layer.

The goal of kernel-D is to learn the distribution of pixel patches in input images **I***SRC* and to distinguish between the real patches and fake patches in the distribution. The real patches are cropped from input images **I***SRC*, and the fake patches are cropped from **I***LR*\_*cl* generated by kernel-G. We use the fully convolutional pixel-patch discriminator introduced in [44] to learn the pixel patch distribution of every single image, as shown in Figure 5.

**Figure 5.** Discriminator network structure consisting of a multi-layer non-pooled convolution layer.

The convolution layer used in kernel-D does not perform pooling operations, instead it implicitly acts on each pixel block and finally generates a hot map (D-map), on which each position corresponds to one cropped patch input. The hot map output by kernel-D represents the possibility of each pixel extracting the surrounding pixel patches from the original pixel patch distribution, and it is used to distinguish the real patches from the fake patches. The loss is defined as the pixel-wise mean-square error between the hot map and the label map. The label map refers to all 1 labels of the real patches and all 0 labels of the fake patches.

After the training of KernelGAN, we do not focus on the generator network, but convolute the convolution layers of kernel-G with the stride of 1, successively, to extract the explicit degradation kernel. Meanwhile, the training of KernelGAN is based on one single input image, **I***SRC*, which means that each input image trains one degradation kernel, and many degradation kernels generated by the training image set will be randomly selected and used in the subsequent steps. The graphical examples of some degradation kernels are shown in Figure 6.

**Figure 6.** Graphical example of a degradation kernel extracted after KernelGAN training.

3.2.2. Generation and Injection of Noise

As opposed to the direct downscaling methods, such as BiCubic, we explicitly inject additional noise into **I***LR*\_*cl*, so as to keep the noise distributions of **I***LR* and **I***SRC* images as consistent as possible. Due to the large variance of the patches with rocky content [38], and inspired by [40,45], when extracting noise-mapping patches we control the variance within a specific range under the condition:

$$D(\mathbf{n}\_i) < \sigma\_{\max} \tag{9}$$

where *D*(·) represents the variance function, and σ*max* represents the maximum value of the variance. The noise-mapping patches are extracted from images selected from the images of ROI\_Src randomly, and a certain number of noise patches are extracted to construct the dataset (ROI\_Noi). The noise-mapping patches used for the noise-injection process are randomly selected from ROI\_Noi.

To sum up, the process of generating LR images in ROI\_LR from the source images in ROI\_Src can be expressed as Equation (10), where I and j are randomly selected:

$$\mathbf{I}\_{LR} = (\mathbf{I}\_{SRC} \ast \mathbf{k}\_i^s) \downarrow\_s + \mathbf{n}\_j \tag{10}$$

#### *3.3. SR-GAN*

SR-GAN consists of a super-resolution generator (SR-G), super-resolution discriminator (SR-D) and perceptual-feature extractor (SR-F). SR-G generates a ×4 high-resolution image through learning the characteristics of the training set data. SR-D and SR-F compare the generated image with the ground truth image, respectively. SR-D feeds back pixel-wise loss and adaptive loss to SR-G, and SR-F feeds back perceptual loss to SR-G, realizing SR-D and SR-F's supervision of SR-G. SR-G was designed on the basis of the ESRGAN [26]. As an ESRGAN discriminator may introduce more artifacts [38], SR-D was designed on the basis of PatchGAN [44]. The perceptual-feature extractor was designed on the basis of VGG-19 [46], so as to introduce the perceptual loss [47], which can strengthen the extraction of low-frequency features and improve the effect of visual perception.

The loss LSR of SR-GAN consists of three parts, including pixel-wise loss L<sup>x</sup> [26], perceptual loss L<sup>p</sup> and adversarial loss L*a*.

$$\mathcal{L}\_{\text{SR}} = \alpha\_{\text{x}} \mathcal{L}\_{\text{x}} + \alpha\_{\text{P}} \mathcal{L}\_{\text{P}} + \alpha\_{\text{a}} \mathcal{L}\_{\text{a}} \tag{11}$$

where *αx*, *α*<sup>p</sup> and *α*<sup>a</sup> are constant coefficients, and the constant coefficients were set according to experience as *α<sup>x</sup>* = 0.01, *α*<sup>p</sup> = 1, *α*<sup>a</sup> = 0.005. The losses Lx, L<sup>p</sup> and L*<sup>a</sup>* are defined as Equations (12), (13) and (16).

$$\mathcal{L}\_{\mathbf{x}} = \mathbb{E}\_{\mathbf{I}\_{LR}} \| G(\mathbf{I}\_{LR}) - \mathbf{I}\_{HR} \|\_{1} \tag{12}$$

Pixel-wise loss L<sup>x</sup> uses L1 distance to evaluate the pixel-wise content loss between *G*(**I***LR*) and **I***HR*.

$$
\mathcal{L}\_{\mathbb{P}} = \lambda\_f \mathcal{L}\_f + \lambda\_t \mathcal{L}\_t \tag{13}
$$

Perceptual loss L<sup>p</sup> evaluates the perceived differences in content and style among different images, and consists of feature reconstructing loss L*<sup>f</sup>* related to content and style reconstructing loss L*t*, where *λ<sup>f</sup>* and *λ<sup>t</sup>* denote constant coefficients, and L*<sup>f</sup>* and L*<sup>t</sup>* can be expressed as follows:

$$\mathcal{L}\_f = \frac{1}{\mathbb{C}\_j H\_j \mathcal{W}\_j} \left\| \phi\_j(\mathbf{G}(\mathbf{I}\_{LR})) - \phi\_j(\mathbf{I}\_{HR}) \right\|\_2^2 \tag{14}$$

$$\mathcal{L}\_{l} = \left\| \frac{1}{\mathbb{C}\_{\text{j}}H\_{\text{j}}W\_{\text{j}}} \sum\_{h=1}^{H\_{\text{j}}} \sum\_{w=1}^{W\_{\text{j}}} \left[ \phi\_{\text{j}}(\mathbf{G}(\mathbf{I}\_{LR}))\_{h,w,\mathcal{L}} \phi\_{\text{j}}(\mathbf{G}(\mathbf{I}\_{LR}))\_{h,w,\mathcal{L}'} - \phi\_{\text{j}}(\mathbf{I}\_{HR})\_{h,w,\mathcal{L}} \phi\_{\text{j}}(\mathbf{I}\_{HR})\_{h,w,\mathcal{L}'} \right] \right\|\_{F}^{2} \tag{15}$$

where *φj*(*I*) represents the characteristic diagram obtained at level *j* of the convolution layer after the image **I** inputs SR-F, and the shape of the obtained characteristic diagram is *Cj* <sup>×</sup> *Hj* <sup>×</sup> *Wj* (Channel <sup>×</sup> Height <sup>×</sup> Width). · <sup>2</sup> *<sup>F</sup>* represents the square Frobenius norm.

$$\mathcal{L}\_a = \sum\_{n=1}^{N} -D(G(\mathbf{I}\_{LR})) \tag{16}$$

Adversarial loss, L*a*, can enhance the texture details of the image, making the visual effect of the generated image more realistic.

The structure of SR-G is shown in Figure 7. Based on the ESRGAN model, and with the RRDB [39] structure, it was trained in the constructed LR-HR image pairs (**I***LR*,**I***HR*), and the resolution of the generated images was magnified ×4.

**Figure 7.** Structure of the super-resolution generator (SR-G).

Due to the discriminator in ESRGAN possibly introducing more artifacts, we used the patch discriminator instead of the VGG-128 discriminator in the ESRGAN model, and SR-D was designed based on PatchGAN [44]. In addition, the patch discriminator was used instead of the VGG-128 discriminator, out of consideration for the following aspects: VGG-128 is applied to images with a size of 128 pixels, which makes training on large scales less powerful; VGG-128, when using a fixed fully connected layer, is better at handling global features rather than local features. [34]. For this reason, we use the patch discriminator, which has a fully convolutional structure, and with the receptive field fixed. Each output value of SR-D depends only on the local fixed patches, so that we can optimize the local details. The average value of all local errors is used as the final error, to guarantee global consistency. The structure of SR-D is shown in Figure 8.

**Figure 8.** Structure of the super-resolution discriminator (SR-D).

Based on the VGG-19 [46] model, this paper introduces the perceptual-feature extractor to extract the perceptual loss Lp, that is, to extract the inactive features in VGG-19. The perceptual loss can enhance the low-frequency features of the images and make the images generated by the generator look more realistic. The structure of the perceptual-feature extractor is shown in Figure 9.

**Figure 9.** Structure of the perceptual-feature extractor (SR-F).

#### **4. Experiments and Results**

#### *Training Details*

The proposed model, TS-SRGAN, and other models, such as EDSR8-RGB, RCAN, RS-ESRGAN and RealSR, were run in a Pytorch environment, using the modules provided by the "sefibk/KernelGAN" project [39], "xinntao/BasicSR" project [48] and "Tencent/Real-SR" project [38] in the Github library. BiCubic can be obtained directly, using Matlab functions to perform interpolation operations.

TS-SRGAN first generates an LR–HR image pair dataset (ROI\_Tr) based on a training dataset (ROI\_Src) for training and testing. We randomly selected 2134 images from 40,099 images of ROI\_Src to generate a degraded kernel dataset (ROI\_Ker) through Kernel-GAN training, one by one, namely, **k***<sup>s</sup> <sup>i</sup>* ∈ {ROI\_Ker}, *i* ∈ {1, 2 ··· 2134}; and then randomly selected 4972 images from 40,099 images of ROI\_Src to extract noise patches, one by one, to form a noise patch dataset (ROI\_Noi), namely, *nj* ∈ {ROI\_Noi}, *j* ∈ {1, 2 ··· 4972}; finally, we used the degradation kernel and injected noise to perform degrading operations on the images in ROI\_Src, one by one. In the processing of each image, the degradation kernel and injected noise were randomly selected from ROI\_Ker and ROI\_Noi.

The network structural parameters of kernel-G and the kernel-D and the constant coefficients of losses of KernelGAN have been mentioned above, so we will not repeat them here. In the training phase, both the generator and the discriminator used an ADAM optimizer with the parameters *β*<sup>1</sup> = 0.5, *β*<sup>2</sup> = 0.999; the learning rates of kernel-G and kernel-D were both set to 0.0002, decrementing by ×0.1 every 750 iterations, and the network was iteratively trained for 3000 epochs.

SR-G used the "RRDBNet" model in the "BasicSR" project, and SR-D used the "NlayerDiscriminator" model in the "Real-SR" project. The networks' structural parameters and the constant coefficients of losses have been mentioned above; therefore, we will not repeat them here. The image was magnified by 4 times, and during the training phase both the generator and the discriminator used the ADAM optimizer with the parameters *β*<sup>1</sup> = 0.9, *β*<sup>2</sup> = 0.999; the learning rates of SR-G and SR-D were both set to 0.0001, and the network was iteratively trained for 60,000 epochs.

Some convolutional layers were used in TS-SRGAN, and these convolutional layers play a vital role. After many tests, it became known that the parameters of the convolutional layer in the network needed to be set as in Table 1, to achieve the ×4 resolution images TS-SRGAN and obtain the image quality we wanted.


**Table 1.** Settings of specific parameters for convolutional layers of TS-SRGAN.

The EDSR8-RGB, RCAN, RS-ESRGAN and RealSR models implemented training and testing under the frameworks of BasicSR [48] and Real-SR [38], with the parameter-setting schemes which have been proven to achieve the best results in references [21,32,33]. The parameters used in the implementations are detailed in Table 2.

As the source images used are already the highest resolution (10 m) images of Sentinel-2, there are no real ground truth images (2.5 M resolution) that can be compared with the generated images in reality, and some image-quality assessment metrics commonly used, such as, PSNR and SSIM, are not applicable in this scenario. Therefore, we adopted nonreference image quality assessment (NR-IQA) metrics, including NIQE [49], BRISQUE [50] and PIQE [51]. The evaluation values of NIQE, BRISQUE and PIQE can be calculated by the corresponding functions NIQE, BRISQE and PIQE in Matlab; the output results of the three functions are all within the range of [0, 100], where the lower the score, the higher the perceived effect.

We randomly selected one sub-dataset, "ROIs1158\_spring\_106," in ROIs1158, as the testing dataset (ROI\_Te) containing 784 images. The remote-sensing images in ROI\_Te were collected from the ground areas, as shown in Figure 10. In the figure we marked eight regions with strong geographic features, and the ×4 generated images of these regions are shown subsequently, to visually compare the differences among those models. Table 2.


**Table 2.** Settings of specific parameters for the models implemented in the framework of BasicSR.

**Figure 10.** Ground map corresponding to sub-dataset "Rois1158\_spring\_106.".

We used the BiCubic, EDSR8-RGB, RCAN, RS-ESRGAN and TS-SRGAN models to process 784 images in ROI\_Te to generate ×4 HR images, and used Matlab to calculate the evaluation values of NIQE, BRISQUE and PIQE, one by one, for the images. The histograms were drawn according to the distributions of evaluation metric values, as shown in Figures 11–13, and the mean and extreme values based on the evaluation values are provided in Table 3. The histograms and table show that the TS-SRGAN model is superior to the other models in a variety of NR-IQA metrics.

**Figure 11.** Distribution of evaluation values of NR-IQA metric NIQE.

**Figure 12.** Distribution of evaluation values of NR-IQA metric BRISQUE.

**Figure 13.** Distribution of evaluation values of NR-IQA metric PIQE.



Figures 14–21 show the generated images of eight regions with strong geographic features selected in "ROIs1158\_spring\_106" to visually compare the differences among the models. Through the comparison of the images of various terrains in Figures 14–21, it can be clearly seen that the images processed by the traditional BiCubic method are the most bleary and smooth, due to the inherent deficiencies of the interpolation algorithm. The EDSR8-RGB, RCAN and RS-ESRGAN models cannot correctly distinguish the noise with sharp edges, resulting in blurred results, and even indistinguishable houses and roads. As shown in the TS-SRGAN results, the dividing lines among the objects and backgrounds, such as roads, bridges and houses, are much clearer, which indicates that the noise we estimated was closer to the real noise. By using different combinations of degradation (e.g., blur and noise), TS-SRGAN has obtained LR images sharing the same domain with real images, which can avoid the generated LR images being too smooth and fuzzy. Using the super-resolution network trained by the domain-consistent data, TS-SRGAN generates HR images with clearer boundaries and better perception. Compared with the EDSR8-RGB, RCAN and RS-ESRGAN models, the TS-SRGAN's results are clearer, and have no ambiguity.

**Figure 14.** Comparison of visual effects of the generated images of the region containing mountainroad terrain. NR-IQA of images (NIQE, BRISQUE, PIQE): BiCubic (6.28, 54.73, 91.87), EDSR8-RGB (5.85, 50.25, 86.05), RCAN (4.64, 47.69, 62.20), RS-ESRGAN (3.31, 22.16, 8.42), RealSR (3.16, 15.75, 8.52), TS-SRGAN (2.43, 7.36, 7.48). The subfigures (**a**–**l**) in the figure represent an enlarged view of the local details in the green boxes in the generated image.

**Figure 15.** Comparison of visual effects of the generated images of the region with hilly terrain. NR-IQA of images (NIQE, BRISQUE, PIQE): BiCubic (6.35, 58.38, 92.36), EDSR8-RGB (5.88, 48.02, 81.31), RCAN (4.62, 46.94, 52.83), RS-ESRGAN (3.99, 27.62, 7.61), RealSR (2.83, 13.25, 7.84), TS-SRGAN (2.74, 3.83, 15.57). The subfigures (**a**–**l**) in the figure represent an enlarged view of the local details in the green boxes in the generated image.

**Figure 16.** Comparison of visual effects of the generated images of the region containing surface water terrain. NR-IQA of images (NIQE, BRISQUE, PIQE): BiCubic (6.54, 55.52, 100.00), EDSR8-RGB (5.29, 46.00, 76.57), RCAN (4.67, 46.28, 74.49), RS-ESRGAN (3.53, 27.75, 12.08), RealSR (2.56, 13.68, 9.84), TS-SRGAN (2.42, 13.97, 10.03). The subfigures (**a**–**l**) in the figure represent an enlarged view of the local details in the green boxes in the generated image.

**Figure 17.** Comparison of visual effects of the generated images of the region containing dry riverbeds and residential houses. NR-IQA of images (NIQE, BRISQUE, PIQE): BiCubic (6.30, 50.23, 100.00), EDSR8-RGB (5.17, 49.48, 73.97), RCAN (4.30, 46.78, 58.43), RS-ESRGAN (3.06, 23.08, 18.75), RealSR (3.23, 12.43, 9.51), TS-SRGAN (2.70, 28.73, 11.47). The subfigures (**a**–**l**) in the figure represent an enlarged view of the local details in the green boxes in the generated image.

**Figure 18.** Comparison of visual effects of the generated images of the region containing factories. NR-IQA of images (NIQE, BRISQUE, PIQE): BiCubic (6.28, 51.85, 100.00), EDSR8-RGB (5.45, 45.43, 82.85), RCAN (4.08, 46.86, 48.36), RS-ESRGAN (3.11, 12.89, 11.08), RealSR (3.00, 11.30, 11.11), TS-SRGAN (2.01, 10.28, 11.27). The subfigures (**a**–**l**) in the figure represent an enlarged view of the local details in the green boxes in the generated image.

**Figure 19.** Comparison of visual effects of the generated images of the region containing residential houses. NR-IQA of images (NIQE, BRISQUE, PIQE): BiCubic (6.26, 53.22, 100.00), EDSR8-RGB (4.82, 45.85, 74.28), RCAN (3.96, 45.61, 71.41), RS-ESRGAN (3.08, 22.47, 20.85), RealSR (3.12, 27.19, 15.42), TS-SRGAN (2.32, 18.27, 13.99). The subfigures (**a**–**l**) in the figure represent an enlarged view of the local details in the green boxes in the generated image.

**Figure 20.** Comparison of visual effects of the generated images of the region containing farmlands and sandy terrain. NR-IQA of images (NIQE, BRISQUE, PIQE): BiCubic (0.49, 55.53, 94.76), EDSR8- RGB (5.83, 49.99, 85.93), RCAN (4.37, 45.64, 46.14), RS-ESRGAN (3.35, 21.84, 8.59), RealSR (2.96, 26.74, 13.01), TS-SRGAN (2.46, 8.07, 8.84). The subfigures (**a**–**l**) in the figure represent an enlarged view of the local details in the green boxes in the generated image.

**Figure 21.** Comparison of visual effects of the generated images of the region containing overpasses. NR-IQA of images (NIQE, BRISQUE, PIQE): BiCubic (6.39, 54.32, 94.91), EDSR8-RGB (5.68, 49.28, 81.34), RCAN (4.83, 47.54, 68.32), RS-ESRGAN (2.95, 19.71, 12.84), RealSR (3.46, 13.37, 7.75), TS-SRGAN (2.68, 18.06, 10.60). The subfigures (**a**–**l**) in the figure represent an enlarged view of the local details in the green boxes in the generated image.

#### **5. Conclusions**

In this paper, based on the latest and most widely recognized GAN technologies, including KernelGAN, ESRGAN, PatchGAN, etc., we introduced the degradation kernel estimation and noise injection to perform SR for Sentinel-2 satellites' remote-sensing images, and improved the highest-resolution images from 10 M to 2.5 M resolution. Through the combination of the degradation kernel and injected noise, we obtained LR images in the same domain with real images, and obtained the near-natural LR–HR image pairs. On the basis of near-natural LR–HR image pairs, we used a GAN, combined with an ESRGANtype generator, a PatchGAN-type discriminator and a VGG-19-type feature extractor, used the perceptual loss, and focused on the visual characteristics of the images, so that our results have clearer details and better perceptual effects. Compared with the SR models of Sentinel-2, such as EDSR8-RGB, RCAN, RS-ESRGAN and RealSR, the main difference in our model lies in the construction of LR–HR image pairs for the training datasets. In the scene of training with natural LR–HR image pairs, there was no significant difference in the effect for SR images obtained with those models; however, in the scene with only LR images and no HR prior information, compared with RCAN, which constructs the image pairs through BiCubic and EDSR8-RGB and RS-ESRGAN, which use WorldView satellite HR images to construct the image pairs, TS-SRGAN has obvious advantages in the quantitative comparison of the non-reference image quality assessment and the intuitive visual effects.

**Author Contributions:** Conceptualization, Y.L. and S.W.; methodology, Y.L.; software, Y.L. and Y.W.; validation, Y.L. and, Y.W.; formal analysis, Y.L.; resources, Y.L., S.W. and B.L.; writing—original draft preparation, Y.L., Y.W. and B.L.; writing—review and editing, B.L.; project administration, Y.L.; funding acquisition, Y.L. and S.W. All authors have read and agreed to the published version of the manuscript.

**Funding:** This research was funded by the National Natural Science Foundation of China (Grant No. 61871147); Natural Science Foundation of Guangdong Province, China (Grant No. 2018A030313346); the General University Key Field Special Project of Guangdong Province, China (Grant No. 2020ZDZX3078); the General University Key Field Special Project of Guangdong Province, China (Grant No. 2022ZDZX1035); the Research Fund Program of Guangdong Key Laboratory of Aerospace Communication and Networking Technology (Grant No. 2018B030322004).

**Data Availability Statement:** The ×4 images by models TS-SRGAN, BiCubic, EDSR8-RGB, RCAN, RS-ESRGAN and RealSR are available online at Baidu Wangpan (code: mbah). The trained models of TS-SRGAN, EDSR8-RGB, RCAN, RS-ESRGAN and RealSR are available online at Baidu Wangpan (code: 5xj6). Additionally, all the codes generated or used during the study are available online at github/TS-RSGAN.

**Acknowledgments:** This work has a preprint version [52]. This version has not been peer-reviewed.

**Conflicts of Interest:** The authors declare no conflict of interest. The funders had no role in the design of the study, in the collection, analyses, or interpretation of data, in the writing of the manuscript, or in the decision to publish the results.

#### **References**


## *Article* **Weather Forecast Based on Color Cloud Image Recognition under the Combination of Local Image Descriptor and Histogram Selection**

**Kiet Tran-Trung 1, Ha Duong Thi Hong <sup>2</sup> and Vinh Truong Hoang 2,\***


**Abstract:** Numerous researchers have used machine vision in recent years to identify and categorize clouds according to their volume, shape, thickness, height, and coverage. Due to the significant variations in illumination, climate, and distortion that frequently characterize cloud images as a type of naturally striated structure, the Local Binary Patterns (LBP) descriptor and its variants have been proposed as feature extraction methods for characterizing natural texture images. Rotation invariance, low processing complexity, and resistance to monotonous brightness variations are characteristics of LBP. The disadvantage of LBP is that it produces binary data that are extremely noise-sensitive and it struggles on regions of the image that are "flat" because it depends on intensity differences. This paper considers the Local Ternary Patterns (LTP) feature to overcome the drawbacks of the LBP feature. We also propose the fusion of color characteristics, LBP features, and LTP features for the classification of cloud/sky images. Morover, this study proposes to apply the Intra-Class Similarity (ICS) technique, a histogram selection approach, with the goal of minimizing the number of histograms for characterizing images. The proposed approach achieves better performance of recognition with less features in use by fusing LBP and LTP features and using the ICS technique to choose potential histograms.

**Keywords:** cloud images; preprocessing techniques; classification; histogram selection; ICS; LTP; LBP; color image; weather forecast

#### **1. Introduction**

Any region's weather is closely tied to the presence of clouds. Clouds play a major role in all types of precipitation. Although not all clouds may result in precipitation, they are crucial for controlling the weather in some regions. Different types and heights of clouds exist in various geographical locations, such as the Earth's tropics or poles [1]. The sky always has clouds, and they are ever-changing. Clouds serve as indicators of atmospheric conditions and are crucial for weather forecasting and warnings, as well as for controlling the Earth's energy balance, temperature, and weather [2]. Before creating a weather prediction, meteorologists investigate the specifics of the cloud type since they are constantly changing.

We can forecast changes in weather by studying and categorizing clouds. Future measurements and forecasts may be significantly impacted by changes in cloud classification. Additionally, locating and analyzing clouds can assist meteorologists in modifying weather forecasts, better comprehending the local ecology, and foretelling changes in the world's climate [3]. Ground-based cloud observations are remote sensing image materials, so they will have information such as cloud parameters, spatial resolution, temporal

**Citation:** Tran-Trung, K.; Duong Thi Hong, H.; Truong Hoang, V. Weather Forecast Based on Color Cloud Image Recognition under the Combination of Local Image Descriptor and Histogram Selection. *Electronics* **2022**, *11*, 3460. https://doi.org/10.3390/ electronics11213460

Academic Editors: Yanhui Guo, Deepika Koundal and Rashid Amin

Received: 19 September 2022 Accepted: 13 October 2022 Published: 26 October 2022

**Publisher's Note:** MDPI stays neutral with regard to jurisdictional claims in published maps and institutional affiliations.

**Copyright:** © 2022 by the authors. Licensee MDPI, Basel, Switzerland. This article is an open access article distributed under the terms and conditions of the Creative Commons Attribution (CC BY) license (https:// creativecommons.org/licenses/by/ 4.0/).

resolution, spectral resolution, etc. Classification of cloud information (cloud parameters, temporal, and spatial resolution, etc.) is very important. These parameters need to be defined for ground-based cloud observations such as cloud height, cover, and cloud type [4–6], in which the parameter of cloud type is the first and most readily available parameter that humans can obtain.

The cloud is a product of nature, and it reflects the climate and different weather conditions on Earth. Clouds appear very diverse in different atmospheric conditions. Extracting useful information from huge amounts of image data by detecting and analyzing different entities in images is a major challenge. Today, the identification of cloud types is mainly based on the observations of experts. The results are subjective and cannot meet the actual requirements of meteorological observations. The automatic classification of cloud types has become a demanding problem that needs to be solved in this case.

Researchers are increasingly using imagery of the entire sky above the ground for a variety of purposes, including solar radiation, weather forecasting, and aviation. These images are usually higher in resolution than those obtained from [7] the satellite. Additionally, the camera's vertical orientation makes it simple to photograph clouds at various low altitudes. Thus, they enrich satellite images with useful data. In most cases, cloud observations are performed manually by experts from meteorological institutes. They observe these clouds to better understand atmospheric phenomena and classify clouds into various categories according to the World Meteorological Organization standard [8]. Although the results obtained are quite accurate, such manual observations are expensive and time-consuming. Therefore, it is necessary to apply cloud classification algorithms automatically and systematically to save costs.

In recent years, several studies have approached computer vision techniques to identify and classify clouds based on their volume, shape, thickness, height, and coverage. For example, Kliangsuwan and Heednacram [9] used a new method for feature extraction in cloud classification. The authors propose three more types of features based on Fourier transform, namely Fast Fourier Transform Projection on the modified *x*-axis (*k*-FFTPX), half *k*-FFTPX, and *h* × *k*-FFT, and use an ANN-based classification technique with a tree algorithm to extract features. Li et al. [10] proposed a new approach to cloud pattern recognition, based on analyzing the image as a set of patches (set of changes), rather than a set of pixels, and through the Support Vector Machine (SVM) classifier for classification. Zhen et al. [11] used spectral and texture feature extraction by tonal statistical analysis and Gray Level Cooccurrence Matrix (GLCM), and an SVM classifier with the Radial Basis Function (RBF) multiplier to classify different clouds from sky images. Taravat et al. [12] used a neural network in conjunction with SVM for automatic cloud classification for the entire color terrestrial image.

Algorithms based on structural features such as cloud scale, edge sharpness, Fourier transform, etc., cannot effectively exploit the useful information of cloud images due to cloud images, as a type of natural striated structure, often possessing variations with very large connotations due to large variations in illumination, climate, and distortion [13]. Cheng and Yu [14] proposed a cloud classification method that deals with mixed cloud types in an image based on the segmentation of images in different blocks. In each block, the texture statistical feature and the Local Binary Patterns (LBP) feature are extracted. These features are then classified using the Bayesian classifier. Liu et al. [15] proposed a new feature extraction technique by improving the LBP technique. This technique is called Salient Local Binary Pattern (SaLBP) for terrestrial cloud image classification. SaLBP utilizes the most frequently occurring patterns (prominent patterns) to obtain descriptive information. Liu and Zhang [13] presented a new feature extraction algorithm called Learning Group Patterns (LGP) to classify seven sky conditions; the proposed algorithm considers the resolution of the texture by using SaLBP and through LGP. Zhang et al. [16] focused both on designing appropriate feature representations and learning distance metrics from sample pairs. The authors also propose a feature extraction technique called Transfer Deep Local Binary Patterns (TDLBP) and learn WML. Wang et al. [17] proposed a powerful

feature extraction method based on the average rank of occurrence frequencies of invariant rotational samples defined in the LBP of the cloud image, called SLBP. In recent years, different image classifications based on deep learning have been proposed and demonstrated for their effectiveness [18,19]. In 2020, Wang et al. [20] proposed a convolutional neural network (CNN) integrated with a neural network with deep learning capabilities, called CloudA, as a ground-based cloud image recognition method. CloudA visualizes cloud features using TensorBoard visualization, and these features can help us to understand the terrestrial cloud classification process.

Therefore, feature extraction plays an important role, affecting the results of the classifier. There are many feature extraction methods have been proposed in the last decade. The LBP feature and its variants have been proposed as an effective feature extraction method for classifying natural texture images [21]. Although the extraction of local features by LBP gives many positive results, there are still noisy and duplicated features. Indeed, many extracted features will decrease the performance due to the curse of dimensionality. Feature selection involves finding a subset of valid features. To improve the accuracy of the classifier and reduce the computational burden, the feature selection step is essential.

This paper presents an approach for color image classifcation based on LTP features and feature selection methods. Since the color information is important to represent texture, we consider different color spaces for extracting local features. The rest of this paper is structured as follows. Section 2 presents the feature extraction by the LBP and LTP descriptor. Section 3 introduces the histogram selection approach via the Intra-Class Simalarity (ICS) score for selecting the most important features. Section 4 presents the cloud image classification process. Then, Section 5 shows the experimental results on the two benchmark datasets. Finally, Section 6 gives the conclusions and perspectives of this work.

#### **2. Feature Extraction**

Feature extraction is an important step for multimedia processing. The question of how to extract ideal features that can still reflect, as fully as possible, the contents of the image remains a challenging problem in computer vision. In other words, feature extraction is the process of obtaining the most important data from the raw data. Feature extraction is an important step in building any pattern classification model and aims to extract relevant features. In this process, relevant features are extracted from the objects to form feature vectors.

The most common image features include color, texture, and shape, etc., and most of the feature spaces are built on these features. However, the performance of the model depends heavily on the use of the image features [22]. LBP feature extraction was first proposed by Ojala et al. [23] to describe the texture of the image. It is the comparison of neighboring pixels with the central pixel to obtain a binary sample. This binary pattern is generated as follows: all neighboring pixels will take the value 1 if its pixel value is greater than the central pixel value, and otherwise take the value 0. Then, the pixels are multiplied with the respective weights and summed to obtain the LBP value for the center pixel.

The formula for calculating LBP is determined as follows:

$$\text{LBP}\_{P,R}(\mathbf{x}\_{\mathcal{C}}, y\_{\mathcal{C}}) = \sum\_{p=0}^{P-1} s(\mathcal{g}\_{\mathcal{P}} - \mathcal{g}\_{\mathcal{C}}) 2^{p} \tag{1}$$

where

$$s(\mathcal{g}\_p - \mathcal{g}\_c) = \begin{cases} 1 & \text{if } s(\mathcal{g}\_p - \mathcal{g}\_c) \ge 0 \\ 0 & \text{if } s(\mathcal{g}\_p - \mathcal{g}\_c) < 0 \end{cases} \tag{2}$$

where, *xc* and *yc* are the coordinates of the center pixel, *P* is the number of neighboring pixels, *R* is the neighborhood radius, *gc* is the grayscale value of the center pixel, *gp* is the grayscale value of the neighboring pixel.

Figure 1 shows the process and encoding of the LBP operator for grayscale images with 3 × 3 pixels.

**Figure 1.** Describes the encoding of the LBP operator.

The LTP operator was developed from LBP and introduced by Tan and Triggs [24]. This proposal offers significantly higher efficiency than LBP and better noise handling than LBP in homogenous regions. In LTP, *s*(*gp* − *gc*) is defined as follows:

$$s(\mathcal{g}\_p - \mathcal{g}\_c) = \begin{cases} 1 & \text{if } s(\mathcal{g}\_p - \mathcal{g}\_c) \ge t \\ 0 & \text{if } |s(\mathcal{g}\_p - \mathcal{g}\_c)| < t \\ -1 & \text{if } s(\mathcal{g}\_p - \mathcal{g}\_c) \le -t \end{cases} \tag{3}$$

where, *t* is the user-defined threshold. Figure 2 shows the working and encoding of the LTP operator for grayscale images with 3 × 3 pixels, with the parameter value *t* = 5.

**Figure 2.** Describes the encoding of the LTP operator.

#### **3. Histogram Selection**

A histogram is used to describe discrete or continuous data and is one of the best ways to represent variables. In other words, it provides a visual interpretation of numeric data by displaying the number of data points that fall within a specified range of values (called a bin). In order to select the pertinent features, there are several methods, such as evaluating the individual features or groups of features. Several histogram selection methods based on graph construction or the measurement of similarity are introduced [25,26]. The histogram selection methods can be considered as an evaluation of the groups of features [27]. Histogram selection methods are usually grouped into three approaches: filter method, wrapper method, and embedded method. The latter involves a combination of the reduced processing time of the filtration method and the high efficiency of the encapsulation method. The filtering method is used to calculate the score of each histogram to measure its effectiveness, and then the histogram will be ranked according to the calculated

score. The histogram is evaluated using a specific classification algorithm, and the selected histograms are the histograms that maximize the classification rate.

To improve the classification performance, there are many proposed methods with the goal of reducing the dimensionality of the feature matrix. One such method is dimensionality reduction of the feature matrix based on the feature histogram, as proposed by Porebski et al. in 2013 [28]. In this method, the most important and significant histograms are selected based on the score value of each histogram. The approach to selecting characteristic histograms using ICS techniques has recently been extended to the multicolor space domain. Considering a database with *N* textured color images, each image *Ii*, *i* ∈ {1, 2, ... , *N*} has a characteristic *δ* histogram. The entire set of data is represented by the matrix *H* as follows:

$$H = \begin{bmatrix} h\_1^1 & \dots & h\_1^r & \dots & h\_1^\delta \\ \dots & \dots & \dots & \dots & \dots \\ h\_i^1 & \dots & h\_i^r & \dots & h\_i^\delta \\ \dots & \dots & \dots & \dots & \dots \\ h\_N^1 & \dots & h\_N^r & \dots & h\_N^\delta \end{bmatrix} = \begin{bmatrix} h\_1 \\ \dots \\ h\_i \\ \dots \\ h\_N \end{bmatrix} = \begin{bmatrix} h^1 & \dots & h^r & \dots & h^\delta \end{bmatrix} \tag{4}$$

in which *h<sup>r</sup> <sup>i</sup>* is the *<sup>r</sup>*th histogram of the color image with texture *<sup>i</sup>*. *<sup>h</sup><sup>r</sup> <sup>i</sup>* is defined as follows: *hr <sup>i</sup>* = *h<sup>r</sup> <sup>i</sup>*(1),..., *<sup>h</sup><sup>r</sup> <sup>i</sup>*(*k*),..., *<sup>h</sup><sup>r</sup> <sup>i</sup>*(*Q*), where *Q* is the bin number of the histogram.

The ICS technique is based on an in-class similarity method to evaluate the similarity between histograms extracted from images of the same class.

Let *I<sup>k</sup> <sup>j</sup>* be the *k* training image of class *j*, and class *j* has *Nj* images. Accordingly, the number of intersections of the histogram is calculated as follows:

$$D(I\_{\dot{\jmath}}^k, I\_{\dot{\jmath}}^{k'}) = \sum\_{i=1}^{\mathbb{Q}} \min(h[I\_{\dot{\jmath}}^k](i), h[I\_{\dot{\jmath}}^{k'}](i)) \tag{5}$$

To measure the similarity of the class *j*, let *SIMj* be the similarity measure, calculated as follows:

$$SIM\_j = \frac{2}{N\_j(N\_j - 1)} \sum\_{k=1}^{N\_j - 1} \sum\_{k'=k+1}^{N\_j} D(I\_j^k, I\_j^{k'}) \tag{6}$$

Porebski et al., suggested that the higher the *SIMj* in a class, the more relevant the histogram *h<sup>r</sup>* . Finally, to calculate the ICS score of a histogram *h<sup>r</sup>* by:

$$S\_{ICS}^r = \frac{1}{C} \sum\_{j=1}^{C} SIM\_j \tag{7}$$

where, *C* is the number of classes to be considered. *S<sup>r</sup> ICS* has a value from 0 to 1. The most distinct histogram is the one with the highest score of *S<sup>r</sup> ICS*.

#### **4. Cloud Image Classification Process**

#### *4.1. Data Preparation*

To our knowledge, the SWIMCAT dataset [29] is a benchmark for cloud image classification. It contains images taken with WAHRSIS, a calibrated image of the entire sky above ground, and consists of of 784 patches of 5 cloud types from images taken in Singapore between January 2013 and May 2014. The five sky types were identified based on visual features of sky/cloud conditions, consulting experts from the Singapore Meteorological Service. All patch images are 125 × 125 pixels; we take one representative image from each category. This dataset contains 784 images of sky/clouds, classified into 5 categories: clear sky, patterned clouds, thick dark clouds, thick white clouds, and cloud cover (Figure 3).

**Figure 3.** Selected images from five categories of the SWIMCAT dataset.

Although there have been many studies analyzing sky/cloud images captured by ground-based cameras by several research groups, publicly available standard databases are rare. Some datasets are released for cloud detection or segmentation purposes—for example, HYTA [30]. In addition, this paper also evaluates the proposed approach on the Cloud-ImVN 1.0 [31]. This dataset was created based on inspiration from the SWIM-CAT dataset, with a larger number of images and more cloud/sky types. Specifically, the Cloud-ImVN 1.0 dataset has 2100 images of sky/clouds, classified into 6 categories: clear blue sky, patterned clouds, thick dark clouds, thick white clouds, thin white clouds, and cloud cover (Figure 4).

**Figure 4.** Six categories of cloud/sky images in the Cloud-ImVN 1.0 dataset: (**a**) clear blue sky, (**b**) patterned clouds, (**c**) thick dark clouds, (**d**) thick white clouds, (**e**) thin white clouds, (**f**) cloud cover.

#### *4.2. Classification Process*

The classification process of clouds images is illustrated in Figure 5. It consists of four steps as follows:


**Figure 5.** The cloud image classification process based on histogram selection using ICS technique.

#### **5. Experimental Results**

Previous studies focus on color features/color information of the RGB color space to classify clouds/sky images. However, the specific color space allows an improvement in the classification performance [32,33]. This work considers 14 different color spaces, such as HLS, HSV, IHLS, Lab, rgb, YIQ, YUV, RGB, bwrgby, XYZ, YCbCr, Luv, I1I2I3, ISH, for extracting features. Each extraction method applied obtains a corresponding feature in this case, the value of (*R*, *P*), with *R* ∈ {1, 2, 3, 4, 5} and *P* ∈ {4, 8, 12}. Thus, the input parameters for the feature extraction techniques LBP, LTP, and LBP+LTP will have 15 pairs of parameters (*R*, *P*), respectively. Each specific feature extraction technique, with a specific pair of input parameters, is applied on the 14 color spaces. Finally, histogram selection is applied for those features. In summary, the parameters used to run the experiment and give the results of cloud/sky image classification include: (1) a pair of parameters (*R*, *P*), (2) 14 color spaces, (3) 3 feature extraction methods: LBP, LTP, and LBP + LTP, (4) a selected number of histograms. Moreover, the dimension of each histogram is dependent on the value of *P*—for example, LBP features with (1, 8) have 3 histograms from three color channels, and each histogram consists of 2<sup>8</sup> = 256 bins. When applying the LTP descriptor to feature extraction, the number of histograms is doubled (2 × <sup>2</sup><sup>8</sup> = 512 bins) compared with the LBP descriptor.

#### *5.1. Results on the SWIMCAT Dataset*

Table 1 is a summary and selective synthesis result obtained when running experiments on the applied SWIMCAT dataset technique to extract LBP, LTP, and LBP + LTP features of 14 color spaces. The highest value for each feature is underlined. We observe that when classifying cloud/sky images in the SWIMCAT dataset, if using more color variables, the results are better than when using grayscale images, specifically with the RGB

color system, with ACC reaching 97.8 ± 1.3 for the LBP (1, 12) technique, ACC reaching 97.9 ± 1.2 for LTP (2, 12), and ACC reaching 98.1 ± 1.3 for technique LBP+LTP (3, 12).

Table 2 presents the results obtained on the SWIMCAT dataset incorporating the histogram selection method. The highest value for each feature is underlined. For color spaces with H color components and S color components, using the LTP feature gives better results than using the LBP feature. Specifically, the HLS color space reaches 98.9 ± 0.8 with the LBP (4, 12) feature and reaches 99.0 ± 0.7 with the characteristic LTP (4, 8); the ISH color system reached 98.6 ± 0.6 with the LBP (4, 12) feature and 99.0 ± 0.4 with the LTP (4, 8) feature.

**Table 1.** The best results obtained for SWIMCAT dataset without using histogram selection method. Some of the best results for each color space, each feature, and the selected parameters (*R*, *P*).


**Table 2.** The best results obtained for SWIMCAT dataset while using histogram selection method. Some of the best results for each color space, each feature, and the selected parameters (*R*, *P*).


Figure 6 presents the selected results of the highest accuracy obtained from three types of features (LBP, LTP, LBP+LTP) in the two scenarios: without and with the histogram selection method.

**Figure 6.** Highest results obtained on the SWIMCAT dataset without (first row) and with (second row) histogram selection method.

Table 3 presents the comparison of the results obtained on the SWIMCAT dataset with previous studies. Thus, the highest results are obtained on the SWIMCAT dataset for three types of features (Table 3) using the ICS technique, as follows:


With the characteristic of LBP when using the ICS technique, there is no significant difference in results when not using the ICS technique. However, with the characteristics of LTP and LBP+LTP when using the ICS technique for better results, the number of histograms selected is also smaller than when not using the ICS technique to select histograms. Moreover, with the SWIMCAT dataset, the LTP feature gives better results than the LBP feature (Table 2). The proposed approach clearly outperforms LBP variants such as WLBP, SRBP, and SWOBP. For example, the SaLBP technique is based on the LBP uniform, which has had many of the bins that usually arise eliminated.

**Table 3.** Comparison of experimental results on SWIMCAT dataset using ICS with previous studies.



#### **Table 3.** *Cont.*

#### *5.2. Results on the Cloud-ImVN 1.0 Dataset*

Table 4 shows the best results obtained with different types of features on the Cloud-ImVN 1.0 dataset. We observe the appearance of the HLS, HSV, and IHLS color spaces. It shows that color components with high dichroism or with H (Hue) and S (Saturation) components can still be good candidates for cloud/sky image classification. Moreover, we observe and confirm that the RGB space is not the best color space for characterizing cloud images. LTP features achieve higher accuracy than LBP features and have lower standard deviation. LBP features achieve the highest accuracy, with 85.4 ± 4.6, and LTP achieves the highest accuracy at 88.1 ± 2.2. In total, the combination of LBP and LTP features achieves the best accuracy at 92.2 ± 2.4.

Table 5 presents the obtained result on the Cloud-ImVN 1.0 dataset while using the histogram selection method on different parameters: color space, features used, (*R*, *P*) values. When applying the ICS method, the results of cloud/sky image classification change significantly: the results are higher and the number of features is highly reduced. Considering Table 5, the LBP feature (3,12) in the RGB color system achieved the highest ACC of 89.3 ± 2.2 with one histogram selected, while, when not using ICS, the LBP feature (5, 12) in the RGB color system achieved the highest ACC of 81.9 ± 3.3.


**Table 4.** The best results obtained for Cloud-ImVN 1.0 dataset without using histogram selection method. The best results for each color space, each feature, and the selected parameters (*R*, *P*).

**Table 5.** The best results obtained for Cloud-ImVN 1.0 dataset while using histogram selection method. The best results for each color space, each feature, and the selected parameters (*R*, *P*).


Figure 7 presents the selected results of the highest accuracy obtained from three types of features (LBP, LTP, LBP+LTP) in the two scenarios: without and with the histogram selection method.

Table 6 presents the comparison of the results obtained on the Cloud-ImVN 1.0 dataset with previous studies. Thus, the highest results are obtained on this dataset for three types of features using the ICS technique, as follows:


**Figure 7.** Highest results obtained on the Cloud-ImVN 1.0 dataset without (first row) and with (second row) histogram selection method.


**Table 6.** Comparison of experimental results on the Cloud-ImVN 1.0 dataset.

There are many image processing algorithms, classifying images mainly based on the RGB color space. However, for the classification of cloud images, it seems that the RGB color space carries many disadvantages and is not a good candidate. A color space with high dichroism and with high luminance components is a good candidate because the luminance in cloudy areas is higher than in others. Similarly to the results obtained on the SWIMCAT dataset, for Cloud-ImVn 1.0, the LTP descriptor gives better results than other LBP variants.

#### **6. Conclusions**

The existing cloud features are very useful in determining color space and texture features to classify cloud types. In order to be able to classify sky/clouds, it is essential to distinguish between two types of pixels (sky and clouds); a suitable color space can facilitate this classification. High-dichroism color systems are good candidates for cloud/sky image classification. In addition, color systems with high luminance components are also good candidates because the luminance in cloudy areas is higher than in cloudless areas.

This paper presents and systematically analyzes various features developed for the task of cloud/sky image classification. We found that the LBP and LTP feature extraction techniques generalized well to this objective. We integrate the color space and texture structure with the LBP feature and LTP feature effectively to obtain higher classification accuracy. Integrating color features and texture into cloud/sky image classification also enhances the performance. In the experiment, this work integrates color features to increase the efficiency of the feature extraction process, and using the ICS technique to select potential histograms allows us to enhance the performance clearly, with fewer features. However, the parameter value *t*, obtained using exhaustive techniques for the RGB color system, may affect the results of other color spaces.

By exploiting different aspects of sky/cloud images through ground-based sky/cloud images, the proposed method has solved the basic problems of color system processing and feature selection. Below are several future directions of the topic:


**Author Contributions:** K.T.-T.: Data curation, Software, Writing—original draft, Investigation, Formal analysis; H.D.T.H.: Data curation, Investigation, Conceptualization, Methodology; V.T.H.: Conceptualization, Methodology, Validation, Writing—review & editing, Supervision. All authors have read and agreed to the published version of the manuscript.

**Funding:** This research is funded by Ho Chi Minh City Open University (HCMCOU) and the Ministry of Education and Training (Vietnam) under grant number B2021-MBS-07.

**Data Availability Statement:** The data used to support the findings of this study are available from the corresponding author upon request.

**Conflicts of Interest:** The authors declare that they have no conflict of interest.

#### **References**


### *Article* **A Robust Framework for Object Detection in a Traffic Surveillance System**

**Malik Javed Akhtar 1, Rabbia Mahum 1,\*, Faisal Shafique Butt <sup>2</sup> Rashid Amin 3, Ahmed M. El-Sherbeeny 4, Seongkwan Mark Lee 5,\* and Sarang Shaikh <sup>6</sup>**


**Abstract:** Object recognition is the technique of specifying the location of various objects in images or videos. There exist numerous algorithms for the recognition of objects such as R-CNN, Fast R-CNN, Faster R-CNN, HOG, R-FCN, SSD, SSP-net, SVM, CNN, YOLO, etc., based on the techniques of machine learning and deep learning. Although these models have been employed for various types of object detection applications, however, tiny object detection faces the challenge of low precision. It is essential to develop a lightweight and robust model for object detection that can detect tiny objects with high precision. In this study, we suggest an enhanced YOLOv2 (You Only Look Once version 2) algorithm for object detection, i.e., vehicle detection and recognition in surveillance videos. We modified the base network of the YOLOv2 by reducing the number of parameters and replacing it with DenseNet. We employed the DenseNet-201 technique for feature extraction in our improved model that extracts the most representative features from the images. Moreover, our proposed model is more compact due to the dense architecture of the base network. We utilized DenseNet-201 as a base network due to the direct connection among all layers, which helps to extract a valuable information from the very first layer and pass it to the final layer. The dataset gathered from the Kaggle and KITTI was used for the training of the proposed model, and we cross-validated the performance using MS COCO and Pascal VOC datasets. To assess the efficacy of the proposed model, we utilized extensive experimentation, which demonstrates that our algorithm beats existing vehicle detection approaches, with an average precision of 97.51%.

**Keywords:** CNN (convolution neural network); YOLO (You Only Look Once); intersection over union (IoU); mAP (mean average precision)

#### **1. Introduction**

Multimedia has deeply penetrated many realms of life in the present generation of promptly emerging technologies. In daily life, many people utilize electronic devices for various video applications i.e., animated videos, activity recognition, movies, etc. Cameras have been utilized quickly over the last century for surveillance systems. A surveillance system is a systematic method of monitoring behavior, actions, or other changing information. This results in massive data accumulation in the form of images and video clips, and it can be a tiring task to extract relevant information from this multimedia content. The three essential phases in every surveillance system are object detection, tracking, and recognition.

**Citation:** Akhtar, M.J.; Mahum, R.; Butt, F.S.; Amin, R.; El-Sherbeeny, A.M.; Lee, S.M.; Shaikh, S. A Robust Framework for Object Detection in a Traffic Surveillance System. *Electronics* **2022**, *11*, 3425. https:// doi.org/10.3390/electronics11213425

Academic Editor: Byung Cheol Song

Received: 29 September 2022 Accepted: 19 October 2022 Published: 22 October 2022

**Publisher's Note:** MDPI stays neutral with regard to jurisdictional claims in published maps and institutional affiliations.

**Copyright:** © 2022 by the authors. Licensee MDPI, Basel, Switzerland. This article is an open access article distributed under the terms and conditions of the Creative Commons Attribution (CC BY) license (https:// creativecommons.org/licenses/by/ 4.0/).

Object identification and localization is the procedure of locating the location of objects in images and videos captured by surveillance cameras. The objects can be classified and detected in real-time using various computer vision techniques [1]. Furthermore, the objects are specified by employing rectangular bounding boxes. There exist various applications of object detection in industries and scientific research based on ML (Machine Learning) and DL (Deep Learning), for example face detection [2], text detection [3], pedestrian detection [4], logo recognition [5], object identification in the video [6], vehicle detection [7], disease detection [8], medical imaging [9] and many more. Moreover, for vehicle detection in autonomous driving systems, numerous challenges are still faced such as the algorithm's inability to detect faraway vehicles due to their small size, blurry conditions, night view, and rainy seasons because of less precision in localization of present studies.

Traditional techniques based on ML (Machine Learning) have been used mostly for object detection; the object area is computed first using a sliding window, then various features are mined, and finally, traditional classifiers such as SVM (Support Vector Machine) are used to classify the objects. Although the results are satisfactory, these methods are incapable of accurately detecting and classifying objects ignoring the underlying deep features. A researcher used the Haar feature descriptor to extract the linear, center, diagonal, and edge features before classifying the objects using a Support Vector Machine. Moreover, employing a hand-crafted feature descriptor requires human effort. As a result, researchers are concentrating their efforts on deep learning algorithms like CNN, R-CNN, and YOLO, which have greatly improved object detection performance.

Existing deep learning methods for object detection are dedicated to simplifying the network and speeding up the detection process. These results are heavily reliant on the accuracy of the proposals generated, such as when, for example, the researcher employs a faster R-CNN approach to detect and count vehicles. Although this technique can accelerate the detection process, it has lower detection accuracy than other traditional methods. Most importantly, these methods are incapable of detecting distant vehicles. We propose an improved Yolov2 algorithm with Densenet-201 as a base network in video surveillance systems to detect far-away vehicles that appear in small sizes.

The rest of the paper is structured in the following manner. Motivation is covered in Section 2, while the study's related work is presented in Section 3. The problem statement is discussed in Section 4, the methodology is discussed in Section 5 and the proposed approach is presented in Section 6. Section 7 presents the experiments, and Section 8 concludes the findings.

#### **2. Motivation**

The motivation of the proposed model is to investigate the issue of object detection (i.e., vehicle detection) in videos obtained from surveillance cameras employing the improved YOLOv2 technique. Moreover, the vehicles have various sizes in videos, and as a result, conventional approaches have a hard time detecting vehicles precisely. An improved YOLOv2 algorithm is developed in this paper to cope with this challenge.

The key advantages of doing this investigation are as below:


#### **3. Literature Review**

Employing computer vision techniques to accurately identify on-the-road vehicles is a thought-provoking issue that has been a scorching research topic for the past two decades [10]. The surveillance videos of traffic have the ever-changing background due to lighting effects. As a result, the exact size and location of vehicles are difficult to capture due to the simultaneous movements of the vehicles on the road.

Recently, DL (Deep Learning) models have piqued the interest of numerous scholars, and a plethora of deep learning object detection algorithms have been introduced. In comparison to traditional methods, manual feature extraction in machine learning object detection algorithms needs experts with years of experience in the associated domain. Whereas, deep learning models necessitate a large amount of data to automatically acquire the characteristics that can imitate differences in data, making it more demonstrative. Simultaneously, the procedure of feature extraction in the CNN layer in visual recognition mechanisms is similar to the human visual mechanism. Deep learning-based detection algorithms have achieved reasonable real-time performance compared to traditional algorithms in recent years, requiring a continuous increase in data volume, and constant updates of device hardware, and have attained recognition worldwide. Due to the better real-time accuracy and performance in the academic field, the deep learning vehicle detection algorithm has been gradually developed in two directions; one is focused on accuracy and the other one is on complexity.

For more than a decade, researchers have studied vehicle detection and recognition extensively in the literature. Previously, numerous handcrafted features were removed for vehicle detection, which requires manual intervention. Haar [11], HOG [12], and LBP [13] were the three most commonly used feature descriptors. The classification framework was evaluated for vehicle detection and found to be effective, i.e., a large number of vehicles were detected. Additionally, the HOG feature in conjunction with the Support Vector Machine classifier is commonly used with great success in vehicle detection. Moreover, the mentioned features and classifiers with broad applications in vehicle detection tasks, and statistical techniques using vertical and horizontal edge features were initiated for the detection of vehicles and vehicle tracking at night by placing the tail lights. Table 1 presents the recent works done on vehicle detection and classification.

**Table 1.** Summary of existing techniques for vehicle detection.


*Electronics* **2022**, *11*, 3425


#### **Table 1.** *Cont.*

#### **4. Problem Statement**

A vast number of techniques have been developed in the past era that have added much-needed attention among researchers as a result of the development and improvements in the domain of CV due to their vast surveillance applications. As a result of advances in computer vision, the detection of objects in images is becoming increasingly important because it can benefit a vast number of applications, including human detection, face detection, vehicle detection, hammer detection, gun detection, knife detection, and many others. With the advancement of technology and the increased number of vehicles on

the road around the world, the traffic system has become increasingly reliant on automatic vehicle recognition systems. Consequently, the vehicle detection and recognition system must perform well in the context of time complexity and accuracy.

The main aim of this research paper is to investigate the challenge of vehicle detection in surveillance videos using deep learning. Due to the low-quality of surveillance images, lack of background information, low lighting, and distance it is quite difficult to detect vehicles.

Hence, we deduce a technique i.e., Improved YOLOv2 through surveillance videos that distinguish between vehicles and non-vehicles based on deep learning. From various research studies done previously, it is deduced that the current surveillance systems cannot efficiently tell us which kind of model should be used for what kind of images. Surveillance systems usually fail because they rely heavily on human operators who have physical restrictions in the form of lethargy or loss of attentiveness due to monitoring several screens for longer periods. These restrictions can be eased by enhancing the surveillance systems to automatically detect the various objects that are present in an image. These proficiencies can then enable surveillance systems to detect objects in various images. To have such proficiencies, we need to deduce a mechanism that can not only capture images but can also account for human emotion and behavior i.e., a method like object detection has to be introduced that can detect the difference between different objects in an image.

Although various traditional methods are used to detect vehicles in surveillance videos, the main problem is that the traditional methods are not as accurate and also, they are very expensive. Nowadays researchers focus on using deep learning methods to recognize vehicles. In this research, the Improved YOLOv2 algorithm, a type of DL (Deep Learning) technology, was utilized to detect various vehicles (e.g., Car, Bus, Truck) observed in surveillance cameras.

#### **5. Materials and Methods**

Conventionally, deep learning contains numerous layers of nonlinear processing modules to obtain the features. All layers are cascaded and take the output from the previous layer as input. Many researchers have attempted to build the network deeper and larger to investigate the potential of deep learning. However, it has a challenge with exploding or the vanishing gradient problem (VGP). As a result, many researchers build multiple different structures of deep learning.

A range of deep learning structures has been proposed such as AlexNet [25], ResNet [26], DenseNet [26], GoogLeNet [27], VGGNet [28]. The 2012 ImageNet Large Scale Visual Recognition Competition (ILSVRC) winner, AlexNet, is comparable to LeNet and has ReLU non-linearity and max-pooling. In the 2014 ILSVRC, VGGNet came at second place, with deeper networks (19 layers) than AlexNet. To extract sparse correlating features in feature map stacks, GoogLeNet, the ILSVRC 2014 winner, uses 1 × 1 convolution to minimize the dimensions of feature maps earlier than the expensive convolutions, as well as parallel routes with variable receptive field sizes. ResNet, the ILSVRC 2015 winner, proposes a 152-layer network with a minimum of 2 layers of skipped or shortcut connections. Whereas, each layer in DenseNet feeds forward the output of all preceding layers, providing N (N + 1)/2 connections in N layers, whereas outdated convolutional networks with N layers only deliver N connections. DenseNet is capable of performing better than the cutting-edge ResNet structure in the ImageNet classification test.

In this research, we proposed DenseNet201 as the base network in YOLOv2 for vehicle detection (e.g., Car, Bus, Truck) because of its remarkable performance. However, before going into detail about DenseNet201, the traditional convolution neural network (CNN) will be discussed first, followed by the distinctions between DenseNet and CNN.

#### *5.1. Convolution Neural Network (CNN)*

A standard convolution neural network (CNN) normally includes (i) Convolution (CONV) layer, (ii) Rectified linear unit (ReLU) layer, (iii) pooling (POOL) layer, (iv) Fully connected (FC) layer, and (v) Softmax layer [29]. The following are the functions of the several layers, with the convolution layer as the fundamental session of a CNN. Convolutional input with various kernels produces the feature maps. It can be expressed mathematically as presented in Figure 1.

**Figure 1.** CNN Convolution operation (\* represents multiplication).

Succeeding the convolution layer, there exists the ReLU nonlinear activation function, which is used to extract nonlinear features. The goal of the ReLU layer is to impart nonlinearity to the network. It is mathematically defined as Equation (1).

$$relu(v) = \max(v, 0);\tag{1}$$

The pooling layer works by geographically resizing the feature maps to reduce the parameters, memory footprint, and network computation time. Each feature map is subjected to the pooling function, and the most common pooling approaches are max pooling as shown in Equation (2), and average pooling as presented in Equation (3).

$$a\_k = \frac{1}{|\mathcal{R}\_k|} \sum\_{j \in \mathcal{R}\_k} M\_j \tag{2}$$

$$a\_k = \max\_{j \in \mathcal{R}\_k} (M\_j) \tag{3}$$

*M* denotes the pooling region, while *Rk* represents the total elements along with the pooling region. The confidential scores will be calculated through fully connected layers and stored in a 1 × 1 × c volume. Each element represents class scores, while *c* refers to the categories.

An individual neuron in the FC layer is linked to neurons in previous layers. In a typical CNN, all the layers are progressively associated, as shown in Equation (4).

$$m\_{\mathcal{I}} = F\_{\mathcal{I}}(m\_{r-1}) \tag{4}$$

However, when the network grows deeper and larger, it is possible that the network could explode or the gradient would vanish. As a result, researchers offered various network architectures to solve the problem. ResNet, for example, changed this behavior by using a short link as shown in Equation (5).

$$m\_r = F\_r \lfloor (m\_{r-1}) + m\_{r-1} \rfloor \tag{5}$$

Rather than summing the feature maps' outputs of the layer to the incoming feature maps, DenseNet has direct connections among all layers and each current layer takes input from all previous layers. The expression is rewritten as Equation (6).

$$m\_r = F\_r[(m\_{0\prime} \ m\_{1\prime} \ m\_{2\prime} \ \dots \ m\_{r-1})] \tag{6}$$

where *r* denotes the layer number's index, *F* denotes a non-linear function and *mr* denotes the *r*-th layer's output.

#### *5.2. Densenet-201*

Due to the capacities of feature reusability by succeeding layers, the DenseNet-201 employs the condensed network, allowing the tremendously parametrically efficient model, which increases diversity in the succeeding layer input and enhances performance. The DenseNet201 has performed admirably on a variety of datasets, including ImageNet [30] and CIFAR-100 [31]. Direct connections from all preceding layers to all future layers are introduced to boost connectivity in the DenseNet201 architecture, as shown in Figure 2.

**Figure 2.** Direct connections in DenseNet201.

The advantages of DenseNet201, which includes 201 convolutional layers, are fewer vanishing-gradient problems, excellent feature distribution, feature reusability, and a fewer number of parameters.

Let's assume that an image *m*<sup>0</sup> is fed into a neural network with R layers and non-linear transformation *Fr (.)*, where *r* is the index of the layer. ResNet's traditional skipping connections are included in the feed-forward network that bypasses the non-linear alteration with an identity function, as shown in Equation (7).

$$m\_r = F\_r(m\_{r-1}) + m\_{r-1} \tag{7}$$

ResNet has one advantage here that from initial layers till final layer, a gradient can move straight through the identity function. Whereas, direct end-to-end connections are used in the dense network to maximize the amount of information in each layer. The r-th layer receives all of the previous layer's information as shown in Equation (8).

$$m\_r = F\_r[(m\_{0\prime}, m\_{1\prime}, \dots, \dots, m\_{r-1})] \tag{8}$$

In DenseNet, down sampling takes place at Dense Blocks, which are split into Transition layers; it contains a 1 × 1 convolutional layer (CONV) and a pooling layer (average) with BN (batch normalization). The bulks from the transition layer ultimately spread to

the dense layers. We transformed the entire average-pooling layer into a 2 × 2 max pool layer for network utility. BN (Batch normalization) is performed previously in each of the convolutional layers, making the model less complex. The hyperparameter k denotes the network's growth rate, making the DenseNet capable of producing cutting-edge results. Pooling layers are eliminated, and the proposed detection layers are fully integrated and related to the classification layers for detection. Even deeper network designs than the 201-layer network can be found in DenseNet-264 [32]. Because we don't want to cast a wide network, the 201-layer structure is suitable for detecting vehicles. Due to its manner, which reflects feature maps as a global mechanism of the network, DenseNet201 performs well even with a smaller growth rate. Figure 3 exhibits the DenseNet201 architecture:

DenseNet-201 is based on the transfer learning concept, having 201 depth layers and 20 million parameters that have been trained using more than one million images attained from the ImageNet dataset.

#### *5.3. YOLO (You Only Look Once) Theory*

YOLO is an abbreviation of "You Only Look Once" [33], an advanced, one-stage algorithm, to identify objects in real-time. The YOLO technique uses CNN, and object recognition is performed as a regression scenario. CNN is employed to predict various bounding boxes and class probabilities simultaneously. In comparison to Faster R-CNN, YOLO obtains location and category predictive information without a region proposal network (RPN).

#### *5.4. Working Principle of YOLO*

At the start, the network splits the input image into the R × R grid. When the central point of an object lies in a grid cell, that grid cell is responsible for the detection of that object. B bounding boxes and confidence scores are predicted in each grid cell for those bounding boxes. Prob (*Object*) stands for whether there is a required object falling into this cell. The mathematical equation of confidence C in YOLO-v2 is shown in Equation (9).

$$\mathbb{C}(Confidence) = \text{Prob}(Object) \* IoI\_{pred}^{truth} \tag{9}$$

Here, each grid cell predicts C conditional class probabilities, Pr (Class | Object), Prob (*Object*) is the probability of predicting whether the boundary object contains the vehicle object. If the object is present, Prob (object) is equal to 1, otherwise it is equal to 0.

There are five components of the bounding box (x0, y0, wd, ht, confidence). The confidence score reflects how self-assured the model is in the predicted box containing an object and how correctly the box is that it predicts. The (x0, y0) coordinates refer to the center of the box related to the bound of the grid cell and these coordinate values lie between 0 and 1. The (wd, ht) box dimensions are width and height of the relative bounding box to the whole image and are also normalized to 0 and 1. The category probability p is calculated as shown in Equation (10).

$$\text{Prob}(Class\_i \Big|Object) \* \text{Prob}(Object) \* IoL\_{pred}^{truth} = \text{Prob}(Class\_i) \* IoL\_{pred}^{truth} \tag{10}$$

The confidence score is zero if no object lies in that cell. Otherwise, the confidence score should be equivalent to the intersection over union (*IoU*) of the actual and predicted boxes. Each grid cell creates B of these predictions, and there exist a total of R × R × B × 5 outputs connected to bounding box predictions. The last layer of the pre-trained CNN model predicts the tensor of size R × R × (B × 5 + C), where C is several classes.

If multiple objects exist in a single grid cell then to resolve this problem, we utilized the concept of an anchor box. The anchor box enables the YOLOv2 to identify several objects in a single grid cell. Due to this, a new idea of an anchor box i.e., one more dimension, is added to the output labels by predefining several anchor boxes. After that, one object will be assigned to each anchor box. Figure 4 illustrates the framework of the YOLO methodology.

**Figure 4.** The framework of the YOLO methodology.

#### *5.5. Loss Function*

The loss is split into two sub-parts, a loss for localization for predicting bounding box offsets and a classification loss for predicting the probabilities of conditional class. The squared error sum is utilized to compute both parts. Two scale parameters are used to determine how much the loss from bounding box coordinates predictions should be increased *λcoord* and how much we want to reduce the number of confidence score predictions for boxes that are lost without objects *λnoobj*. As a result, the weighted technique is used to balance the various types of losses. Generally, *λcoord* is set as 5 and *λnoobj* set as 0.5 to minimize each loss. Otherwise, each loss may contribute differently to the overall loss, rendering certain losses unsuccessful for network training. The loss equation is shown in Equation (11):

$$\begin{split} \lambda\_{\text{coord}} & \sum\_{i=0}^{s^2} \sum\_{j=0}^{B} \Pi\_{ij}^{obj} \Big[ \left( \chi\_i - \pounds\_i \right)^2 + \left( y\_i - \pounds\_i \right)^2 \Big] + \lambda\_{\text{coord}} \sum\_{i=0}^{s^2} \sum\_{j=0}^{B} \Pi\_{ij}^{obj} \Big[ \left( \sqrt{w\_i} - \sqrt{w\_i} \right)^2 + \left( \sqrt{h\_i} - \sqrt{h\_i} \right)^2 \Big] \\ & + \sum\_{i=0}^{s^2} \sum\_{j=0}^{B} \Pi\_{ij}^{obj} \left( \mathbb{C}\_i - \hat{\mathsf{C}}\_i \right)^2 + \lambda\_{\text{noobj}} \sum\_{i=0}^{s^2} \sum\_{j=0}^{B} \Pi\_{ij}^{nobj} \left( \mathbb{C}\_i - \hat{\mathsf{C}}\_i \right)^2 + \sum\_{i=0}^{s^2} \Pi\_{ij}^{obj} \sum\_{c \in \text{classes}} \left( p\_i(c) - p\_i(c) \right)^2 \end{split} \tag{11}$$

where *xi* and *yi* represent the center coordinates, *wi* and *hi* refer to the width and height of the box, *Ci* represents the confidence of the box, and *pi*(*c*) is the class probability related to the box of the i-th grid cell. Moreover, the equivalent predictions of *xi*, *yi*, *wi*, *hi*, *Ci*, and *pi*(*c*) are *xˆi*, *yˆi*, *wˆi*, *hˆi*, *Cˆi*, and *pi*(*c*), the weight of the loss coordinates is *λcoord*, and *λnoobj* represents the weight of the bounding boxes without any objects loss. *S2* indicates the S × S grid cells, B indicates the boxes whether there is an object that falls in the j-th bounding box of the i-th grid cell, and *λnoobj* refers to the confidence consequence when there is no object. In Equation (11),

$$\lambda\_{coord} \sum\_{i=0}^{s^2} \sum\_{j=0}^{B} \Pi\_{ij}^{obj} \left[ \left( \mathfrak{x}\_i - \mathfrak{x}\_i \right)^2 + \left( y\_i - \mathfrak{y}\_i \right)^2 \right]$$

is responsible for calculating the coordinate loss,

$$\lambda\_{coord} \sum\_{i=0}^{s^2} \sum\_{j=0}^{B} \Pi\_{ij}^{obj} \left[ \left( \sqrt{w\_i} - \sqrt{w\_i} \right)^2 + \left( \sqrt{h\_i} - \sqrt{\tilde{h}\_i} \right)^2 \right]$$

is responsible for computing the bounding box size loss,

$$\sum\_{i=0}^{s^2} \sum\_{j=0}^{B} \Pi\_{ij}^{\text{obj}} \left( \mathbf{C}\_i - \hat{\mathbf{C}}\_i \right)^2$$

is responsible for determining the bounding box confidence loss with objects,

$$\lambda\_{m\alpha\flat\flat\dot{j}} \sum\_{i=0}^{s^2} \sum\_{\dot{j}=0}^{B} \Pi\_{i\dot{j}}^{m\alpha\flat\flat\dot{j}} \left(\mathsf{C}\_i - \mathsf{C}\_i\right)^2$$

will calculate the bounding box confidence loss without objects, and

$$\sum\_{i=0}^{s^2} \Pi\_{ij}^{obj} \sum\_{c \in classers} \left( p\_i(c) - p\_i(c) \right)^2$$

is responsible for calculating the class loss.

#### **6. Proposed Solution**

Our proposed solution network's structure comprises (i) the Input layer, (ii) network for feature extraction, and (iii) detection network. The first stage in the network is to balance the size of an input image to 224 × 224 pixels, after which the scaled data is passed into DenseNet-201 for Feature Extraction. As previously indicated, we replaced the YOLOv2 baseline network Darknet-19 with DenseNet-201 and associated procedures, and now we are looking into the network's detection adjustments. The complete structure of our proposed system is depicted in Figure 5.

**Figure 5.** The overall structure of the proposed model.

#### **7. Experimental Evaluation**

#### *7.1. Dataset*

Dataset is the main foundation to estimate any model's performance. Improving the recognition rate of the proposed model requires sufficient data for vehicle detection training. More training data can enhance the recognition and generalization rate as well as the robustness of the model, whereas overfitting problems may occur due to an insufficient amount of datasets. We used two datasets, Kaggle [34] vehicle and KITTI [35] datasets for the training and testing of the model. Moreover, the MS COCO [36] dataset and Pascal VOC [37] dataset were used to cross-validate the proposed model.

#### 7.1.1. Kaggle Vehicle Dataset

The vehicle dataset available on Kaggle is used for experimental purposes. The dataset is split into two parts i.e., train set and the test set. The Kaggle vehicle dataset contains 22,852 training images and 5193 test images, containing a total of 28,045 images. There exist 17 classes (Ambulance, Car, Cart, Boat, Bus, Caterpillar, Helicopter, Barge, Bicycle, Segway, Limousine, Motorcycle, Tank, Taxi, Snowmobile, Truck, and Van). The class-wise distribution of Kaggle datasets is presented in Table 2.

**Table 2.** Comprehensive overview of the Kaggle dataset.


#### 7.1.2. KITTI

The KITTI dataset is freely available having 80,256 labeled objects in numerous images. We utilized 7481 training photos and 2000 test images. All of the images are colored and have been saved as "png" files. There are 80 classifiers (Car, Bus, Truck, Train, Motorcycle, etc.). The class-wise distribution KITTI dataset is described in Table 3.



#### 7.1.3. Pascal VOC

Pascal VOC contains 20 different classes (Vehicle: train, bicycle, boat, bus, airplane, etc.), and 9963 images consisting of 24,640 annotated objects. For vehicle detection, we utilized various class samples from the Pascal VOC dataset. More precisely, we employed 800 images in total to evaluate our proposed classifier for the detection of vehicles.

#### 7.1.4. COCO

Common Objects in Context (COCO) is one of the most famous open-source datasets for object identification and segmentation. Microsoft sponsors the COCO dataset, which contains over 300,000 images and 90 object types. In recent years, semantic segmentation has become the industry standard for image semantics understanding. Thus, we employed only 500 images exhibiting various vehicles from the COCO dataset. Various training samples are presented in Figure 6.

**Figure 6.** Various samples for training.

#### *7.2. Metrics*

To analyze the performance of the proposed system, we have utilized the metric of *Accuracy* [9], Intersection Over Union (*IoU*) [38], and mean Average Precision (*mAP*) [39]. *Accuracy* relies upon true positive (*TP*) [40], false positive [41] (*FP*), true negative (*TN*), and false negative (*FN*). Furthermore, the accuracy of the system indicates the correctly classified images by the proposed system. Equation (12) is presented below.

$$Accuracy = \frac{TP + TN}{TP + TN + FP + FN} \tag{12}$$

We have employed *mAP* i.e., the average precision to analyze the performance of our proposed detector. The Equation (13) is shown below, where *Q* denotes the total number of test images.

$$mAP = \sum\_{i=1}^{Q} \frac{AP\ (q\_i)}{Q} \tag{13}$$

#### *7.3. Environment*

We performed the experiments using a GPU NVIDIA card i.e., GEFORCE RTX 30 with 4 GB memory. The operating system was Windows 10 having a RAM of 16 GB. The experiment was performed using Matlab 2021a. We trained our classifier for various categories of vehicles employing parameters such as epochs: 100 and learning rate: 0.0001.

The primary goal of this paper is to propose an accurate approach for the detection of vehicles correctly. The various experiments performed can provide insight into the method's robustness and capacity to run in real-time scenarios. To achieve a reliable vehicle detector, we proposed an Improved YOLOv2 using DenseNet201 as the base algorithm employing a transfer learning (TL) mechanism. The proposed model is based on the outstanding performance of DenseNet as it performs on ImageNet dataset classification tasks. Figure 7 shows results of the proposed model for the detection of Vehicles using the Kaggle Vehicle dataset.

**Figure 7.** Results of the proposed model's detection by using the Kaggle Vehicle dataset.

#### *7.4. Class-Wise Performance*

The average precision (AP) for each vehicle class, was used to measure the performance of recognition. The average recognition performance is depicted by the mean Average Precision (mAP), whereas intersection over union (IoU) indicates the average localization performance. In Object detection, mAP and IoU are significant measures for evaluating a model's performance. Table 4 shows that the proposed upgraded YOLOv2 with Densenet201 has an mAP of 97.51% and an IoU of 97.06%. Improved YOLOv2 with Densenet20 worked well for single and multiple vehicle identification, according to our findings. In our purposed method, the mAP of Taxi and Van reaches up to 98.9%, while the remainder of the results ranges from 94.5% to 98.8%. In terms of localization and recognition accuracy, our proposed technique surpassed others.


**Table 4.** Class-wise performance over Kaggle and KITTI dataset.

#### *7.5. Cross-Validation*

The Pascal VOC and MS COCO datasets have been employed for the cross-validation of the proposed model. For vehicle detection, we employed various samples from Pascal VOC and MS COCO datasets. Using DenseNet-201, we determined the mAP for each of the 20 classes in the PASCAL VOC dataset for Improved YOLOv2, and we achieved 81% mAP, which was approximately 2 percent higher than YOLOv2. Furthermore, our proposed model achieved promising results and outperformed other detectors, as shown in Table 5. For 1000 iterations, the training took around one hour. It was exhibited that Fast RCNN [42] attained 70% mAP, YOLOv2 [43] achieved 76.8%, and Faster RCNN with ResNet [43] achieved 76.4% mAP. The highest mAP was 81%, which was attained by our proposed model, whereas the least mAP was 63.4% which was attained by YOLO [33]. Moreover, SSD300 [44] and SSD500 [44] achieved 74.3% and 76.8% mAP, respectively. On the other side, Faster RCNN along with VGG-16 [45] and Improved YOLOv3-Net [46] achieved 73.2% and 77.4% mAPs. It is concluded that our proposed algorithm transcends the existing models due to an improved base network DenseNet-201. Our base network retrieves the most relevant features, and due to dense connections the flow of information is

accurate till the last layer. More precisely our proposed model is robust, to perform accurate detection due to its dense architecture. In Figure 8, the comparison plot is depicted.


**Table 5.** Comparison of different Network Models using PASCAL VOC 2007.

**Figure 8.** Comparison of mAP between the different networks with the proposed network.

#### *7.6. Comparison with Existing Models*

To evaluate the performance of our proposed model, we conducted two separate experiments. In the first experiment, we employed Pascal VOC 2007 to train our detector for vehicle detection. We analyzed the effectiveness of the proposed technique and matched it with predominant techniques over Pascal VOC 2007 dataset. We utilized only three class samples from the dataset as Bus, Car, and Truck. The proposed model performed significantly better than existing techniques. This training method employed a batch size of 64 and 0.001 is the learning rate. It was done using the IoU Threshold of 0.50. Four distinct dimensions of network models have been perceived such as Improved YOLOv2, YOLOv3, and YOLOv3-Net and our proposed model Improved YOLOv2-Net-201. The statistics are shown in Table 6. The best mAP of 82.7% was achieved for Improved YOLOv2- Net-201 due to the proposed dense architecture as the base network in YOLOv2. Each layer attains data from all the preceding layers and passes it to all coming layers. More precisely, the classification layer has a direct connection with previous layers, extracting the most valuable features for the detection of vehicles. Our proposed model is capable of significant vehicle detection and outperforms the existing techniques. The comparison plot is presented in Figure 9, exhibiting the better performance among existing models.



**Figure 9.** Comparison Graph with existing models for Car, Bus, and Truck samples using the PASCAL VOC dataset.

In the second phase, the COCO dataset has been used to train the detector for vehicle detection like Buses, Car, and Trucks. The statistics are shown in Table 7. The best mAP of 75.1% was achieved by our proposed model, and the least mAP was 60% attained by the original YOLOv2. Meanwhile, YOLOv3 and Improved YOLOv3 achieved 66.2% and 71.2% mAPs, respectively. Our proposed model has attained the best performance among existing models. Our proposed model effectively identifies the vehicles more than the predominant models. Moreover, our model is based on DenseNet which overcomes the problem of vanishing gradient and is better in terms of compactness than ResNet. The comparison graph for performance over the COCO dataset is shown in Figure 10.


**Table 7.** Comparison of Proposed Network with existing models on the COCO dataset.

**Figure 10.** Comparison Graph with existing models for Car, Bus, and Truck samples using the COCO dataset.

#### **8. Conclusions**

In this study, an innovative and vigorous system for Vehicle detection is proposed using a deep neural network established on YOLOv2 (You Only Look Once). Our proposed technique uses DenseNet-201 as a Feature Extraction network swapping darknet18 in the original YOLOv2. We employed two benchmarks such as the Kaggle vehicle dataset and the KITTI dataset as: 70% for training and 30% for testing of our proposed model. Moreover, we utilized samples from 17 classes exhibiting various vehicles such as buses, trucks, cars, carts, bikes, etc. We performed extensive experimentation to evaluate the performance of the proposed model and achieved better average precision for our model than existing techniques. Moreover, our proposed model is more compact and utilizes more representative features due to dense connections among layers. More precisely, each coming layer is directly connected with all previous layers till the classification layer in our proposed base network, and this mechanism ensures a good flow of information from the input layer to the last one. Furthermore, our proposed model detects tiny vehicles with more precision and more accurately calculates bounding boxes due to compactness in the base network than the original YOLOv2. We also performed cross-validation to determine the robustness of our proposed technique using two prominent datasets, Pascal VOC and COCO. We attained excellent performance for our proposed model compared to state-of-the-art techniques, achieving 81% mAP. We believe that our proposed model is robust and an effective framework for vehicle detection such as for cars, buses, trucks, etc. In the future, we aim to modify and fine-tune our model to attain better accuracy and mAP for vehicle detection along with classification. Moreover, we will try to utilize our framework for other object detection applications such as abnormal activity detection.

**Author Contributions:** Conceptualization, S.M.L. and R.M.; methodology, F.S.B.; software, R.A.; validation, A.M.E.-S., S.M.L. and S.S.; formal analysis, M.J.A.; investigation, R.M.; resources, R.M.; data curation, R.A.; writing—original draft preparation, F.S.B.; writing—review and editing, A.M.E.- S.; visualization, S.M.L.; supervision, S.S.; project administration, M.J.A.; funding acquisition, A.M.E.- S. All authors have read and agreed to the published version of the manuscript.

**Funding:** The authors extend their appreciation to King Saud University for funding this work through researchers supporting project number (RSP-2021/133), King Saud University, Riyadh, Saudi Arabia.

**Data Availability Statement:** Data sharing does not apply to this article as authors have used publicly available datasets, whose details are included in the "experimental results and discussions" section of this article. Please contact the authors for further requests.

**Acknowledgments:** The authors extend their appreciation to King Saud University and UET Taxila for supporting this work.

**Conflicts of Interest:** The authors declare no conflict of interest.

#### **References**


### *Article* **A Machine Learning Method for Prediction of Stock Market Using Real-Time Twitter Data**

**Saleh Albahli 1,\*, Aun Irtaza 2,3, Tahira Nazir 2, Awais Mehmood 2, Ali Alkhalifah <sup>1</sup> and Waleed Albattah <sup>1</sup>**


**Abstract:** Finances represent one of the key requirements to perform any useful activity for humanity. Financial markets, e.g., stock markets, forex, and mercantile exchanges, etc., provide the opportunity to anyone to invest and generate finances. However, to reap maximum benefits from these financial markets, effective decision making is required to identify the trade directions, e.g., going long/short by analyzing all the influential factors, e.g., price action, economic policies, and supply/demand estimation, in a timely manner. In this regard, analysis of the financial news and Twitter posts plays a significant role to predict the future behavior of financial markets, public sentiment estimation, and systematic/idiosyncratic risk estimation. In this paper, our proposed work aims to analyze the Twitter posts and Google Finance data to predict the future behavior of the stock markets (one of the key financial markets) in a particular time frame, i.e., hourly, daily, weekly, etc., through a novel StockSentiWordNet (SSWN) model. The proposed SSWN model extends the standard opinion lexicon named SentiWordNet (SWN) through the terms specifically related to the stock markets to train extreme learning machine (ELM) and recurrent neural network (RNN) for stock price prediction. The experiments are performed on two datasets, i.e., Sentiment140 and Twitter datasets, and achieved the accuracy value of 86.06%. Findings show that our work outperforms the state-of-the-art approaches with respect to overall accuracy. In future, we plan to enhance the capability of our method by adding other popular social media, e.g., Facebook and Google News etc.

**Keywords:** machine learning; SentiWordNet; stock prediction; sentiment analysis

#### **1. Introduction**

Stock price fluctuation signifies the existing market trends and company evolution that might be measured to sell or buy stocks. A stock market estimate has been considered as one of the highly challenging and essential tasks due to its nonlinear or dynamic behavior [1]. Stock prices turn up and down every minute or even every second because of variations in demand and supply. If a group of individuals wants to purchase a specific stock, its price will rise. Whereas, when most people owning a specific stock want to sell it, its market price will decrease. This association among supply and demand is tied into the news, blogs, and sentiment analysis (SA), etc. Stock market prediction using SA deals with automatic [2] performance of the stock market. In this regard, Twitter is the most popular platform that can be used to predict public opinion, so it can be useful for forecasting the stock market price [3].

Nowadays, there has been a debate on the effectiveness of the sentiments conveyed via social media in forecasting the change in the stock market. Various researchers have revealed that sentiments might influence the stock market movement and act as potential predictors for trade-off outcomes [4,5]. Furthermore, different methods of sentiment mining can be employed differently in numerous stock circumstances [6]. In other words, there

**Citation:** Albahli, S.; Irtaza, A.; Nazir, T.; Mehmood, A.; Alkhalifah, A.; Albattah, W. A Machine Learning Method for Prediction of Stock Market Using Real-Time Twitter Data. *Electronics* **2022**, *11*, 3414. https://doi.org/10.3390/ electronics11203414

Academic Editors: Yanhui Guo, Deepika Koundal and Rashid Amin

Received: 19 September 2022 Accepted: 4 October 2022 Published: 21 October 2022

**Publisher's Note:** MDPI stays neutral with regard to jurisdictional claims in published maps and institutional affiliations.

**Copyright:** © 2022 by the authors. Licensee MDPI, Basel, Switzerland. This article is an open access article distributed under the terms and conditions of the Creative Commons Attribution (CC BY) license (https:// creativecommons.org/licenses/by/ 4.0/).

are a lot of responsibilities involved in evaluating opinions about the traits and features of stocks. [7,8]. However, the existing techniques do not suggest an absolute reliance on the number of tweets per unit of time. The amount of data gathered and analyzed during the existing studies remain inadequate, thus causing predictions with low accuracy [9,10].

Even though extensive techniques have been presented by the research community for stock market prediction, these approaches have some potential limitations. The existing methods are not robust to tackle the versatile nature of stocks. Furthermore, the massive size of data requires such methods which can learn a more reliable set of features to better demonstrate the varying behaviors of stocks over the time. Hence, there is a need for performance enhancement both for the stock trends prediction accuracy and time complexity.

To deal with the issues of current approaches, we propose the technique namely the SSWN with ELM classifier for stock market prediction. The presented method comprises three main steps which are data gathering, sentiments computation along with model training, and finally the stock market prediction module. More descriptively these are the contributions of this paper:


The remainder of this paper is structured as follows: Section 2 shows the related work. The proposed method is presented in Section 3. Experiments and results are described in Section 4, while Section 5 concludes our work.

#### **2. Related Work**

Numerous studies [11–21] have been exhibited on employing electronic knowledge to forecast stock trends. For instance, Zhang et al. [22] proposed an LSTM based method to estimate the stock market trend. In the first step, the input is partitioned into three parts: open opinion space, stock transaction, and market transaction data. The one-layer LSTM was employed to prepare long memory in public opinion space, whereas two layers of LSTM were applied to train short memory in stock series and market. After this, data were combined by using the merged layer, and a linear layer was utilized to enhance the model results. The method predicts the market behavior and evaluates the relationship between the emotions of investors and transaction data. However, the method needs further improvements in the emotion abstraction technique. Xu et al. [23] presented a method for the forecast of the stock market by introducing the SA. Initially, the dataset is gathered by using a heuristic mean-end process, and then sentiments are identified from the acquired data. SA was combined with event study and the result was used as the input of principal component analysis (PCA), which was used for further analysis. The method predicts the market behavior using SA with an accuracy of 84.89%. However, the method faces stability-related issues and there exists an inequality between the forecast and real values.

Wu et al. [24] proposed a deep learning (DL) method for the prediction of the stock dimensional valence-arousal sentiments in the stock market. The method used the title, keywords, and overview of stock market-related messages for estimation of all vectors using the hieratical attention approach. The method achieved success, producing better results. However, it cannot identify the words with multiple meanings, and it also needs some stability improvements. Similar to the aforementioned technique, a DL-based method was employed in [22] for extrapolation of Stock market using sentiment analysis. The model

is based on RNN and LSTM techniques which is then utilized to define the sentiments into positive and negative class. The increase or decrease in stock prices is predicted from sentiment analysis. Ren et al. [15] presented a framework for prediction by examining the sentiments of investors. Initially, the financial reviewed content was gathered from two sites namely Sina Finance and Eastmoney. Then, the SVM was trained over the financial data to predict an essential index in China, namely SSE 50 Index, by applying a five-fold cross-validation technique. The method confirmed that merging the sentiment keypoints with stock market data can obtain robust results in comparison to utilizing only stock market data in estimating movement direction. However, this technique is not robust to analyze large data in real-time. Bouktif et al. [14] introduced an approach to predict the stock market's future directions. Initially, stock data are gathered from online resources together with public tweets. In the second step, the NLP approach was applied to compute the informative key-points from the tweets. Then, several ML-based methods, namely naive Bayes, logistic regression, SVM, ANN, random forest, and XGBoost, were trained to classify the data. The technique needs further improvement for complex textual features.

Kelotra et al. [13] offered a DL based technique namely the Rider-monarch butterfly optimization (MBO)-based on the ConvLSTM framework for stock market prediction. In the first step, the input data were collected from the livestock market which was passed to the key-points computation process to calculate the technical indicators-based representative set of features. In the next step, the clustering technique, namely sparse-fuzzy C-means (FCM), was employed over the extracted key-points to group them. After this, the highly important key-points were passed to the presented RiderMBO-based Deep-ConvLSTM network to perform prediction. Another sentiment analysis-based stock market prediction approach was presented in [12], which makes use of computed textual deep features. After gathering the stock market data, CNN and RNN were employed to compute the deep features. After this, PCA and LDA algorithms were applied to extract the significant set of features. Finally, the SVM classifier was trained over the calculated features for stock market movements prediction. The model performs well for stock market prediction, but it may not exhibit better performance over real-world scenarios. Similarly, in [11], a DL-based framework employing sentiment analysis for stock market prediction was presented. The LSTM model was utilized to forecast the future closing values of a stock market. Supporting the English-only tweets, this method is robust to calculate the stock market movements.

The user responses from historic articles can be employed to predict consumer behaviors with time. One such method was presented in [25] using a dual CNN approach with user behaviors to embed both the semantic and structural information from text articles. Another approach employing Pillar 3 disclosed information was presented in [26] that focused on the investigation of deposit users' interests and behavior using information from websites that were rooted deeply in commercial bank disclosures. The Pillar 3 regulatory framework's objective was to strengthen price stability by ensuring accountability and improving financial institutions' public disclosures. The work [26] performs well for analyzing consumer behavior. However, the model needs evaluation on a standard dataset.

#### **3. Proposed Methodology**

Our proposed technique encompasses three steps: data gathering, extraction of sentiments, training, and prediction of the stock market.

#### *3.1. Data Gathering and Cleansing*

First, we gather data from Twitter. This social media platform is selected due to its conciseness. In addition to tweets data directly extracted from Twitter, we have used the state-of-the-art dataset named Sentiment140 [27]. After data acquisition, we cleanse this collected data by removing spam, redundant, meaningless or irrelevant tweets by using a reduction system. The preprocessing step further includes the following:

• Conversion of tweets into word tokens by using bigrams, meaning that the model evaluates two tokens/words at the same time. This means that if a tweet describes something as "not good", that will be considered as a negative remark, rather than a positive one just because it contains the word "good".


After preprocessing, the cleansed dataset is used for feature extraction and sentiment identification by using the ML algorithm. This process formed the raw twitter data into a standard dataset containing a feature set and tweets with their predicted sentiments, i.e., Positive, Negative and Neutral denoted by 1, −1, and 0, respectively. Furthermore, neutral tweets can cause an imbalance in the training process which can degrade the performance of the classifier. To remove the neutral tweets, we used a simple algorithm which identified them by their label (i.e., 0) and filtered them out of the dataset, resulting in the reduced version of the dataset with no neutral tweets. The dataset is further reduced by removing neutral tweets as they do not play any role in the prediction process. The removal of neutral tweets is necessary for two reasons; (i) neutral tweets do not contain any opinion or sentiment polarity, hence they do not play any significant role in opinion mining, and (ii) the inclusion of neutral set of tweets causes a bigger dataset, resulting in the extra and unnecessary overhead for the classifier during model training [28–30]. The overall architecture is shown in Figure 1.

**Figure 1.** Flow diagram of proposed technique.

Secondly, we also make use of stock market data provided at Google Finance, where Global historical stock data is available. The price data of chosen stocks is selected and downloaded from the service provider in a CSV file. The collected data maintain seven features named: date, open, high, low, close, volume, and adjusted close. These features indicate traded date, opening price, highest price for trading, lowest price for trading, price at closing, traded shares, and stock closing price when investors are paid their dividends, respectively. This data is also preprocessed by adding some calculated values based on existing features (i.e., 5-day price difference, 10-day price difference, extrapolated prices

during holidays, and return of the market (RM)), and removing some columns including adjusted close price, volume, and opening price. The reasons for adding those calculated values are as follow: the 5- and 10-day price difference provides a brief past behavior of the stock under discussion. The closing prices for weekend have been extrapolated to complete the timeline of the dataset, which may result in improved overall accuracy of the model [4]. The return of the market (RM) is calculated to provide an investor a probabilistic idea of risk vs. expected profit.

After the preprocessing stage for both data sources have been completed, the next step is model training and stock prediction. An ELM and RNN-based model have been trained using the extracted features from the Twitter and Google Finance datasets. Both datasets are distributed into two subsets; the first 70% is reserved for training and the second 30% for testing/validation. More details about the incorporated datasets have been provided in the results and discussions section.

#### *3.2. Feature Extraction*

Once the data re passed from the preprocessed stage, they are forwarded to the feature extraction stage where further data processing is performed. For this reason, we have proposed a novel approach, namely the SSWN. A detailed description of the proposed approach is given in the subsequent sections.

#### 3.2.1. SWN

Several lexical resources are highly utilized in various investigations. A summary of the highly applied assets is given in Table 1. The first lexical resource mentioned in the table named SenticNet is a semantic resource which is publicly available and used for performing SA at concept-level. It does not use the standard graph mining techniques, rather it uses as custom-devised concept 'energy flows' for common sense knowledge representations. On the other hand, AFINN one of the simplest and popular lexicons containing hundreds of synsets and words associated with a polarity score ranging from −5 to 5. Similarly, SO-CAL is also a lexical resource which more than six thousand Synsets while assigning each word a polarity score ranging from −5 to 5. Another popular lexical resource is WordNet, which is a superficial resemblance of thesaurus, grouping the words together based on their meanings. It is a freely available large lexical database which groups nouns, verbs, adverbs and adjectives into synsets, also known as cognitive synonyms. Additionally, WordNet-Affect extends the domains of WordNet by further including a subset of cognitive synonyms (synsets) which are appropriate for representing the affective concepts in a correlation with affective words. There are several applications of SWN in SA that can be employed to predict the stock market as the structure of its key points is convenient to perform the mathematical modeling. SWN is a lexical resource for opinion mining [23], in which every synset of WordNet, a triple of polarity scores is named, i.e., a positivity, negativity, and objectivity score. SWN has been established routinely by implying a mixture of linguistic and statistic classifiers. It has been employed in various opinion-related missions, i.e., for bias analysis and SA with encouraging findings.

#### 3.2.2. SSWN

For predicting the future trends of stock market, we have introduced SSWN, which is based on SWN 3.0 and contains a set of feature words specifically helpful to identify and score tweets related to stock market only. The SSWN creation procedure starts with two seed sets. The first group comprises positive terms while the other contains negative terms. The seed groups are extended by combining all the synsets from SWN related to the seed words. A particular value of the radius is chosen for seed expansion. Another set namely objective word is also introduced. In the second step, the computed seeds are used to classify the SSWN synsets into positive and negative classes. In the presented approach, we have employed classifiers along with four choices of radius = 0, 2, 4, 6. The outputs from all classifiers are averaged to decide the final value of the synset. Table 2 describes a SSWN sample in which every tuple of SSWN specifies a synset comprised of dialogue data, an identifier that links the synset with WordNet, scores, and a gloss that keeps the denotation together with the usage of the values available in each synset. All words/tokens in each row of the cleansed data are replaced with the calculated scores, resulting in a feature matrix which is aligned/standardized with the input requirements of the ELM classifier. The objective score (*OS*) can be calculated as:

$$OS = 1 - (PS + NS) \tag{1}$$

where *PS* is the positive score while *NS* is negative. The sentiment score (*SS*) can be calculated using Equation (2):

$$SS = PS - NS\tag{2}$$

The strength of sentiment (*ST*) can be found through Equation (3), in which *r* is the rank of the feature.

$$ST = \sum\_{r=1}^{n} SS(r)/r \tag{3}$$

Table 3 demonstrates the relationship between a term t and a class c.

**Table 1.** A summary of lexical resources.


**Table 2.** A sample from the lexical resource named SWN.


**Table 3.** Association between t and c.


#### Information Gain (IG)

IG, also termed as expected mutual information, is an ML-based technique that is employed to compute the term goodness for a given technique [23]. It works by computing the bits of information based on the existence or absence of a word in a file. For example, the collection of groups in a target space is represented by [30] *i* = 1, ... , *m*. Then, the IG for a term t is computed by using the formula in Equation (4).

$$\mathbf{G}(t) = -\sum\_{i=1}^{m} \mathbf{P}\_{r}(\mathbf{c}\_{i}) \log \mathbf{P}\_{r}(\mathbf{c}\_{i}) + \mathbf{P}\_{r}(t) \sum\_{i=1}^{m} \mathbf{P}\_{r}(\mathbf{c}\_{i}/t) \log \mathbf{P}\_{r}(\mathbf{c}\_{i}/t) + \mathbf{P}\_{r}(\tilde{t}) \sum\_{i=1}^{m} \mathbf{P}\_{r}(\mathbf{c}\_{i}/\tilde{t}) \log \mathbf{P}\_{r}(\mathbf{c}\_{i}/\tilde{t}) \tag{4}$$

It is a simplified type of binary categorization [21] as text categorization approaches typically use n-array classification space, i.e., the range of n can be up to tens of thousands. Furthermore, the goodness of a value is calculated universally in accordance with all classes on average. The IG value is computed for every distinctive term for a specified corpus. Furthermore, a threshold is defined against the IG score based on which terms are eliminated from the corpus. The computation complexity for IG is O(Vn), where V is vocabulary size and n is n-array categorization. By employing the correlation table, the IG value is computed through Equation (5). The greater the value of IG, the better the union.

$$IG(t,c) \approx B \times N \times \log \frac{B}{(B+D) \times (A+B)}\tag{5}$$

Sentiment Knowledge Base (SKB) Generation Procedure

To produce the SKB, the presented approach follows the following steps:


The SKBs produced via this procedure are domain-independent as sentiment strength is computed through employing a generic sentiment lexicon that does not require the training from a specific domain. The presented SKBs are capable to deal with the problem of data absence and data diversity. Moreover, these SKBs can easily locate the sentiment orientation, weightage, and sense of words based on their usage. These sentiment resources are used in the introduced technique to improve SA specifically for stock market prediction and for SA in general. Table 4 shows a sample from the proposed lexical resource SSWN. Another challenging problem for effective SA is the constant occurrence of new words or sentences. Hence, there is a need for such a method that can deal with a database comprising frequent out-of-vocabulary (OOV) words. In natural language processing, the words which are present in testing/real data set but not available in the training dataset are called out of vocabulary (OOV) words. The main issue is that the model mistakenly assigns zero probability to OOV words, which results in likelihood of a word equal to zero. This common problem normally occurs when the model is trained on larger dataset. There are multiple solutions to solve this problem, including tokenization, smoothing technique, and semantic representations [42,43]. As OOV terms belong to a specific domain, intensive domain information is needed to specify its strength. To cope with this issue, usually, active learning is employed in which a polarity score is computed through humans. To evade the

bias, we have chosen only those OOV words for which at least ten persons have voted. The final sentiment score is computed by taking the average value of all ten scores.


**Table 4.** A sample from the proposed lexical resource SSWN.

#### *3.3. Prediction Phase*

The link between stocks and sentiments is definitely nonlinear. Hence, after discovering a causality association between the moods over the past 3 days and present-day stock prices, we attempted two techniques (ELM and RNN) to discover and examine the definite association [44], and financial markets often follow nonlinear trends. As discussed earlier, the proposed technique incorporates two datasets, i.e., data extracted from Twitter and a state-of-the-art dataset named Sentiment140. The features extracted from Twitter data by using SSWN have been incorporated to predict the stock trends by using the past three days stock data extracted from Google Finance. These extracted features are then utilized to predict the current day's stock trends of a set of specific brands.

#### 3.3.1. Extreme Learning Machine

The important characteristics of text classification include a large number of training samples and high text dimensionality. The high dimension of the text results is increased computational burden to the ELM. A traditional and effective method to resolve this issue is to reduce text dimensionality by using some text representations which help increase the clarification accuracy. The researchers often use vector space model (VSM) for text representation in text classification. Compared with other text representation methods, word vector representation has proven to have better text representation ability. Word vector deals with dimensionality problem by mapping each term (a distinct word in textual dataset) with a real vector with low dimension by training the unlabeled corpus. We have considered open, high, and low as input to the ELM and closing price as output of the ELM. In the proposed approach, the ELM classifier [45] was initially introduced for a feed-forward neural network with a single hidden layer without the need to tune it. The output with L hidden nodes for training set is explained in Equation (6):

$$f\_{\perp}(\mathbf{x}) = \sum\_{i=1}^{L} \beta\_i h\_i(\mathbf{x}) = h(\mathbf{x})\beta \tag{6}$$

Here, *<sup>β</sup>* <sup>=</sup> {*β*1,..., *<sup>β</sup>L*}*<sup>T</sup>* is presenting output weights among the nodes of the hidden and output layer while *h*(*x*) = {*h*1(*x*),..., *hL*(*x*)} is the output vector. The decision method for ELM classifier is given as:

$$f\_L(\mathbf{x}) = \operatorname{sign}\left(h(\mathbf{x})\boldsymbol{\beta}\right) \tag{7}$$

To obtain the robust performance, the ELM aims to deal with the lowest training error and reach the minimum norm of the resultant weights by reducing the given objective function:

$$Minimize: ||H\beta - T||\_2 \text{ and } ||\beta||\tag{8}$$

Here, H is showing the output matrix from hidden layers.

$$\mathbf{H} = \begin{bmatrix} h(\mathbf{x}\_1) \\ \vdots \\ h(\mathbf{x}\_N) \end{bmatrix} = \begin{bmatrix} h\_1(\mathbf{x}\_1) & \cdots & h\_L(\mathbf{x}\_1) \\ \vdots & \vdots & \vdots \\ h\_1(\mathbf{x}\_N) & \vdots & h\_L(\mathbf{x}\_N) \end{bmatrix} \tag{9}$$

To reduce the norm of the output weights, ELM draws an optimal hyper-plane to classify the samples into different classes through maximizing the margin: 2/||*β*|| by employing the nominal least square approach as:

$$
\beta = H^\dagger T \tag{10}
$$

Here, *H*† presents the Moore–Penrose generalized inverse of the matrix that is calculated by using the orthogonalization, orthogonal projection, and singular value decomposition approaches.

The desired output of ELM is:

$$T\_{test} = \beta H^{\ddagger} \tag{11}$$

#### 3.3.2. Recurrent Neural Network

Recurrent neural network, aka RNN is suitable in the problems in which we must deal with a sequence of data. Many researchers recommend using RNN for time series analysis [8–10]. In this type of work, the model learns from its current observing, also known as Short-term memory of the network, resembling the frontal lobe of the brain. The reason for using RNN when we are going to deal with sequential data is that the model uses its short-term memory to predict the upcoming data with more accuracy. Rather than using a fix deadline for deleting the past data, the weights allotted to past data determine the time for which these data will be kept in memory. Thus, RNN is more suitable in the case of problems, such as sequence labeling, sentiment analysis, and speech tagging, etc. [46,47].

Time series analysis is generally an important problem which can be resolved by using RNN. In this problem, we need to work with data which is in sequential order. Such works involve learning from the most recent observations, alternatively called short-term memory. This work primarily focuses on text classification. So, RNN in this research, is used for classification of Twitter data. We propose a model to predict the closing price of the stock market.

Twitter data are not in a uniform format, meaning that number of words in a tweet may vary from 3–5 words to 17–20 words, for example. However, our neural network does not accept input in this form. We need to convert this data into a uniform format. The most appropriate solution to this problem can be embedding and padding the data rows. The embedding process involves representing the words with vectors by using the procedure mentioned in the discussion related to ELM. The position of a term or word in a vector space is determined and it is represented in the feature vector. The embedding data then needs to be in the uniform length, so we pad the data with zeros.

RNN [48] employs links among nodes to build a directed graph over a timeframe. This enables it to show sequential vibrant behavior. RNN utilizes its memory to manipulate the varying length sequences of inputs which makes it appropriate for the stock prediction. Every processing unit in an RNN consists of time-based arbitrary real valued activation and adaptable weight which are generated by employing the same set of weights in a loop over a graph-like structure. Equation (12) is used to specify the values of hidden units.

$$H^t = f\left(h^{t-1}, \mathfrak{x}^t; \mathfrak{G}\right) \tag{12}$$

In RNN, the size of the input remains same for each learned model, as, it is indicated in the form of shift from one state to another. Moreover, the structure employs the identical transition function having the same parameters for each time step. RNN stores the output of the previous layers to make predictions which enables it to work with sequential data. In this work, we have tested the RNN for prediction of stock market behavior.

#### **4. Experimental Results**

This section describes the demographics of datasets used, an overview of the evaluation metrics, and a comprehensive discussion of the results achieved along with a comparison with state-of-the-art techniques.

#### *4.1. Experimental Setup*

The test bed consists of a workstation equipped with an x64 Intel Core i7-6700 CPU clocking at 3.40 GHz with 16 GB of DDR4 RAM and 4 GB of NVIDIA GetForce graphics card. The storage capacity is 1 TB HDD and 256 GB of SSD. The 64-bit operating system is Microsoft Windows 10 Professional which is installed on the SSD. The datasets and working environments are stored on SSD to avoid the mechanical delay caused by the HDD and speedup the model training and testing process.

Python version 3.7.15 along with necessary libraries like NLTK, Stanford NER Tagger, and BeautifulSoup, Numpy, Scikit-learn etc. is installed in Anaconda environment. We have used the Relu activation function and learning rate is 0.001 for our model training. For performance evaluation we have employed different metrics, i.e., accuracy, precision, recall, and F-measure.

#### *4.2. Datasets*

As described in the previous sections, we incorporated two datasets, i.e., Sentiment140, which is a state-of-the-art dataset widely used for tasks involving SA, and the other dataset is directly collected from Twitter platform using a Twitter API, i.e., Tweepy.

#### 4.2.1. The Sentiment140 Dataset

The Sentiment140 dataset contains a total of 1.6 M tweets extracted by using a Twitter API [49]. All the tweets have been annotated as negative = 0, neutral = 2, and positive = 4 and are utilized to discover their sentiments. The dataset contains six columns described in Table 5. A detailed description of the Sentiment140 dataset can be found here [27].


**Table 5.** Description of the Sentiment140 dataset.

We filtered out the tweets mentioning one of the specified brand names in the tweet body. This filtration resulted in a new subset of the Sentiment140 dataset consisting of total 56 K tweets. The set of neutral tweets has been ignored and subtracted from the dataset as neutral tweets do not play any significant role in the stock prediction.

#### 4.2.2. Direct Data from Twitter

This dataset is collected using a custom code which uses a Twitter API. Tweets mentioning the brands are shown in Table 6 and posted during 1 March 2021 to 21 March 2021 have been extracted/downloaded by using a Python library called Tweepy. After performing preprocessing and cleansing steps mentioned in the previous sections, the gathered data are finally in a condition to be processed and used for predicting the stock market value of specific brands. Table 6 demonstrates the demographics of data directly collected from Twitter.


**Table 6.** Details of data directly extracted from Twitter.

Similar to the steps performed on the setntiment140 dataset, we downloaded the tweets mentioning one of the brands under study by using the previously mentioned custom code. This resulted in a new dataset consisting of approximately 506 K tweets. Additionally, we also calculated the term frequency. Figure 2 depicts a word cloud showing frequently used words in the dataset.

The set of neutral tweets was ignored and excluded from the dataset as neutral tweets do not play any significant role in the stock prediction. After performing preprocessing and subtraction of the neutral tweets, the dataset was further reduced to a total of 224.2 K tweets belonging to positive and negative classes only.

**Figure 2.** Word cloud diagram of stock words.

#### 4.2.3. Proposed Method Results

This is section is a detailed discussion about the achieved results by proposed approach. For stock price prediction, we trained two models, i.e., ELM and RNN over both the datasets and reported the average results. Figure 3 shows the results of the proposed method in terms of precision, recall, and f-measure. The said figure depicts that the proposed model shows variable performance from stock to stock. The reason for this is the availability of the training data. Some stocks were found to be mentioned less than others on Twitter, resulting in fewer tweets (i.e., training data) for those brands. Thus, the more data you have for certain stocks, the more accurately the model can predict the output values (stock prices in our case) for those stocks.

From the results, we can say our method performs achieved the good results for predating the stock market behavior. The average values of our proposed system in terms of precision, recall and f-measure are 0.8603, 0.811, and 0.8537, respectively. The column graph shows the brand wise results of our method, in which blue, red, and gray bars show precision, recall, and f-measure, respectively. So, we can say that our method can precisely predict the stock market behavior of all brands.

To further evaluate our method, we have plotted accuracies of all brands in boxplot which can be seen in Figure 4. Figure 4a describes the results of ELM classification. The prediction accuracy of all brands, i.e., APPL, TSLA, MSFT, WMT, PYPL, NVDA, INTC, FB, TWTR, and AMZN is 90.3%, 85.01%, 88.21%, 85.18%, 84.716%, 87.35%, 80.733%, 79.25%, 91.05%, and 88.78% respectively. So, the average accuracy of our proposed technique is 86.06% which is impressive and can be used to precisely predict the stock market behavior. Figure 4b shows the results of RNN classifier for all brands. Here, our method achieved the average accuracy of 81.4%, which is less than the accuracy achieved by ELM classifier. According to the results, we can say that our proposed approach more accurately predicts the stock market trends of any brand.

**Figure 4.** Accuracies of All brands using (**a**) ELM and (**b**) RNN.

4.2.4. Classifiers' Performance Evaluation

We selected nine ML algorithms and compared their performance with respect to the prediction accuracy. We trained these algorithms and then tested them on both the datasets to predict future stock market trends. Before applying these ML algorithms, we split the final datasets into two portions, i.e., 70% of the samples as training data and the remaining 30% as testing data. The training and testing of the algorithms are performed by using a Python library for ML named Scikit-learn [31]. Table 7 provides a list of ML algorithms used in this experimentation along with their optimal parameters.

**Table 7.** Selected ML algorithms with their optimal parameter values.


#### 4.2.5. Performance of Algorithms before and after SSWN

We evaluated the performance of the chosen techniques with SSWN and without using it, i.e., by employing the standard SWN. Figure 5 demonstrates an overall increase in the accuracy of all algorithms after employing SSWN.

**Figure 5.** Accuracy comparison of algorithms before and after SSWN.

#### *4.3. Performance of Algorithms on Both Datasets*

Along with other comparisons, we also compared the performance of selected algorithms on the standard Sentiment140 dataset, as shown in Table 8.


**Table 8.** Performance of the algorithms on the data set of sentiment140.

It is evident from Table 8 that the ELM classifier outperforms other algorithms in terms of accuracy and precision. However, recall and F-measure of ELM cannot remain on top. It can also be observed from the table that that RNN shows second-best performance in terms of accuracy and precision wile remining on top in terms of F-measure.

Table 9 demonstrates the performance of these algorithms on the Twitter dataset. Table 9 demonstrates the significant performance improvement obtained using a majority of the classifiers. By incorporating SSWN, the performance of all algorithms except for NB DT increased, i.e., these two algorithms did not show any significant improvement in their performance. Whereas RNN here, again, shows the second-best performance in terms of accuracy and precision and best performance in terms of F-measure.


**Table 9.** Performance of the algorithms on the data set extracted from Twitter.

#### *4.4. Time Complexity*

The performance of all selected algorithms is also compared with respect to the time taken by the models for training and assigning the sentiment scores. Figure 6 shows a detailed comparison of the performance of algorithms in terms of time taken in seconds for training and scoring. NB took minimum time for training while FLM was the fastest while sentiment scoring. Overall, ELM and RNN were found to perform well with respect to accuracy combined with time complexity.

**Figure 6.** Performance Comparison of algorithms in terms of time taken.

#### *4.5. Classification Performance of the Selected Algorithms*

Figure 7 demonstrates a performance comparison of the selected algorithms in the form of a RoC curve plot, which shows that ELM outperforms others in terms of correct classification of the input samples.

#### *4.6. Comparison with State-of-the-Art Techniques*

Several researchers have presented ML-based work to predict the future trends for the stock market. Therefore, in this section, to assess the prediction robustness of our approach, we performed a comparative analysis of our framework with the latest MLbased approaches. This analysis is evaluated in terms of employed technique, data, as well as obtained accuracy and precision.

**Figure 7.** RoC Curves depicting the performance of the algorithms in terms of True positive rate.

The comparative results are reported in Table 10. Zhou et al. [33] presented an approach for stock market prediction by using the online emotions used by the people to assess their behaviors. The method [33] employed the SVM to perform classification and attained an accuracy of 64.15%. Nguyen et al. [34] introduced a framework for stock market prediction by using the sentiments related to a specific topic of the company and employed the SVM classifier for prediction. This method [34] showed an average accuracy of 54.41%. The work in [35] presented a data mining technique for stock market prediction with obtained accuracy of 66.48%. Khan et al. [25] introduced a framework by using the social sites along with political events for stock market prediction and attained an accuracy of 75.38%. Similarly, the technique in [36] used the same concept with the ANN classifier and attained an average accuracy of 77.12%. Khan et al. [37] presented another approach using financial news data and obtained an accuracy of 80.6%. The technique in [38] employed sentiments of the people from social sites along with naive Bayes and SVM classifier and showed the best average accuracy of 80.6%. From the Table 10, it can be witnessed that the presented approach showed an accuracy value of 85.7%, which is higher than all the comparative methods. Moreover, the comparative methods attained an average accuracy value of 71.24%, which is 85.7% in our case, so our method obtained a 14.46% performance gain.



The reported values demonstrate that the introduced approach outperforms the comparative techniques [25,33–39], by introducing the SSWN sentiment lexicon which assists in selecting a more representative set of features related to the stock market. Moreover, the methods in [25,33–39] are computationally more expensive and can result in over-fitting problem. However, in our case, the robustness of the ELM classifier to deal with the over-fitted training data helps to attain efficient accuracy with less processing time. So, the proposed method can be described as more effective and efficient for stock market prediction.

#### *4.7. Discussion*

The prediction of stock market prices is an interesting topic of research, and it is a challenging task due to the volatility, diversity, and dynamic behavior of stock market. Recent research has revealed that sentiments and news might influence the stock market movement and act as potential predictors for tradeoff outcomes. So, social media platforms can be considered an important source of information for extracting important chunks of information from the social media posts already published by the users. In this regard, Twitter becomes a more suitable source of information due to the concise nature of tweets posted there. However, this conciseness also makes the job more challenging due to usage of shortened words, duplication, and different types of noise residing in tweets. Combined with the power of machine learning, tweets can be significant for prediction of stock market prices. In this work, we introduced a novel approach for prediction of stock prices by using SA. For this purpose, we implement two distinct classifiers, i.e., RNN and ELM, along with other popular ones that are based on the proposed sentiment lexicon named SSWN and two datasets, i.e., data directly acquired from Twitter and a standard dataset named Sentiment140. We performed the experimentation on ten US market stocks data obtained from Google Finance. Firstly, we compared and evaluated the performance of nine different machine learning algorithms on the said stock data where the performance of ELM remained on top. Secondly, we compared our work with state-of-the-art while achieving a superior overall accuracy due to usage of a dedicated sentiment lexicon specially proposed for the prediction of stock market. The scope and working of the proposed technique can be further enhanced considering other DL-based approaches.

#### **5. Conclusions**

People use social media to share their personal ideas and opinions regarding a brand, entity, person, or an affair. Twitter is a globally recognized, modern social media platform for sharing ideas and opinions in a very concise way. Using the power of SA ML, social media posts such as tweets can play a significant role in the prediction of the stock market behavior. This work introduces a novel approach for stock market prediction using SA. The model is based on the proposed SSWN sentiment lexicon along with RNN and ELM classifiers. We have used Twitter data and Sentiment140 dataset for the performance evaluation of the ML models considering ten different brands for stock market prediction. We achieved the average accuracy of 81.40% for RNN and 86.06% for ELM classifier. We compared our approach with various ML models as well as with state-of-the art methods and achieved remarkable results. In future, we plan to enhance the capability and coverage of this approach by adding other popular social media platforms, e.g., Facebook, and Google News. Furthermore, we may evaluate the proposed approach over other challenging datasets while considering more stocks as well.

**Author Contributions:** Conceptualization, formal analysis, data analysis, data interpretation, literature search, funding acquisition, project administration, S.A.; conceptualization, software, resources, methodology, writing—original draft, T.N.; validation, visualization, writing—original draft, A.M.; supervision, validation, writing—review and editing, A.I.; literature search, investigation, validation, A.A.; conceptualization, supervision, writing—review and editing, proofreading, W.A. All authors have read and agreed to the published version of the manuscript.

**Funding:** This research was funded by the Deanship of Scientific Research, Qassim University.

**Acknowledgments:** The researchers would like to thank the Deanship of Scientific Research, Qassim University for funding the publication of this project.

**Conflicts of Interest:** The authors declare that there is no conflict of interest between authors.

#### **References**


## *Article* **Multimodal CT Image Synthesis Using Unsupervised Deep Generative Adversarial Networks for Stroke Lesion Segmentation**

**Suzhe Wang \*, Xueying Zhang, Haisheng Hui, Fenglian Li and Zelin Wu**

College of Information and Computer, Taiyuan University of Technology, Taiyuan 030024, China **\*** Correspondence: wangsuzhe@tyut.edu.cn

**Abstract:** Deep learning-based techniques can obtain high precision for multimodal stroke segmentation tasks. However, the performance often requires a large number of training examples. Additionally, existing data extension approaches for the segmentation are less efficient in creating much more realistic images. To overcome these limitations, an unsupervised adversarial data augmentation mechanism (UTC-GAN) is developed to synthesize multimodal computed tomography (CT) brain scans. In our approach, the CT samples generation and cross-modality translation differentiation are accomplished simultaneously by integrating a Siamesed auto-encoder architecture into the generative adversarial network. In addition, a Gaussian mixture translation module is further proposed, which incorporates a translation loss to learn an intrinsic mapping between the latent space and the multimodal translation function. Finally, qualitative and quantitative experiments show that UTC-GAN significantly improves the generation ability. The stroke dataset enriched by the proposed model also provides a superior improvement in segmentation accuracy, compared with the performance of current competing unsupervised models.

**Keywords:** stroke lesion segmentation; generative adversarial network; unsupervised data augmentation

### **1. Introduction**

Stroke is the problem with blood supply blocking in cerebral vessels, and it is the most prevalent cause of mortality and acquired handicap [1,2]. Among various types of strokes, ischemic stroke is reported in a large proportion, and it mainly induces brain cell death and fatal paralysis. Hence, early diagnosis and quantification of the lesions could help stroke patients achieve effective recovery, which also benefits clinicians in optimizing the therapeutic schedules. Quantitative stroke lesion segmentation from medical imaging is a necessary procedure for the doctor to make decisions. Additionally, computed tomography (CT) is a typical effective non-incursion technique to evaluate the lesion regions of stroke patients [3,4]. It also has the merits of speediness, wide availability and inexpensiveness in detecting brain structure by ionizing radiation. Moreover, CT perfusion modalities including cerebral blood volume (CBV), cerebral blood flow (CBF), mean transit time (MTT) and time to peak of the residue function (Tmax) are also successfully used to assess stroke infarct core size [5]. However, precise segmentation based on these diagnostic means requires rich experience and a significant amount of time from physicians.

Recently, deep neural network-based methods have shown a remarkable impact on the segmentation accuracy for various medical images [6,7]. However, the scarcity of labelling multimodal pictures due to the enormous time cost and complex acquisition procedures often leads to low segmentation accuracy. Although previous image augmentation strategies such as rotation, flipping and elastic deformation have been applied widely to expand the volume of the dataset [8–10], it is impossible to produce a wide diversity of new features from the aspects of texture, shape and location. The disadvantages also limit the learning

**Citation:** Wang, S.; Zhang, X.; Hui, H.; Li, F.; Wu, Z. Multimodal CT Image Synthesis Using Unsupervised Deep Generative Adversarial Networks for Stroke Lesion Segmentation. *Electronics* **2022**, *11*, 2612. https://doi.org/10.3390/ electronics11162612

Academic Editor: José L. Abellán

Received: 15 July 2022 Accepted: 17 August 2022 Published: 20 August 2022

**Publisher's Note:** MDPI stays neutral with regard to jurisdictional claims in published maps and institutional affiliations.

**Copyright:** © 2022 by the authors. Licensee MDPI, Basel, Switzerland. This article is an open access article distributed under the terms and conditions of the Creative Commons Attribution (CC BY) license (https:// creativecommons.org/licenses/by/ 4.0/).

capacity of the deep medical image segmentation model. Generative Adversarial Network (GAN) is a solution for verisimilar image generation and domain translation [11,12]. Nonetheless, this technique always relies on fully annotated paired images and supervised training, which is impractical to collect all modalities for each patient. Additionally, they always have an arduous process for the cross-modality translation. Several recent GAN variants try to tackle the problem by encouraging the extra encoders to capture the domain features [13–16]. However, the quality of the synthetic image may be negatively influenced when the difference between domains increases.

In this work, we proposed a GAN-based data enhancement architecture for CT ischemic stroke lesion segmentation. An unsupervised translation cycle generative adversarial network (UTC-GAN) is presented for the sake of the segmentation accuracy improvement. The main contribution includes the following:


#### **2. Related work**

#### *2.1. Medical Image Segmentation*

Numerous deep network designs have been exploited for medicine image segmentation in recent years [17–19]. Albert Clèrigues et al. [20] introduced symmetrical residual auto-encoding U-Net to perform lesion segmentation on CT images. Meanwhile, modality augmentation is utilized to provide more symmetric samples. Liu et al. [21] embeds an attention component in the deep CNN architecture to improve the predictive quality for white matter hypertension lesions. Furthermore, Zhang et al. [22] employed a 3D DenseNet model with dense block and multi-scale unit to localize the stroke lesions with harsh noise and low picture quality. Among those convolutional neural networks, U-Net is an influential architecture for biological image segmentation. For example, in [23], a U-Net model was applied to complete interwoven neurons and neurites segmentation tasks. Instead of only including contracting and expanding paths, Cui et al. [24] exploited a Bi-Directional ConvLSTM U-Net for blood vessel segmentation to fuse higher resolution features and semantic information.

Traditional works mainly focus on supervised learning. However, it requires sufficient training images with pixel-wise annotations. Therefore, some segmentation studies based on data augmentation are reported [25,26]. In these works, the supervised generative adversarial networks (GAN) are well used to expand the data amount by modelling proper data distribution in a two-player game framework. Jelmer et al. [27] employed DCGAN to complete brain CT image synthesis from MRI. A similar method has been employed in [28] to convert the T1 MRI scans into the T2 modality. Moreover, a multi-stage GAN method is designed to form image–mask pairs for segmentation task, by using the U-Net-like WGAN-GP as the central architecture [29]. To gain a higher augmentation quality and segmentation performance, a multi-scaled GAN framework, which also preserves the boundary of the tumor core, is composed to collaborate with the U-Net [12]. In addition, GAN is utilized for the unbalanced semantic segmentation task to balance data distribution [30]. Most GANs are hard to translate more than two domains or need to obtain the domain label for all training images.

#### *2.2. Unsupervised Generative Adversarial Network*

Several works use GANs to shift domain by unsupervised transferring of information between multiple modalities. Andrade et al. [31] trained Cycle-GAN to transform skin images from the macroscopic domain into the dermoscopic domains, which helps to acquire better segmentation capacity. Chen et al. [32] adapted a formation and appearance-detached augmentation GAN for unannotated cardiac CT image segmentation, synthesizing data from the annotated MRI slices. A unified generative adversarial network is established to execute 3D multimodal segmentation [33]. Such ideas above use other information to guide the generator in GAN to master the domain translation mapping. However, it is hard to capture semantic information across multimodal domains. To improve shape transformation and focus on the difference among the domains, U-GAT-IT [13] introduces an attentional component and a new layer instance-normalization technique in the least-squares GAN to complete the unsupervised domain transfer task. More recently, a tuple of concurrent GANs are designed to perform multi-class unsupervised image domain translation through conditional image generator and multi-task adversarial discriminator [14–16], where the generator is used for encoding content and class. In other words, the figure from a certain domain is composed of content and type information simultaneously.

#### **3. Network Implementation**

In this section, the overall UTC-GAN framework is firstly given. After that, we then offer the details of the translation representation module (GM-TRM) and objective loss functions accordingly.

#### *3.1. Model Architecture*

The image synthesizing is realized by a UTC-GAN module, which allows translating the CT slices from one modality domain to another. This model inherits the benefits of Cycle-GAN [31], which consists of a forward and backward cycle. The forward cycle of UTC-GAN is illustrated in Figure 1 and vice versa. It is noted that generator *GA* is designed to transfer CT images to its perfusion modalities on the dimension of latent space, and *GB* provides the reversed mapping images. The discriminators *DB* try to discriminate whether the synthesized images are fact or fiction. Inspired by the observation in the autoencoding transformations approach [34], a paralleled auto-encoding structure is embedded in discriminator *DB* to extract the representation of modality transformation automatically. Therefore, the discriminator is partitioned into diverse components: encoder *E*, decoder *De* and classifier *C*. Additionally, the architecture of each part is illustrated in Figure 2. Firstly, two encoders *EA*, *EB* are trained to extract the desired parameters of the training image from different domains, which uses a Siamese structure to share weight and obtain a co-training information. Each encoder introduces the AlexNet [17] as the backbone, and the Inception v1 block is embedded to enhance the network convergence. The decoder *De*, a network with one convolutional layer and one fully connected layer, is coupled with the encoders to estimate the modality translation from the fused features. Finally, one classifier *CAdv* is added upon the encoder *EA* from generated domain to decide whether the synthesized image is real or not, and another *CTran* is built upon the decoder *De* to distinguish which transformation is inputted. Each classifier contains two convolutional layers followed by two other fully connected layers.

**Figure 1.** Illustration of the UTC-GAN architecture.

#### *3.2. Gaussian Mixture–Translation Representation Module*

As in the UTC-GAN, the auto-encoding structure, also called the Gaussian mixture– translation representation module (GM-TRM), is proposed for learning the modality transformation representations automatically. The GM-TRM module first encodes the CT images from different modalities into a latent space. Then, the translation representation is extracted (self-supervised) in the decoding part via estimating the mutual information across the latent features. The GM-TRM module would lead the generator to build highly entangled translation representations, and would enforce the discriminator to obtain more additional supervision on image generation. Here, we elaborate on the principle of the GM-TRM as shown in Figure 3.

Let d(x) presents the data distribution of an image *x* from the original modality A. When it is translated to another modality B, a translation function t is assumed to turn *d*(*x*) into *t*(*d*(*x*)). The encoder *EA* is considered to produce a representation from the image *t*(*x*) with network parameters *θ*, which also maps the low-dimensional input data *t*(*x*) to a high-level latent variable *r*. The classical reflection of the latent variable *r* from encoder is a statistical formulation, which is specified by the mean *m<sup>θ</sup>* and variance *σ*<sup>2</sup> *<sup>θ</sup>* of a normal distribution *N*(*ε*|0, *I*), such that:

$$r = m\_{\theta}(t(\mathbf{x})) + v\_{\theta}(t(\mathbf{x})) \tag{1}$$

**Figure 2.** Illustration of the component architecture in discriminator for (**a**) Encoder; (**b**) Decoder; (**c**) Classifier.

To better approximate the transformation in latent space, the probabilistic density of encoding images can be assumed to follow the Gaussian mixture model instead of a single Gaussian. The resultant probabilistic representation of true posterior can be formulated as:

$$d\_{\theta}(r|t,\mathbf{x}) = \sum\_{i=1}^{k} \phi\_i \mathcal{N}\left(r\_i \Big| m\_{i\theta}(t(\mathbf{x})), \sigma\_{i\theta}^2 \mathbf{t}(\mathbf{x})\right) \Big| \sum\_{i=1}^{k} \phi\_i = 1 \tag{2}$$

where *N* denotes a standard normal distribution, *k* is the total count of transformations, *φ<sup>i</sup>* is a weight vector, and *mi<sup>θ</sup>* and *σ*<sup>2</sup> *<sup>i</sup><sup>θ</sup>* are defined as the mean and variance of the *i*th Gaussian component, respectively.

**Figure 3.** The workflow of Gaussian mixture translation representation module (GM-TRM).

Meanwhile, we train a decoder *De* to estimate the applied transformation parameters *t* by comparing the representations from encoded image features of original and target modality. The probabilistic density of decoder is defined as *pϕ*(*t r*,*r*) with the parameter *<sup>ϕ</sup>*, and *<sup>r</sup>* represents the probabilistic density from sample of original domain.

From an information-theoretic viewpoint [35], the transformation *t* can be equivalent to the joint mutual information *<sup>I</sup>*(*t*,(*r*,*r*)) between itself and the encoding latent space (*r*,*r*), and the maximum *<sup>I</sup>*(*t*,(*r*,*r*)) can be considered the optimized representation. However, the posterior *pϕ*(*t*|*r*) cannot be calculated directly. So, a moment matching approximation approach is introduced to compute *pϕ*(*t*|*r*) conveniently, which can be derived to the following formulation:

$$\begin{aligned} &I(t,r) = H(t) - H(t|r) \\ &= H(t) + \sum\_{i=1}^{k} \Phi\_i E \log d\_{i\theta}(t\_i|r\_i) \\ &= H(t) + \sum\_{j=1}^{k} \delta\_j E \log p\_{j\theta}(t\_j|r\_j) + \sum\_{i=1}^{k} \sum\_{j=1}^{k} E \left( KL(d\_{i\theta}(t\_j|r\_i) || p\_{j\theta}(t\_j|r\_i))) \right) \\ &\ge H(t) + \sum\_{j=1}^{k} \delta\_j E\_{d\_\theta(t,r|x)} \log p\_{j\theta}(t\_j|r\_i) \end{aligned} \tag{3}$$

where *KL*(*di<sup>θ</sup> tj ri pj<sup>ϕ</sup> tj ri* represents a non-negative Kullback–Leibler divergence between density *di<sup>θ</sup>* and *pjϕ*, *δ<sup>j</sup>* is the mixture weight. From this representation, the variational posterior distribution *d<sup>θ</sup>* (*r*|*t*, *x*) can be tractable and replaced approximated by an upperbounded parameterized model *pϕ*(*t r*). In addition, the entropy *H*(*t*) is independent of parameters *θ* and *ϕ* with respect to GAN model. So, we can maximize the lower variational bound *I*(*t*,*r*) through only calculating *logpϕ*(*t r*, *x*).

In the meantime, the decoded transformation vector can be associated with the input images pairs, which enforces the generator to utilize classified transformation information as well. Hence, the corresponding generated image will learn different attributes from the other domains.

#### *3.3. Loss Function*

The UTC-GAN not only relies on the model architecture mentioned above but also the appropriate loss function to perfect the model performance. Our loss includes three components, of which the translation representation loss is newly proposed to learn the style transferring from the CT slices pair automatically.

#### 3.3.1. Adversarial and Cycle-Consistency Loss

The adversarial and cycle-consistency loss from Cycle-GAN [30] are employed both in the generation cycle and discriminator. For the images translation *A* → *B*, the adversarial loss can be expressed as:

$$L\_{Adv}(G\_A, D\_B) = E\_{\upsilon \sim p\_B(\upsilon)}[\log D\_B(\upsilon\_B)] + E\_{\mathbf{u} \sim p\_A(\mathbf{u})}[\log(1 - D\_B(G\_A(\mathbf{u}\_A)))] \tag{4}$$

where *u* and *v* are the training slices come from source and target modality. Similarly, the adversarial loss of the translation *B* → *A* is denoted as:

$$L\_{\mathrm{Adv}}(G\_{\mathcal{B}}, D\_A) = E\_{\mathfrak{u} \sim p\_A(\mathfrak{u})} [\log D\_A(\mathfrak{u}\_A)] + E\_{\mathfrak{v} \sim p\_B(\mathfrak{v})} [\log(1 - D\_A(G\_{\mathcal{B}}(\mathfrak{v}\_{\mathcal{B}})))] \tag{5}$$

In order to ensure the generated slices can be reconstructed to their previous modality simultaneously, a cycle-consistency loss *LCyc*(*GA*, *GB*) is utilized into architecture to associate the reconstructed image *GA*(*GB*(*uA*)) with the input image *u*. Thus, the loss function with forward–backward consistency is defined as:

$$\begin{array}{rcl} \mathrm{L}(\mathrm{G}\_{\mathrm{A}}, \mathrm{G}\_{\mathrm{B}}, \mathrm{D}\_{\mathrm{A}}, \mathrm{D}\_{\mathrm{B}}) &= \mathrm{L}\_{\mathrm{Adv}}(\mathrm{G}\_{\mathrm{A}}, \mathrm{D}\_{\mathrm{B}}) + \mathrm{L}\_{\mathrm{Adv}}(\mathrm{G}\_{\mathrm{B}}, \mathrm{D}\_{\mathrm{A}}) + \lambda\_{1} \Big( \mathrm{L}\_{\mathrm{Cyc}}(\mathrm{G}\_{\mathrm{A}}, \mathrm{G}\_{\mathrm{B}})\\ &+ \mathrm{L}\_{\mathrm{Cyc}}(\mathrm{G}\_{\mathrm{B}}, \mathrm{G}\_{\mathrm{A}}) \Big) \end{array} \tag{6}$$

where *λ*<sup>1</sup> is the relative importance of GAN loss concering cycle loss.

#### 3.3.2. Translation Representation Loss

Since the adversarial loss cannot detect the transformation directly from encoders, we have used the joint mutual information *I*(*t*, *v*|*u*) as the decoder to predict domain translation. Additionally, the lower variational bound of *I*(*t*, *v*|*u*) can be maximized by learning the expectation over posterior distribution *pϕ*(*t v*) according to (2). Thus, the translation forecasting loss *LTran* can be described as:

$$\begin{split} L\_{\text{Train}} &= \max I(t, \boldsymbol{\upsilon} | \boldsymbol{u}) = \max\_{\theta, \boldsymbol{\rho}} E \log p\_{\boldsymbol{\rho}}(t | \boldsymbol{r}) \\ &= \max\_{\theta, \boldsymbol{\rho}} \delta\_{\boldsymbol{\beta}} \sum\_{j=1}^{k} \log \mathcal{N}(t | \boldsymbol{u}\_{j}(\boldsymbol{r}), \boldsymbol{\upsilon}\_{j}^{2}(\boldsymbol{r})) \end{split} \tag{7}$$

where the mean *uj* and variance *vj* are derived from the encoder, respectively, and *δ<sup>j</sup>* is a weight vector. This loss function, *LTran*, is added to the discriminator *DB* to learn the transformation functions between different domains. Then, the whole augmented objective function is given by:

$$L\_{Total} = L\_{Adv} + \lambda\_1 L\_{Cyc} + \lambda\_2 L\_{Tran} \tag{8}$$

where *λ*<sup>2</sup> is a hyper-parameter applied for affecting the significance proportion in total loss.

#### **4. Experiments**

*4.1. Experimental Settings*

#### 4.1.1. Dataset

The Ischemic Stroke Lesion Segmentation Challenge (ISLES) 2018 dataset is used to execute the training and assessment of our augmenting-based segmentation. The dataset contains multiple modalities, including CT and four derived perfusion maps, i.e., mean transit time (MTT), time to peak of the residue function (Tmax), cerebral blood flow (CBF) and cerebral blood volume (CBV) [36,37]. In our experiments, the 94 labelled cases with the CT and its perfusion modalities serve as input for the UTC-GAN network.

Additionally, all images in the dataset are preformed augmentation via skull stripping and traditional operations such as flipping, scaling and rotating. Finally, eighty percent of all the image scans are treated as the training and validation set, and the remainder is the testing set. That is, we uses 10,980 slices for training, 3660 slices for validating and 3660 slices for testing, respectively.

#### 4.1.2. Baseline Model and Evaluation Measures

For comparison in augmentation, we use the ISLES2018 dataset to compare UTC-GAN to five existing unsupervised baseline models: Cycle-GAN [30], DRIT++ [14], EGSC-IT [15], FUNIT [16] and U-GAT-IT [13]. The Cycle-GAN and U-GAT-IT model have a similar mechanism and effectiveness to the proposed method. DRIT++ introduces the disentangled representation model to learn the mapping from the CT to other perfusion modalities. EGSC-IT realizes unsupervised image-to-image translation by adopting weight-sharing architecture and feature masks. FUNIT is chosen as another comparable baseline model to perform image translation in that it has a multi-task synthesis structure with reconstruction and feature-matching loss function. At the segmenting phase, a multiple-scale minus network (MSNet) [38] is trained to evaluate the improvement by the generative model. To input the multi-modalities CT slices, the encoder path of the MSNet is duplicated five times and concatenated to the decoding part. Due to computational constraints, every slice image is resized to 256 × 256.

We verify the synthesis methods by adopting four widely used metrics: Peak Signal to Noise Ratio (PSNR), Normalized Mean Squared Error (NMSE) and Structural Similarity Index Measurement (SSIM) [39].

Six performance metrics are chosen to analyze the improvement in the segmentation accuracy for the testing set, including Dice Coefficient (DC), Intersection-over-Union (IoU) score, Precision, Accuracy, Recall and Hausdorff Distance (HD) [40].

#### 4.1.3. Implementation Details

In the data augmentation stage, we adopt the ResNet architectures from Miyato et al. [20] as the backbone of the UTC-GAN generator. For UTC-GAN discriminator, encoder *EA* and *EB* in each branch consists of InceptionV1 block, then the output features from decoder *De* are concatenated to a convolutional classifier *CTran*. Classifier *CAdv* is framed as same as *CTran*. During the segmentation stage, a standard U-Net framework is trained for segmenting stroke lesion areas. The back-propagation of both networks were completed by adopting ADAM optimization algorithm [41] with *β*<sup>1</sup> = 0.5 and *β*<sup>2</sup> = 0.999. The initial learning rate was positioned at 2 × <sup>10</sup>−<sup>4</sup> for all networks. The loss-balancing weight parameters *λ*<sup>1</sup> and *λ*<sup>2</sup> were determined as 15 and 10, respectively. The exponential decay rate is used at a rate of 0.001 every 30 epochs, training the model for a total of 200 epochs with batch size of 24. Each paired input contains the original CT slice and their perfusion modality counterpart.

The synthesis and segmentation networks were implemented by Pytorch on the NVIDIA Titan XP GPU device. Furthermore, the overall experiments are conducted five times with varying random seeds, and the average value is reported.

#### *4.2. The Impact of GM-TRM Module*

First, we compare the generation quality to assess the effectiveness of components in our UTC-GAN model. The main component including the GM-TRM module and the classifier *CTran* concatenated to decoder *De* are replaced in the architecture and compared with other existing schemes in sequence. As the competitors for comparison, the AET module also employs the auto-encoding architecture to learn the unsupervised transformation representation. Additionally, the AVT module creates the translation information by applying a constrained variational approach to the similar auto-encoders network. Table 1 firstly presents the quantitative synthesis performance with GM-TRM and other modules. We observe that the GM-TRM enhances the synthesized performance by leading to 1% improvement on average over PSNR, and 0.3% over NMSE. The evaluation result demonstrates that the mixture density can derive better transformation representation. Moreover, we also quantitatively compare the results by varying different classifiers upon the GM-TRM. The traditional non-linear classifier adopts fully connected layers to discrete transformations, whereas the convolutional type takes advantage of the convolution kernel to narrow the range of the transformation prediction. From the results, we can see that

the convolutional classifier is consistently better than the non-linear classifier, and our synthesis model can almost achieve the best PSNR no matter which classifier is used.


**Table 1.** Synthesizing performance of the UTC-GAN composed of different components.

#### *4.3. Comparison with Other Unsupervised Data Augmentation Methods*

Moreover, we next evaluate the UTC-GAN compared with various corresponding synthesis models.

Table 2 reports the quantitative synthesis results for all baselines. From the results underlined in Table, the UTC-GAN surpasses the other five unsupervised models by raising the PSNR from 23.13 to 26.40, and the NMSE approximates 0.096 when the SSIM achieves 0.918. This indicates our model attains the transformation characteristic across different modalities, and that it contributes to achieving superior synthesis effectiveness.


**Table 2.** Synthesizing performance comparison with different baselines.

Figure 4 shows a visualization comparison under different types of modalities between the proposed UTC-GAN model and other unsupervised baselines. As we can see, the UTC-GAN generates much more realistic synthesis images, while samples from Cycle-GAN and U-GAT-IT generate some unclear regions or fail to create a detailed feature. In contrast, samples from FUNIT, EGSC-IT and DRIT++ show similar results, which supply more complex attributes of the brain to a certain extent but also have some unwanted artifacts that can be found in the image. Overall, the proposed method yields higher visual realism results for all CT modalities than the others, as indicated by qualitative and quantitative measures.

#### *4.4. Segmentation Using Data Augmentation*

To investigate the influence of UTC-GAN on the segmentation improvement, we first use the real and synthetic data to provide for the MSNet segmentation model, the performance evaluation is described in Figure 5. The synthetic images have a positive effect on the segmentation accurateness where the synthetic data ratio is less than 60%. The increasing proportion of the synthetic images leads to 7% and 8% improvement on the dice score and precision than the results obtained only by the real images. Moreover, we found that too many synthetic images could degrade the performance.


**Figure 4.** The visualizations of real and generated CT slices under different modalities. From left to right: CT, Tmax, CBF, MTT, CBV modality. The top-down results: Input ground truth, Cycle-GAN, U-GAT-IT, FUNIT, EGSC\_IT, DRIT++ and our augmentation method.

**Figure 5.** The performance of stroke segmentation by changing the percentage of synthetic images.

The second experiment is conducted on the task of segmentation by applying both all the data augmentation methods mentioned earlier and the MSNet model on ISLES2018 data. The quantitative results of segmentation can be viewed in Table 3. According to the demonstration in the table, all of the synthesis models provide more significant improvement on all evaluation metrics than only using the MSNet segmentation method with the randomness added by data pre-processing. Additionally, our process appears to produce the most significant improvement when generating transformation detectable images. For example, our proposed UTC-GAN allows the Dice score to improve from 0.675 to 0.768, increasing the precision by and 10% and reducing the Hausdorff distance by 9.3 mm, respectively. The IoU score and accuracy also achieved the highest value. This improvement may benefit the most due to the high contrast and variation from the generated images. Furthermore, the boxplot of the Dice coefficient is also used to analyze the segmentation robustness in Figure 6. Overall, the mean value and the median line of the UTC-GAN expresses relatively higher improvement than the other six competitors.


**Table 3.** Segmentation performance comparison with different strategies.

Additionally, three representative cases of visual segmentation scan for variant combinations are displayed in Figure 7. Compared with other data augmentation means, the UTC-GAN produces considerable improvements in the segmentation mask, allowing segmentation regions to be closer to their manual annotation counterpart. This intuitively demonstrates that the stroke CT images augmented by the UTC-GAN model help obtain a more accurate lesion region.

**Figure 6.** The Dice coefficient in boxplot by different methods, where the asterisk \* and number in black represent the respective mean value.

**Figure 7.** Visual results of lesion segmentation processed via the MSNet and data augmentation models. The boundaries of expertise-based and unsupervised learning-based segmentation are painted with red and green curves, respectively.

#### **5. Conclusions**

In this paper, a GAN-based data augmentation paradigm is presented to promote the exactness of ischemic stroke segmentation. By integrating Siamesed auto-encoders and information-theoretic loss into a Cycle-generative adversarial framework, the architecture can learn sufficient representations about transformation from the original to the target domain in an unsupervised fashion. The sampling problem on latent data space is further solved by introducing a Gaussian mixture probability distribution to better approximate the characteristics of the transformation. Based on the experimental evaluation and comparison, it demonstrates that the proposed method outperforms alternative structures and provides better high-quality generated stroke images. Meanwhile, the augmentation method could yield higher segmentation quality improvement by cooperating with traditional segmentation methods.

**Author Contributions:** Conceptualization and methodology, S.W. and X.Z.; formal analysis, S.W., X.Z., H.H., F.L. and Z.W.; software, validation and writing—original draft, S.W., X.Z. and H.H.; writing review and editing and data curation, S.W., X.Z., H.H. and F.L.; supervision and funding acquisition, S.W., X.Z. and F.L. All authors have read and agreed to the published version of the manuscript.

**Funding:** This work is supported by the General Program under grant funded by the National Natural Science Foundation of China (NSFC) (No. 62171307), and the Basic Research Program of Shanxi Province under grant funded by the Department of Science and Technology of Shanxi Province (China) (No. 202103021224113).

**Data Availability Statement:** The ISLES 2018 dataset is publicly available at http://www.isleschallenge.org/ (accessed on 15 July 2022).

**Conflicts of Interest:** The authors declare no conflict of interest.

**Ethical Approval:** Dataset was obtained from ISLES challenge, there is no live interaction with subjects and the dataset is anonymous. Hence this article does not contain any studies with live human participants or animals performed by any of the authors.

#### **References**


## *Article* **An Effective Deep Learning-Based Architecture for Prediction of N7-Methylguanosine Sites in Health Systems**

**Muhammad Tahir 1, Maqsood Hayat 1,\*, Rahim Khan <sup>1</sup> and Kil To Chong 2,3,\***


**Abstract:** N7-methylguanosine (m7G) is one of the most important epigenetic modifications found in rRNA, mRNA, and tRNA, and performs a promising role in gene expression regulation. Owing to its significance, well-equipped traditional laboratory-based techniques have been performed for the identification of N7-methylguanosine (m7G). Consequently, these approaches were found to be time-consuming and cost-ineffective. To move on from these traditional approaches to predict N7 methylguanosine sites with high precision, the concept of artificial intelligence has been adopted. In this study, an intelligent computational model called N7-methylguanosine-Long short-term memory (m7G-LSTM) is introduced for the prediction of N7-methylguanosine sites. One-hot encoding and word2vec feature schemes are used to express the biological sequences while the LSTM and CNN algorithms have been employed for classification. The proposed "m7G-LSTM" model obtained an accuracy value of 95.95%, a specificity value of 95.94%, a sensitivity value of 95.97%, and Matthew's correlation coefficient (MCC) value of 0.919. The proposed predictive m7G-LSTM model has significantly achieved better outcomes than previous models in terms of all evaluation parameters. The proposed m7G-LSTM computational system aims to support the drug industry and help researchers in the fields of bioinformatics to enhance innovation for the prediction of the behavior of N7-methylguanosine sites.

**Keywords:** deep learning; pattern recognition; LSTM; RNA; natural language processing

### **1. Introduction**

To date, about 150 various types of RNA alteration/modification have been recognized. These changes in RNA perform important functions in regulating the expression of genes at different levels. For example, it was previously confirmed that changes in RNA can affect RNA transport, processing, mRNA translation, and stability [1,2]. The most abundant RNA modifications are the N7-methylguanosine (m7G), which happens in the 5' cap position of mRNA molecules and transfer RNA (tRNA) loop eukaryotic S ribosomal RNA (rRNA); these modifications are preserved amid the three different kingdoms. This modification plays a precarious role in the regulation of RNA function, post-transcription modifications, and metabolism [3]. Within the context of aging, Zago et al. have discussed the emerging importance of microRNAs as biomarkers for Parkinson's disease [4]. Unfortunately, data on the functional mechanisms are very limited. Recently, it has been revealed that m7G sites can be efficiently identified by modern sequencing techniques [5–7]. Using deep sequencing technology, Marchand et al. investigated AlkAniline-Seq for identifying m7G in RNA at single-nucleotide resolution in yeast, human, and bacterial mitochondrial and cytoplasmic rRNAs and tRNAs [5]. Furthermore, differently modifying internal m7G sites to certain basic sites, Zhang et al. effectively established a MeRIP-seq approach to predict m7G sites at single-base resolution. The single-base resolution approach used in m7G-seq

**Citation:** Tahir, M.; Hayat, M.; Khan, R.; Chong, K.T. An Effective Deep Learning-Based Architecture for Prediction of N7-Methylguanosine Sites in Health Systems. *Electronics* **2022**, *11*, 1917. https://doi.org/ 10.3390/electronics11121917

Academic Editors: Yanhui Guo, Deepika Koundal, Rashid Amin and Gemma Piella

Received: 22 April 2022 Accepted: 14 June 2022 Published: 20 June 2022

**Publisher's Note:** MDPI stays neutral with regard to jurisdictional claims in published maps and institutional affiliations.

**Copyright:** © 2022 by the authors. Licensee MDPI, Basel, Switzerland. This article is an open access article distributed under the terms and conditions of the Creative Commons Attribution (CC BY) license (https:// creativecommons.org/licenses/by/ 4.0/).

data revealed the profile of m7G in human mRNA and tRNA, to enhance the knowledge of m7G distribution in human cells [6].

However, the transcriptome-wide dissemination and vibrant regulation of m7G within internal mRNA areas are still unknown. According to Zhao et al., the internal mRNA m7G methyltransferase METTL1, and not WDR4, is a key responder to post-ischemic insults, resulting in a global reduction in m7G methylation inside mRNA [8]. In addition, Liu et al. introduced m7GPredictor for predicting internal m7G modification sites using sequence properties [9]. In this model, the authors used various numerical descriptor methods and a random forest was used for the selection of optimal feature sets. Likewise, Bi et al., developed a computational model for the identification of m7G sites [10]. In this model, they have used different types of sequence encoding schemes in combination with the XGBoost algorithm. Further, Shoombuatong et al. proposed a new predictor known as THRONE for discrimination of human RNA N7-methylguanosine sites [11]. The THRONE was designed in three steps using an ensemble learning predictor. Likewise, the m7G-DPP web predictor was introduced by Zou and Yin by using physicochemical properties of RNA for the prediction of m7G sites [12]. Here, Pearson correlation coefficient, dynamic time warping, and distance correlation were utilized for extracting numerical features. Next, the LASSO algorithm was employed to select highly discriminative features [12]. Likewise, Zhang et al. introduced a predictor, namely BERT-m7G, by utilizing in staking ensemble approach for the identification of RNA m7G sites [13]. In this model, a BERT-based multilingual model was utilized to represent the information RNA sequences. Similarly, to specifically detect the internal mRNA m7G mutation, Malbec et al. developed the m7G individual nucleotide-resolution cross-linking and immunoprecipitation with sequencing (miCLIP-seq) approach [7]. Finally, this group of researchers determined that m7G modifications are enriched in AG-rich contexts, which are highly preserved in different mouse tissues and human cell lines. However, the advanced sequencing techniques revealed significant findings in this area, although these methods are still costly for transcriptome-wide detection. In this context, computational analysis m7G site predictors have been introduced, namely m7GFinder [14], iRNA-m7G [15], and m7G-IFL [16]. In these predictors, Yang et al. introduced a computational m7GFinder tool that can predict m7G sites in H. sapiens RNA using a sequence-based approach. The optimal feature subset was determined using mRMR, F-score, and Relief; and a support vector machine (SVM) was used as a classifier. Similarly, in sequential, the iRNA-m7G model was performed by Chen et al. for the identification of m7-methylguanosine sites by fusing multiple feature spaces. In this model, sequential- and structural-based features were integrated in order to form a hybrid space. Three types of features were combined using the feature fusion method, including secondary structure components, pseudo-nucleotide composition, and nucleotide property and frequency, to extract important RNA sequence features. Experiments have shown that the feature fusion technique outperforms the use of a single type of feature in detecting m7G sites [15]. Similarly, Ning et al. presented a predictor for the identification of m7G, namely m7G-DLSTM based on an LSTM model and natural language processing (NLP), nucleotide chemical property, and binary code feature extraction methods [17]. Most recently, an m7G-IFL computational model for identifying m7G sites was developed by Dai et al. [18]. This model uses an RNA sequence-encoding iterative feature representation approach to discover probabilistic distribution information from various sequential models and improve feature representation skills in a supervised iterative manner. The m7G-IFL predictor used various feature extraction techniques such as ring-function-hydrogen properties (RFH), physical-chemical-properties (PCP), and binary k-mer frequency (BKF). Then, extreme gradient boosting (XGBoost) was applied as a classifier. Furthermore, it was discovered that the proposed iterative feature method can improve feature representation capability during the iterative phase through feature analysis [16].

Furthermore, enhanced efficiency of existing computational models is still needed in the detection process. Thus, there is a dire need for the development of novel computational methods for the accurate, fast, and precise detection of m7G modification. In our recent study, we tried to focus on deep learning-based prediction methodologies to develop an accurate computational system called "m7G-LSTM" to predict N7-methylguanosine sites, which could directly determine m7G sites based on sequence information. The proposed m7G-LSTM system contains two stages i.e., distributed attributes representation and long short-term memory (LSTM) model. In the attributes representation stage, the NLP-based approach word2vec is applied to fragment the RNA instance into words (3-mers). Likewise, in the second stage, the N7-methylguanosine site is identified by using the LSTM model. The proposed prediction model m7G-LSTM has shown better performance and obtained promising outcomes.

#### **2. Methods and Materials**

#### *2.1. Benchmark Dataset*

Here, we select and download a benchmark dataset from Chen et al. [15,16] to train the proposed computational system. The benchmark dataset consists of 741 positive sequences, which are m7G sites, and 741 negative sequences, which are non-m7G sites; both have the same length (41 nucleotides). The benchmark dataset is mathematically expressed in Equation (1).

$$\mathcal{S} = \mathcal{S}^+ \cup \mathcal{S}^- \tag{1}$$

The benchmark dataset *S* consists of m7G sites and non-m7G sites, *S*<sup>+</sup> with positive m7G sites sequences, and *S*− with negative non-m7G sites sequences. To examine the performance of the proposed model, cross-validation can be used. The dataset is split into three sections: 20% for testing, 10% for validation, and 70% for training.

#### *2.2. Encoding Scheme*

The one-hot encoding approach is a simple but useful feature extraction technique, frequently used in deep learning, but shows effective performance in bioinformatics [19] and computer science [20]. It is employed to illustrate the nucleotide acid composition along the RNA/DNA sequence. In previous studies [21–25], one-hot encoding was employed. In this encoding technique 'A', 'C', 'G', and 'U' are represented by binary vectors of (1, 0, 0, 0), (0, 1, 0, 0), (0, 0, 1, 0), and (0, 0, 0, 1), respectively. As a result, an n nucleotide RNA/DNA sequence is encoded as a 4×n dimensional binary vector, which is used as input to the CNN and LSTM models in this study. The vector has a length of *n* = 41 nucleotides. Figure 1 demonstrates the graphical representation of the one-hot encoding scheme.


**Figure 1.** Description of the one-hot encoding scheme.

#### *2.3. Distributed Attributes Formulation*

The distributed attributes formulation scheme reduces the classification error of the computational model by obtaining noiseless data. Genetic data are usually expressed as biological sequences; hence, it may be thought of as a language through which information moves between cells. Natural language processing (NLP) has been used for a variety of biological problems in this area, such as EP2vec [26], alternative splicing site [27], G2Vec [28], and iN6-Methyl (5-step) [29]. We approach this sequence analysis problem from a new angle, manipulated by NLP. Indeed, there are several effective deep learning applications in the NLP, i.e., word2vec, which embeds words into a vector space. The paragraph vector is built on word2vec, and it embeds whole phrases into vectors that encode their semantic content. In this regard, treating the sequence of RNA/DNA as a sentence rather than

an image is more natural since DNA sequences are just one-dimensional data, whereas images are frequently two-dimensional data. Consequently, we consider a DNA/RNA sequence to be a sentence made up of k-mers (or words) [26]. Here, an NLP-based method, i.e., word2vec, is applied to obtain decipherable demonstrations for RNA sequences. For discontinuity, the RNA sequences are first fragmented into multiple words represented by overlapping k-mers. Here, the value of k = 3 indicates a 3-mer. Commonly, genomes are collected from the Genbank databank by using the following link: http://hgdownload. soe.ucsc.edu. The genome is split into distant 21 chromosomes (C1, C2, C3, C4, C5, ... , C20, and C21). Additionally, the chromosome is fragmented with sentences of 100 nt residues. Lastly, the words are created by cutting each sentence into overlapping 3-mers. The word2vec model is trained using the continuous bag-of-words (CBOW) technique. The current word w(t) is predicted using the context words around it in a predetermined frame in the CBOW technique. Table 1 shows the training parameters of the word2vec model. Finally, each 3-mer word is expressed by a 100-dimensional vector, and each sequence of length L is represented by an array of shapes (L − 2) × 100.

**Table 1.** Training parameters of word2vec.


#### *2.4. Convolutional Neural Networks (CNN)*

A CNN is a deep learning algorithm frequently used in image processing, natural language processing, and bioinformatics studies [30–37]. In image data, CNN works with two-dimensional; however, CNN can also be employed with three-dimensional and one-dimensional data. In this regard, the 1D (1-dimensional) CNN model in the field of bioinformatics is effectively applied [24,38–41]. A CNN comprises one input layer, multiple hidden layers such as pooling layers, ReLU (activation function) layers, convolutional layers, normalization layers, fully connected layers, and an output layer. In this study, various optimal hyper-parameters, such as the size of the masks [3, 5, 7, 9, 11, 13, and 15], the number of masks [4, 6, 8, 10, 12, 14, and 16], and convolution layers [1, 2, and 3] are used for training CNN model. The dropout probability range was [0.2, 0.25, 0.3, 0.35, 0.4, 0.5, 0.6, 0.7, and 0.75]. The selections of hyper-parameters are performed on the best success rates in terms of all performance metrics to discriminate N7-methylguanosine sites. The convolution layer, ReLU layer, and sigmoid function are mathematical as follows:

$$\text{Conv1D}(R)\_{\hat{\jmath}f} = \text{ReLU}\left(\sum\_{s=0}^{S-1} \sum\_{n=0}^{N-1} \mathcal{W}\_{sn}^{f} R\_{\hat{\jmath}+s,n}\right) \tag{2}$$

In Equation (2), *R*, *f*, and *j* stand for the input, filter index, and output index position, respectively. *N* shows the number of input channels, and *S* denotes the size of the window.

Dense layer with dropout: The scalar output score of the dense layer is transformed from the feature vector *z*.

$$f = w\_{d+1} + \sum\_{k=1}^{d} w\_k z\_k \tag{3}$$

$$f = w\_{d+1} + \sum\_{k=1}^{d} m\_k w\_k z\_k \tag{4}$$

In Equations (3) and (4), *wd*+1 represents the term of additive bias and the previous layer *zk* weight is *wk*. The rectified linear function is denoted by ReLU and mathematically stated in Equation (5).

$$\text{ReLU}(z) = \max(0, z) \tag{5}$$

As its output is scaled to the [0, 1], the sigmoid function is responsible for predicting whether a given sequence is an m7G site or not. Equation (6) expresses the sigmoid function mathematically.

$$\text{Gigmoid}(z) = \frac{1}{1 + e^{-z}} \tag{6}$$

#### *2.5. Long Short-Term Memory Layer (LSTM)*

The recurrent neural network (RNN) is a type of deep learning that can learn only from sequential data such as time-series data and textual data [42]. However, it has the issue of gradient vanishing, and thus the parameters are not updated during the backpropagation [36,43–46]. Therefore, LSTM is a type of RNN that may store information regarding long-distance data dependence and added gating function by addressing the issue of RNN gradient [30,47,48]. LSTM gating mechanisms enable the network to effectively decide to keep it remember or ignore it. Furthermore, speech recognition and language translation also have great contributions [49–51]. Figure 2 illustrates our proposed model, which is composed of an input layer, two LSTM layers, and followed by a dense layer.

**Figure 2.** The proposed m7G-LSTM computational model Schema.

The first LSTM layer has an output channel of size 32, which is fed into the second LSTM layer, where the second layer has 64 output channels. Moreover, the dropout rate of 35% is applied to the input connection within the LSTM layers. The outcome of the second LSTM layer is flattened and passes to the dense layer. The dense layer is a fully-connected layer with x output channels, and it is followed by a sigmoid activation function. Finally, the sigmoid function generates the outcomes.

The proposed model has been trained as follows. Let *xi* be a vector demonstrating the input RNA sequence (Equation (7)). The LSTM computes *zi* for *xi* (Equation (9)). Sigmoid (Equation (10)) changes *zi* to a vector of values between 0 and 1. The loss is the binary cross-entropy of the prediction (Equation (10)). It is used for updating the hidden neurons at the hidden layer utilizing the Adam optimization algorithm, with 0.0005 being set as the learning rate.

$$\mathbf{x}\_{l} = RNA \text{ sequence where } \mathbf{x}\_{l} \in \{\mathbf{A}, \mathbf{C}, \mathbf{G}, \mathbf{U}\} \tag{7}$$

$$y\_i = \begin{cases} \text{ } 0 \text{ if } \mathbf{x}\_i = \text{non-m}\mathsf{Z}\mathsf{G} \text{ sites} \\ \text{ } 1 \text{ if } \mathbf{x}\_i = \text{m}\mathsf{Z}\mathsf{G} \text{ sites} \end{cases} \tag{8}$$

$$z\_{i} = lstm(\mathbf{x}\_{i})\tag{9}$$

$$sigmoid(z\_i) = \frac{1}{1 + e^{-y}} \tag{10}$$

#### *2.6. Evaluation Parameters*

In the literature [15,52–63], the following four equations were employed to measure the prediction performance of the computational method: specificity (sp), sensitivity (sen), accuracy (acc), MCC (Matthew's correlation coefficient), and auROC. In the below equations 'FP' is a false positive, 'TP' is a true positive, 'TN' is a true negative, and 'FN' is a false negative.

$$\begin{cases} Sp = \frac{TN}{\text{TN} + \text{FP}} \times 100 \\ \text{Sen} = \frac{TP}{\text{TP} + \text{FN}} \times 100 \\ \text{Acc} = \frac{TN + \text{TP}}{\text{FN} + \text{TP} + \text{TN} + \text{FP}} \times 100 \\ \text{MCC} = \frac{\text{TP} \times \text{TN} - \text{FP} \times \text{FN}}{\sqrt{(\text{TP} + \text{FP})(\text{TP} + \text{FN})(\text{TN} + \text{FP})(\text{TN} + \text{FN})}} \times 100 \end{cases} \tag{11}$$

Accuracy: assesses the precision of a computational algorithm for distinguishing m7G sites and non-m7G sites. Sensitivity and specificity are the true positive (TP) and true negative (TN) rates of a test. MCC reveals the correlation between target classes in the case of the imbalanced dataset; here, the ratio of both classes is the same. The area under the ROC curve (auROC) is another measurement metric that shows the predicted outcomes of the model. The auROC indicates the quality of the model. In the above equation, FN and FP denote false negative and false positive, respectively.

#### **3. Results and Discussion**

An intelligent computational method, namely m7G-LSTM, is designed based on a natural language processing approach, i.e., word2vec, in combination with the deep learning algorithm LSTM. The efficiency is reported on the basis of various measuring metrics, which are mentioned above. The proposed m7G-LSTM model has an accuracy of 95.95%, specificity of 95.97%, the sensitivity of 95.94%, MCC of 0.919, and auROC of 0.980 on the LSTM model, whereas the CNN model achieved 94.94% accuracy, 93.28% specificity, 96.62% sensitivity, 0.899 MCC, and 0.979 auROC. Table 2 shows the detailed projected outcomes of the proposed predictor on LSTM and CNN.

**Table 2.** Performance of CNN and LSTM models using cross-validation.


Figures 3 and 4 show the auROC as well as a graphical illustration of the confusion matrix.

**Figure 3.** The auROC curve of the m7G-LSTM model.

We compared our proposed m7G-LSTM model to the state-of-the-art models, such as m7GFinder, m7G-IFL, and iRNA-m7G, to calculate its predictive performance. In this study, we develop two various deep learning-based approaches i.e., LSTM and CNN approach to present the proposed m7G-LSTM model. As a result, we compared the highest-performing model of the m7G-LSTM to the most recent m7G-IFL model. To achieve a valid comparison, we execute and evaluate the proposed model on the same benchmark dataset as existing models. Table 3 summarizes the prediction performance. Our computational model, as can be shown, outperforms the other three models, with an accuracy of 95.95%, the sensitivity of 95.94%, specificity of 95.97%, and MCC of 0.919, respectively. Our proposed m7G-LSTM computational model outperformed the existing latest m7G-IFL predictor by 3.54% in specificity, 3.37% in sensitivity, 3.45% in accuracy, and 0.069 in MCC. We observe that our predictor outperforms other models, with an AUROC of 0.980. Our m7G-LSTM model improves upon state-of-the-art prediction models for predicting m7G site modification, based on the results.


**Table 3.** Model comparison between the proposed m7G-LSTM and current models.

The graphical depiction of the execution outcomes is demonstrated in Figure 5 in which the m7G-LSTM method achieved remarkable results compared to the current prediction models. This shows the significance of our proposed model.

**Figure 5.** Comparison of proposed m7G-LSTM computational model with existing methods.

#### **4. Conclusions**

The proposed m7G-LSTM method is a reliable and novel deep learning-based prediction model for m7G sites. The proposed model utilized the distributed feature representations that are exploited by the LSTM model. Two-feature encoding schemes were used, i.e., word2vec and one-hot encoding. The input of RNA sequence is divided into 3-mers or words in feature representation, and each word is mapped to its corresponding feature representation using the NLP method, i.e., word2vec. The one-hot encoding converts categorical data to binary data that can be processed by computational models efficiently. Then, the LSTM and CNN models were applied to identify N7-methylguanosine sites, but the LSTM model produced better performance than the CNN model. In terms of all performance measures for discriminating N7-methylguanosine sites, the m7G-LSTM model highly outperforms state-of-the-art models, according to the prediction results. The predicted outcome demonstrates that the proposed m7G-LSTM computational system is efficient and reliable and that it might be useful in drug-related applications and academics.

**Author Contributions:** Conceptualization, M.T.; methodology, M.T.; software, M.H. and R.K.; validation, M.T., M.H., R.K. and K.T.C.; resources, K.T.C.; writing—original draft, M.T. and R.K.; writing—review & editing, M.T., M.H. and K.T.C.; visualization, M.T. and R.K.; supervision, M.H. and K.T.C.; project administration, K.T.C.; funding acquisition, K.T.C. All authors have read and agreed to the published version of the manuscript.

**Funding:** This work was supported by the National Research Foundation of Korea (NRF) grant funded by the Korean government (MSIT) (No. 2020R1A2C2005612).

**Conflicts of Interest:** The authors declare no conflict of interest.

#### **References**


### *Article* **Transfer Learning-Based Automatic Hurricane Damage Detection Using Satellite Images**

**Swapandeep Kaur 1, Sheifali Gupta 1, Swati Singh 2, Vinh Truong Hoang 3, Sultan Almakdi 4, Turki Alelyani 4,\* and Asadullah Shaikh <sup>4</sup>**


**Abstract:** After the occurrence of a hurricane, assessing damage is extremely important for the emergency managers so that relief aid could be provided to afflicted people. One method of assessing the damage is to determine the damaged and the undamaged buildings post-hurricane. Normally, damage assessment is performed by conducting ground surveys, which are time-consuming and involve immense effort. In this paper, transfer learning techniques have been used for determining damaged and undamaged buildings in post-hurricane satellite images. Four different transfer learning techniques, which include VGG16, MobileNetV2, InceptionV3 and DenseNet121, have been applied to 23,000 Hurricane Harvey satellite images, which occurred in the Texas region. A comparative analysis of these models has been performed on the basis of the number of epochs and the optimizers used. The performance of the VGG16 pre-trained model was better than the other models and achieved an accuracy of 0.75, precision of 0.74, recall of 0.95 and F1-score of 0.83 when the Adam optimizer was used. When the comparison of the best performing models was performed in terms of various optimizers, VGG16 produced the best accuracy of 0.78 for the RMSprop optimizer.

**Keywords:** hurricane; damage; undamaged; emergency managers; transfer learning; satellite images

### **1. Introduction**

There has been a steady increase in the occurrence of natural disasters since 1980 globally. The number of people that are prone to disasters is also increasing [1]. Amongst natural disasters, the most catastrophic disaster includes hurricanes that occur in areas with warm seawaters that are in tropical and subtropical areas. The sun heats seawater, leading to the formation of enormous clouds, which cause excessive rainfall, floods and very fastmoving winds [2]. Damage of approximately USD 265 billion was estimated in the US in the year 2017 due to three major hurricanes (Harvey, Maria and Irma). These hurricanes affected thousands of people and caused many fatalities. During such difficult times, the affected people required assistance. Hence, it is very essential assess the destruction brought about due to hurricanes [1].

Satellite images have been used to determine whether there has been damage inflicted by the hurricane or not. Satellite images have been gaining immense popularity for monitoring hurricanes. Ground surveys are time-consuming and also labor-intensive [3].

In artificial intelligence, transfer learning is a technique that involves reusing an already trained model on a different but related problem. This technique is now being popularly used in deep learning when the dataset is not large. This technique helps in the

**Citation:** Kaur, S.; Gupta, S.; Singh, S.; Hoang, V.T.; Almakdi, S.; Alelyani, T.; Shaikh, A. Transfer Learning-Based Automatic Hurricane Damage Detection Using Satellite Images. *Electronics* **2022**, *11*, 1448. https:// doi.org/10.3390/electronics11091448

Academic Editor: Abdeldjalil Ouahabi

Received: 4 April 2022 Accepted: 28 April 2022 Published: 30 April 2022

**Publisher's Note:** MDPI stays neutral with regard to jurisdictional claims in published maps and institutional affiliations.

**Copyright:** © 2022 by the authors. Licensee MDPI, Basel, Switzerland. This article is an open access article distributed under the terms and conditions of the Creative Commons Attribution (CC BY) license (https:// creativecommons.org/licenses/by/ 4.0/).

reduction in resources and the labeled data required for training newer models. It helps reduce training times [4].

A convolutional neural network consists of feature extraction as the first stage and classification as the final stage. In transfer learning, the classification stage is altered. The initialization of the network has been performed with weights from the ImageNet dataset [4]. The convolutional layers and the max-pooling layers are frozen so that no modification of weights takes place. Only the dense or the fully connected layers are left free to be altered. After this, the retraining of the model is performed. The advantage of the feature extraction stage is taken, and only the final classifier is tuned, which works better with smaller datasets. This is the reason for why it is called transfer learning, as the advantage of the knowledge of one problem can solve the second problem [5].

This paper involves the study of hurricane damage detection using satellite images. The estimation of the intensity of hurricanes was performed by using deep convolutional neural networks. Infrared images were used for estimation and were obtained from the satellite source. The adopted method is known as Deep Phurie, and it produced a very low root mean square (RMS) value in comparison with the method adopted earlier, which is known as Phurie. Deep Phurie is completely automatic, but this paper does not evaluate the damage post-disaster [2]. Furthermore, deep convolutional neural networks were used to estimate of the intensity of the tropical cyclones or hurricanes that took place over a period from 1998 to 2012. Regularization techniques were used along with many convolutional and dense layers. This technique helped in extracting features from hurricane images effectively. A low RMS value and an improved accuracy were obtained [6]. However, these data were noisy and not of good quality.

A multilayer perceptron was proposed for the determination of the connection between the appearance of hurricanes and the high-energy particles that flow out from the sun. A multilayer perceptron is an artificial neural network accompanied by backpropagation. It was found that hurricane appearances could take place a few days before the breakout of a solar wind [7].

As a deep learning method, a single-shot multibox detector (SSD) was employed for the calculate of the destruction inflicted on buildings due to hurricane Sandy, which occurred in the year 2012. The Vgg16 model and the SSD model were used, and improvements of 72% and 20% in mAp and mF1, respectively, were observed [8]. The CNN model was used to determine areas that were severely affected by Hurricane Harvey. Satellite images were used for the extraction of man-made features such as roads and buildings before and after the occurrence of the disaster. An F1 score of 81.2% was achieved [9].

Damage assessment after a hurricane is of utmost importance. In this paper, the author created a benchmark dataset for the property that became damaged by Hurricane Harvey. The dataset consisted of both undamaged and impaired building images, and they were obtained from satellite imagery. FEMA and TOMNOD were the sources of this dataset [10].

The destruction brought about because of hurricane Dorian has been determined using satellite imagery and artificial intelligence. The austerity of the destruction caused due to the hurricanes has been determined, and an accuracy of 61% was achieved [11].

Earlier studies focused on finding the intensity of hurricanes and providing a benchmark dataset for damage detection. Fewer studies have focused on classifying hurricane images into damaged and undamaged classes. In this paper, a comparative analysis of the four transfer learning models that include DenseNet121 [12], VGG16 [13], MobileNetV2 [14] and InceptionV3 [15] has been performed with respect to confusion matrix parameters. These models have also been used for determining the destruction brought about on buildings because of Hurricane Harvey.

The objectives of this study include the following:


The rest of the paper is organized as follows: proposed methodology in Section 2; results and discussion in Section 3; conclusion and future scopes in Section 4.

#### **2. Proposed Methodology**

The model that has been presented for automatic damage detection due to hurricanes is shown in Figure 1. The platform used to create and run the algorithm is Kaggle. The model classifies satellite images into damaged and undamaged categories. The methodology comprises two main steps: The first is preprocessing [16,17], which is further divided into normalization and data augmentation, and the second is classification using the pre-trained CNN models. Each stage has been described below.

**Figure 1.** Block diagram of the adopted methodology.

#### *2.1. Preprocessing*

The satellite images of the Houston region used in this study were captured by optical sensors. The images could be covered with clouds either partially or fully. This implies that the images obtained from the satellites have been corrupted by noise. The nature of the noise is unknown, meaning that it could be a result of fluctuations in light, the sensor of the camera or artifacts. Improving the quality of the images so that good results can be obtained is imperative. For this purpose, a denoising operation needs to be performed, which could be based on wavelets [18] or can be acquired from a compressive sensing method [19].

For the suppression of unwanted distortions or enhancement of some of the features of the images, pre-processing steps such as resizing were used. The original size of the satellite images of hurricanes is 128 × 128. The resizing of the satellite images of the hurricane was performed. The resizing of the images was performed at 224 × 224 when Vgg16, MobileNetV2 and DenseNet121 transfer learning techniques were applied. The images were resized to 299 × 299 on the application of the InceptionV3 technique.

The two main steps of the preprocessing stage, which include normalization and data augmentation, have also been explained in this section.

#### 2.1.1. Normalization

Normalization is a very important step for maintaining numerical stability in a model. Normalization helps in learning faster and brings about stability in gradient descents. The input image pixels have, thus, been normalized in the values between 0 and 1. Normalization is brought about by multiplying pixel values by 1/255.

#### 2.1.2. Data Augmentation

The augmentation of data is a technique utilized to generalize the model by applying random transformations with respect to input images [20,21]. It increases the variability and robustness of the model as the model becomes new and modified versions of the input data. An image data generator is utilized for augmenting the data, which is an on-the-fly data augmentation method because augmentation is performed during training time. The image data generator returns only the randomly modified images and not the original images. Data augmentation has been applied only to training images and not to testing images.

The techniques adopted for data augmentation in this study are rotation, width shifting, height shifting, horizontal flipping and zoom operation.

#### *2.2. Hurricane Damage Detection Using Pre-Trained CNN Models*

In this paper, four pre-trained models, which include VGG16 [22], MobileNetV2 [23], InceptionV3 [24] and DenseNet121 [25], have been used for classifying satellite images into damaged and undamaged classes.

Transfer learning models are models trained on very large datasets that include millions of images. As the models have been trained on such a large dataset, a generalization of the model takes place. The features that have been learned from the larger datasets help in solving a different problem consisting of lesser data or a smaller dataset. This helps eliminate the need to train a model from scratch.

The description of the architectures of these models is shown in Table 1.


**Table 1.** Description of architecture of pre-trained CNN models.

The VGG16 model comprises 16 layers that have weights and has approximately 138 million parameters. There are 13 convolutional layers and 3 fully connected layers. The VGG16 model is widely used because of its ease of implementation [22]. MobileNetV2 consists of 53 layers and 3.4 million parameters. It has been derived from the MobileNetV1 model, which utilizes depth-wise convolution as the building block of the model. However, the additional feature from the previous models is that it has an additional inverted residual layer [23]. This model is used because it is smaller in size and also cost-effective. There are nineteen bottleneck layers that were residual. There are 42 layers in the InceptionV3 model and 24 million parameters. InceptionV3 is an advanced version of InceptionV2. It reduces the amount of computations as it uses factorization methods [24]. For InceptionV3, the

input image size is (299, 299, 3). Densenet121 consists of 121 layers with trainable weights. DenseNet121 has 8 million parameters. In this model, the network proceeds deeper as each layer is connected to all the other layers; for example, the first layer is connected to the second, third, fourth and so on layers. This leads to a improved maximum flow of information amongst the layers [25].

#### *2.3. Tuning the Hyper-Parameters*

The training of the four models has been performed for 40 epochs and a batch size of 100. The total epochs refer to how often the learning algorithm will be working through the complete dataset. Batch size refers to the number of training examples that would be utilized in a single iteration [26]. A batch size of 100 implies that 100 samples from the training dataset would be used for the estimation of the error gradient before the weights of the models are updated. The learning rate (LR) is another important hyperparameter that should not be either too big or too small [27]. It is used for finding the learning speed of the proposed models. The model would take a lot more time to reach the minimum loss if the LR is too small and if the LR is too high due to the fact that overshooting the low loss areas can take place. A learning rate of 0.0001 has been chosen in this paper. The batch size, number of epochs [28] and the learning rate have all been decided empirically.

Furthermore, the activation function [29] used is a rectified linear unit (ReLU) [30]. The fully connected head, used along with all the four pre-trained models, is shown in Figure 2. The pre-trained block is followed by a flattening layer and two dense layers. The flattening layer size of the DenseNet121 model is 50,176; for VGG16, the size is 25,088. For the MobileNetV2 model, the flattening layer size is 62,720; for InceptionV3, the size is 131,072. After the flattening layer, a dense layer of size 256 is applied. Finally, a dense layer with two classes that is damaged and undamaged is used.

**Figure 2.** Proposed new fully connected head.

The block diagram of the four pre-trained models, which include DenseNet121, VGG16, MobileNetV2 and InceptionV3, is displayed in Figure 3.

**Figure 3.** Block diagram of pre-trained models: (**a**) DenseNet121; (**b**) VGG16; (**c**) MobileNet; (**d**) InceptionV3.

The block diagram of DenseNet121 is shown in Figure 3a. The input is of 224 × 224 × 3 sizes. This is applied to the DenseNet121 model, and the output obtained is of 7 × 7 × 1024 size. This is then applied to the new fully connected head, which comprises the flattening and dense layers. The output of the flattening layer is of 50,176 in size. The output of the first dense layer is 256, and the last dense layer classifies the images into two classes, which include damaged and undamaged classes.

The block diagram of the VGG16 model is demonstrated in Figure 3b. An input size of 224 × 224 × 3 was applied to the model, and an output of 7 × 7 × 512 was obtained. The output of the flattening layer is 25,088, and the outputs of the dense layers are 256 and 2 in size.

The block diagram of MobileNetV2 has been demonstrated in Figure 3c, for which its input image size is 224 × 224 × 3. This is applied to the model and an output of 7 × 7 × 1280 is obtained. The output after the application of the flattening layer is 62,720, and the outputs of the two dense layers are 256 and 2, respectively.

Figure 3d presents the InceptionV3 model, for which its input image size is 299 × 299 × 3. The output when this input is applied to the model is 8 × 8 × 2048. The output obtained after the flattening layer is of size 131,072.

#### **3. Results and Discussion**

This section includes the results performed on the four pre-trained models that are DenseNet121, VGG16, MobileNetV2 and InceptionV3, considering various parameters and then comparing the results of all four models.

#### *3.1. Result Analysis in Terms of Loss and Accuracy*

The results of the four transfer learning models, which include DenseNet121, VGG16, MobileNetV2 and InceptionV3, in terms of training performance parameters are shown in Table 2. Training loss, training accuracy, training recall, validation loss, validation accuracy, and validation recall values for the models have been shown for various epochs in Table 2.


**Table 2.** Training Performance of the four pre-trained models.

As per Table 2, the highest training accuracy of 0.9727 and training recall of 0.9735 were obtained by the DenseNet121model at the 40th epoch and learning rate of 0.0001. The highest validation accuracy of 0.9670 and validation recall of 0.9658 were obtained by the InceptionV3 model at the 40th epoch. The lowest training loss of 0.0666 and validation loss of 0.0956 has been obtained by the DenseNet121 model at the 40th epoch.

The training and validation accuracies for all the proposed models are demonstrated in Figure 4. The model accuracy for the DenseNet121 model has been displayed in Figure 4a. The model accuracy for the VGG16 model has been shown in Figure 4b; Figure 4c demonstrates the model's accuracy for the MobileNetV2 model, and Figure 4d demonstrates the model accuracy for the InceptionV3 model.

**Figure 4.** *Cont*.

**Figure 4.** Training and validation accuracy: (**a**) DensNet121; (**b**) VGG16; (**c**) MobileNetV2; (**d**) InceptionV3.

Figure 5 demonstrates the training and validation loss for all four models. Figure 5a shows the model loss for DenseNet121, Figure 5b shows the model loss for the VGG16 model, Figure 5c shows the model loss for the MobileNetV2 model and Figure 5d displays the model loss for InceptionV3.

**Figure 5.** Training and validation loss: (**a**) DensNet121; (**b**) VGG16; (**c**) MobileNetV2; (**d**) InceptionV3.

#### *3.2. Confusion Matrix Parameter Result Analysis*

Figure 6 shows the confusion matrix for all four models, which include DenseNet121, VGG16, MobileNetV2 and InceptionV3. The confusion matrix parameters are accuracy, precision, recall and F1-score [31,32], and they can be calculated with the help of the confusion matrix.

**Figure 6.** Confusion Matrix for different models: (**a**) DensNet121; (**b**) VGG16; (**c**) MobileNetV2; (**d**) InceptionV3.

The results of the classification report or the confusion matrix parameters of the four original and modified models in terms of precision, recall, F1-score and accuracy have been displayed in Table 3. From Table 3, it is concluded that the best precision of 0.92 is obtained by the modified DenseNet121 model as compared to the 0.82 precision obtained by the original DenseNet121 model. The best recall of 1.00 is obtained by the modified InceptionV3 model. The best F1-score of 0.83 and accuracy of 0.75 are obtained by the modified VGG16 model. An improvement in the recall is 0.95, and it is obtained by the modified VGG16 model as compared to the recall of 0.85 obtained by the original VGG16 model.


**Table 3.** Comparison of confusion matrix parameters of the original and modified pre-trained models.

Figure 7 compares the original and modified models in terms of the classification report parameters, which include precision, recall, F1-score and accuracy. The best accuracy of 0.75 is obtained by the VGG16 model followed by an accuracy of 0.65 obtained by the InceptionV3 model. The highest F1 score of 0.83 is obtained by the VGG16 model. The highest recall of 1.00 was obtained from the InceptionV3 model followed by a recall of 0.95 in the VGG16 model.

#### *3.3. Comparison of Results of Various Optimizers*

Different optimizers, which include SGD [33], Adadelta [34], Adam [35] and RM-Sprop [36], have been compared for the two best performing models, which include VGG16 and InceptionV3, as shown in Tables 4 and 5 and Figures 8 and 9.

**Table 4.** Comparison of the original and modified VGG16 model for the various optimizers.


**Table 5.** Comparison of the original and modified InceptionV3 model for the various optimizers.


**Figure 8.** Comparison of original and modified VGG16 model.

**Figure 9.** Comparison of original and modified InceptionV3 model.

#### 3.3.1. Comparison of Original and Modified VGG16 Model

Table 4 shows the comparison of the original and modified VGG16 model for the four optimizers, which include SGD, Adadelta, Adam and RMSprop. The highest precision of 0.98 is obtained for the modified VGG16 model for the SGD optimizer. The highest recall of 0.95 is obtained by the modified model for the Adam optimizer. An improved F1-score of 0.83 is obtained for the Adam optimizer. The highest accuracy of 0.78 is obtained by the RMSprop optimizer.

Figure 8 displays the comparison of the original and modified VGG16 model. There is an improvement in accuracy from 0.77 for the original model to 0.78 for the modified model. An improved F1-score of 0.83 is obtained by the modified VGG16 model. The highest recall of 0.95 and precision of 0.98 were obtained by the modified model.

#### 3.3.2. Comparison of Original and Modified InceptionV3 Model

Table 5 presents the comparison of the original and modified InceptionV3 model for four optimizers. It can be inferred from the table that, in the case of the InceptionV3 model, all optimizers produced almost the same results in terms of precision, F1-score and accuracy. For the modified model, an equal precision of 0.65, F1-score of 0.79 and accuracy of 0.65 were obtained for all four optimizers.

Figure 9 compares the results of the original and the modified InceptionV3 model, and it was observed that almost the same results were obtained for the model for all four optimizers.

#### *3.4. Classification and Misclassification Results*

Figure 10 shows the classification and misclassification results for the VGG16 model. Figure 10a shows that the actual class is undamaged and the predicted class is also undamaged. Figure 10b shows that the actual class is damaged and the predicted class is also damaged.

**Figure 10.** Classification and misclassification results. (**a**) True: undamaged predicted: undamaged; (**b**) True: damaged predicted: damaged; (**c**) True: undamaged predicted: damaged; (**d**) True: damaged predicted: undamaged.

Figure 10c,d display the misclassification results. For Figure 10c, the actual class is undamaged and the predicted class is damaged. For Figure 10d, the actual class is damaged and the predicted class is undamaged.

#### *3.5. Comparison with Present State-of-Art Deep Learning Models*

Table 6 presents the comparison of the best Transfer learning model (VGG 16) obtained in this paper with the state-of-the-art deep learning models. The study was performed on 23,000 satellite images of the hurricane. VGG16 obtained an accuracy of 78%, which is more than the other deep learning models presented in Table 6. In reference number [37], work was performed on 1128 hurricane images using a VGG16 model and an accuracy of 64.61% was obtained. A stacked CNN model was used in reference number [11], and an accuracy of 61% was obtained. Work on 61,000 hurricane images was performed in reference [38] using a VGG16 model, and an accuracy of 74% was achieved. Accuracy at 77.85% was obtained by the CNN model comprising five convolutional layers in reference [6]. The hurricane images used were 48,828 in number.

**Table 6.** Comparison with present state-of-the-art deep learning models.


416

#### *3.6. Comparison with Present State-of-Art Machine Learning Models*

This section compares the best Transfer learning model (VGG16) obtained in this paper with commonly used machine learning algorithms. VGG16 achieved an accuracy of 78%, and the best accuracy was achieved with the RMSProp optimizer and 23,000 satellite images. The hurricanes are accompanied by floods. Most authors have worked on floods using machine learning algorithms. Naive Bayes achieved an accuracy of 78.51% and Support Vector Machine achieved an accuracy of 91% when applied to 7500 images [39]. Random forest attained an accuracy of 82% when applied to 201 images [40]. Random forest attained an accuracy of 92% on 255 flood images [41]. The machine learning results are better since the analysis has been conducted on lower numbers of images, whereas the deep learning transfer learning models proposed in this paper have worked on a greater number of images, which involved 23,000 images.

#### **4. Conclusions and Future Scope**

In this paper, four pre-trained models, includingDenseNet121, VGG16, MobileNetV2 and InceptionV3, based on transfer learning have been put forward for the detection of destruction inflicted on buildings due to Hurricane Harvey, which took place in the Greater Houston region in the year 2017. The comparison of the four models has been performed based on training accuracy, training recall, training loss, validation accuracy, validation recall and validation loss. The highest training accuracy of 0.9727 and training recall of 0.9735 was obtained by the DenseNet121model at the 40th epoch and learning rate of 0.0001. The highest validation accuracy of 0.9670 and validation recall of 0.9658 was obtained by the InceptionV3 model at the 40th epoch. The lowest training loss of 0.0666 and validation loss of 0.0956 was obtained by the DenseNet121 model at the 40th epoch.

A comparison was also performed in terms of the classification report's parameters, and it was found that VGG16 outperformed other models by obtaining an accuracy of 0.75, an F1 score of 0.83 and a recall of 0.95.

When the comparison was performed for the best-performing models for various optimizers in terms of the classification report parameters, it was found that VGG16 performed better by obtaining an accuracy of 0.78 for the RMSprop optimizer.

Furthermore, an improvement could be brought in with values of the confusion matrix parameters. Moreover, the model could be made more generalizable by including images of other hurricanes.

**Author Contributions:** Conceptualization and methodology, S.K., S.G. and S.S.; formal analysis, V.T.H., S.A., T.A. and A.S.; software, validation and writing—original draft, S.K., S.G. and S.S.; writing—review and editing and data curation, V.T.H., S.A., T.A. and A.S.; supervision and funding acquisition, S.S., A.S. and T.A. All authors have read and agreed to the published version of the manuscript.

**Funding:** The authors are thankful to the Deanship of Scientific Research at Najran University for funding this work under the Research Collaboration Funding program grant code NU/RC/SERC/11/8.

**Data Availability Statement:** https://www.kaggle.com/datasets/kmader/satellite-images-of-hurricanedamage.

**Conflicts of Interest:** The authors declare no conflict of interest.

#### **References**


### *Article* **Machine Learning Algorithms for Depression: Diagnosis, Insights, and Research Directions**

**Shumaila Aleem 1, Noor ul Huda 1, Rashid Amin 1,\*, Samina Khalid 2, Sultan S. Alshamrani <sup>3</sup> and Abdullah Alshehri <sup>4</sup>**


**Abstract:** Over the years, stress, anxiety, and modern-day fast-paced lifestyles have had immense psychological effects on people's minds worldwide. The global technological development in healthcare digitizes the scopious data, enabling the map of the various forms of human biology more accurately than traditional measuring techniques. Machine learning (ML) has been accredited as an efficient approach for analyzing the massive amount of data in the healthcare domain. ML methodologies are being utilized in mental health to predict the probabilities of mental disorders and, therefore, execute potential treatment outcomes. This review paper enlists different machine learning algorithms used to detect and diagnose depression. The ML-based depression detection algorithms are categorized into three classes, classification, deep learning, and ensemble. A general model for depression diagnosis involving data extraction, pre-processing, training ML classifier, detection classification, and performance evaluation is presented. Moreover, it presents an overview to identify the objectives and limitations of different research studies presented in the domain of depression detection. Furthermore, it discussed future research possibilities in the field of depression diagnosis.

**Keywords:** depression; machine learning (ML); deep learning (DL); regression

### **1. Introduction**

The modern age lifestyle has a psychological impact on people's minds that causes emotional distress and depression [1]. Depression is a prevailing mental disturbance affecting an individual's thinking and mental development. According to WHO, approximately 1 billion people have mental disorders [2] and over 300 million people suffer from depression worldwide [3]. Depression prevails in suicidal thoughts in an individual. Around 800,000 people commit suicide annually. Therefore, it requires a comprehensive response to deal with the burden of mental health issues [4,5]. Depression may harm the socioeconomic status of an individual. People suffering from depression are more reluctant to socialize. Counseling and psychological therapies can help fight depression. Machine learning (ML) aims at creating algorithms that are equipped with the ability to train themselves to perceive complex patterns. This ability helps to find solutions to new problems by using previous data and solutions. ML algorithms implement processes with regulated and standardized outcomes [6,7]. Broadly, ML algorithms are categorized into supervised learning, unsupervised learning, semi-supervised learning, and reinforcement learning algorithms. The supervised ML algorithms [8] utilize main inputs to predict known values, whereas the unsupervised ML algorithms [9] divulge unidentified patterns and clusters

**Citation:** Aleem, S.; Huda, N.u.; Amin, R.; Khalid, S.; Alshamrani, S.S.; Alshehri, A. Machine Learning Algorithms for Depression: Diagnosis, Insights, and Research Directions. *Electronics* **2022**, *11*, 1111. https://doi.org/10.3390/ electronics11071111

Academic Editor: Antoni Morell

Received: 24 February 2022 Accepted: 25 March 2022 Published: 31 March 2022

**Publisher's Note:** MDPI stays neutral with regard to jurisdictional claims in published maps and institutional affiliations.

**Copyright:** © 2022 by the authors. Licensee MDPI, Basel, Switzerland. This article is an open access article distributed under the terms and conditions of the Creative Commons Attribution (CC BY) license (https:// creativecommons.org/licenses/by/ 4.0/).

within the given data. Semi-supervised learning [10] is concerned with the working of systems by combining both labeled and unlabeled data, and it lies between supervised and unsupervised learning. Reinforcement learning [11] is concerned with interpreting the environment to undergo desired actions and exhibiting outcomes through trial and error. The applications of ML techniques in healthcare have proven to be pragmatic as they can process a huge amount of heterogeneous data and provide efficient clinical insights. ML-based approaches provide an efficient understanding of mental conditions and assist mental health specialists in predictive decision making [12]. ML techniques benefit the prediction and diagnosis in the healthcare domain by generating information from unstructured medical data. The prediction outcomes help to identify high-risk medical conditions in patients for early treatments [13]. In mental disorders, ML techniques help arbitrate the potential behavioral biomarkers [14] to assist healthcare specialists in predicting the contingencies of mental disorders and administering effective treatment outcomes. The techniques help the visualization and interpretation of complex healthcare data. The visualization helps develop an effective hypothesis regarding the diagnosis of mental disorders. The traditional clinical diagnostic approach for depression does not accurately identify the depression complexity. The composition of the symptoms related to mental disorders such as depression can easily be detected and anticipated by utilizing ML methods. Therefore, the ML-based diagnostic approach seems to be an efficient choice for predictive analysis. In the healthcare sector, the major domains used for extracting observations associated with mental disorders through ML can be classified as sensors, text, structured data, and multimodal technology interactions [14]. The sensors data can be analyzed using mobile phones and audio signals. The text sources can be extracted through social media platforms, text messages, and clinical records. The structured data constitute the data extracted from standard screening scales, questionnaires, and medical health records. The multimodal technology interactions include data from human interactions with everyday technological equipment, robot, and virtual agents. The ML approaches can be used to assist in diagnosing mental health conditions. The majority of the studies analyze Twitter data [15–17] and sensors data from mobile devices [18,19] for identifying mood disorders. Analyzing textual data can help extract diagnostic information from the individual's psychiatric records [20]. ML approaches can help to predict risk factors in patients with mental disorders. The analysis of sensor data [20], clinical health records [21,22], and text message data [23] can help predict the severity of mental disorders and suicidal behaviors. Various studies have been put forward to aid medical specialists in identifying depression and multiple other mental disorders. The domain of mental disorders comprises a diverse range of mental illnesses. However, this review paper aims attention at the methods presented for the detection of depression. This review paper focuses on elaborating the ML approaches and algorithms used to diagnose and detect depression in individuals. The paper briefly presents the objectives and limitations of the reviewed studies in depression diagnosis, which will help analyze and recognize the best ML approach for a depression diagnosis. The analysis presented in this review paper can help medical specialists and clinicians choose a suitable diagnosis approach for patients with depression. This review paper presents the following: Significant studies extract mental health-related insights. A general model for depression diagnosis involving data extraction, pre-processing, training ML classifier, detection classification, and performance evaluation is considered. An overview of different ML algorithms to diagnose depression by categorizing these depression detection algorithms into three classes, i.e., classification, deep learning, and ensemble. We discussed the limitations of the reviewed studies in the depression diagnosis domain and a better understanding of the choice of the ML approach for depression diagnosis for clinicians and healthcare professionals. Future research possibilities in the domain of depression diagnosis are listed. The organization of the remaining sections of this paper is as follows: Section 2 consists of a brief description of the past studies. The methodology for depression diagnosis is explained in Section 3. Section 4 describes the depression detection model. Section 5 explains the future direction in the domain of depression diagnosis. Section 6 describes the conclusion of this review.

#### **2. Related Work**

Over the years, there have been numerous studies on the use of ML to amplify the scrutiny of mental disorders. In [24], the authors present a history of depression, imaging, and ML approaches. It also provides reviews on researchers that have used imaging and ML to study depression. The algorithms under review are SVM (linear kernel), SVM (nonlinear kernel), and relevance vector regression. Only one mental health domain (MHD) is used to analyze in this survey. This study did not mention depression screening scales, and there is no comprehensive comparison of algorithms. Garcia et al. [25] surveyed mental health monitoring systems (MHMS) using ML and sensor data in mental disorders. This study also analyzed supervised, unsupervised, semi-supervised, transfer, and reinforcement learning which were applied in the domains of mental well-being, including depression, anxiety, bipolar disorder (BD), migraine, and stress. However, the study only presents a brief review of the cases about MHMS and applications. Gao et al. [26] compared MLbased brain imaging classification and prediction research studies for diagnosing. Major depression disorder (MDD) and BD were analyzed, combined with the utilization of the MRI data. SVM, LDA, GPC, DT, RVM, NN, and LR algorithms are under review in this study. However, depression screening scales used in different studies are not mentioned. It only focuses on MDD and BD-based research studies. Gyeongcheol et al. [27] analyzed five ML algorithms; SVM, Gradient Boosting Machine (GBM), RF, Naïve Bayes, and KNN were applied in the domains of mental disorders. It included PTSD, schizophrenia, depression, ASD, and BD studies. This study reviewed the limited number of ML algorithms and did not specify the advantages of using a particular ML approach.

In [28], the authors analyzed Facebook data to detect depression-relevant factors. The Facebook user's data were analyzed using LIWC. Four supervised learning ML approaches were applied to the acquired data: DT, KNN, SVM, and an ensemble model. Experimental results indicated that DT yielded better classification accuracy. Liu et al. [29] presented a brief review of generic AI-based applications for mental disabilities and an illustration of AI-based exploration of biomarkers for psychiatric disorders. The study [30] reviewed three major approaches for brain analysis for psychiatric disorders, magnetic resonance imaging (MRI), electroencephalography (EEG), and kinesics diagnosis, along with five AI methods, Bayesian model, LR, DT, SVM, and DL. In [31], authors have used DL methodology to extract a representation of depression cues in audio and video to detect depression.

This review has introduced the databases and described objective markers for automatic depression estimation (ADE) to sort out and summarize their work. Furthermore, they reviewed the DL methods (DCNN, RNN, and LTMS) for automatic depression detection to extract the representation of depression from audio and video. Finally, they have discussed challenges and promising directions related to the automatic diagnosis of depression using DL approaches. Table 1 illustrate the overview of different studies.


#### **Table 1.** Overview of different studies.

#### **3. Methodology for Depression Diagnosis**

The detection methodology involves a series of processes, including the data extraction, the pre-processing of the extracted data, feature extraction methods for selecting the required set of features for identifying symptoms of depression, and ML classifiers for classifying the input data into defined data categories. This section discusses each of these steps and the different methods and approaches used for implementing each step.

#### *3.1. Pre-Processing Algorithms*


(4) Hidden Markov Model (HMM): HMM is a probabilistic model used to capture and describe information from observable sequential symbols. In HMM, the observed data are modeled as a series of outputs generated by several internal states [34].

#### *3.2. Feature Extraction Methods*

Feature selection is a technique in which those features are selected that are the most accurate predictors of the target variable.


#### *3.3. Supervised Learning Classifiers*

In supervised learning, the specific format is used for the training dataset. Each instance is assigned a label. Datasets are labeled as (x, y) belongs to X, Y where x and y denote a data point. The problem is a classification task if the output y belongs to a discrete domain. If the output is a part of the continuous domain, it is a regression task. The tasks predict the value of the dependent attribute from the variables.

#### 3.3.1. Classification


#### 3.3.2. Regression

It is used to comprehend the connection between reliant and free factors. It is generally used to make projections, for example, for deal income for a given business. Linear regression and logistical regression are popular regression algorithms.


#### 3.3.3. Deep Learning

Deep learning is a type of ML that enhances computers to gain for a fact and comprehend the world as far as a hierarchy of ideas. The hierarchy of ideas permits the computer to learn confounded ideas by building them out of more straightforward ones; a graph of these hierarchies would be many layers deep. In image processing and computer vision with applications, such as scene understanding, clinical image investigation, robotic perception, augmented reality, video surveillance, and image compression, image segmentation is a key idea. Because of the achievement of DL models in a wide scope of vision applications, there has been a generous measure of works pointed toward creating image segmentation approaches utilizing DL models.

Neural Networks:

The neural network is a classifier that stimulates the human brains and neurons; neural networks (NNs) or artificial neural networks (ANN) are based on a collection of process units (e.g., nodes, neurons, or process layers). The processing unit receives signals from other neurons, combines, transforms them, and generates results.


#### **4. Depression Detection Models**

Depression is a type of mental illness which brings a serious burden to individuals, families, and society. Conferred by the WHO, depression will be the most common mental illness by 2030 [44]. In difficult situations, depression leads to suicide. Currently, there is no efficient clinical characterization of depression. It makes the diagnosing process restricted and biased. Diagnosing depression is complicated, depending not only on the educational background, cognitive ability, and honesty of the subject to describe the symptoms but also on the experience and motivation of the clinicians. Comprehensive information and thorough clinical training are needed to diagnose the severity of depression accurately [10]. Hence, in recent years, numerous automatic depression estimation (ADE) systems have been introduced to automatically estimate the severity scale of depression by using different ML algorithms. Figure 1 illustrates various ML algorithms for the diagnosis of depression.

#### *4.1. Classification Models*

This section highlights the classification supervised learning models used in several studies for diagnosing depression. A mobile application, Mood Assessment Capable Framework (Moodable), has been presented in [45] to interpret voice samples, data from smartphone and social media handles, and Patient Health Questionnaire (PHQ-9) data for assessment of an individual's mood, mental health, and inferring symptoms of depression by using ML classifiers SVM, KNN, and RF. The framework achieved 76.6% precision for depression assessment. The authors used six ML classifiers, KNN, Weighted Voting classifier, AdaBoost, Bagging, GB, and XGBoost, in [46], to predict depression. SelectKBest, mRMR, and Boruta feature selection techniques were used for feature extraction. For reducing imbalanced classes, SMOTE was applied. They used a dataset of 604 individuals, including the sociodemographic and psychosocial data and the Burns Depression Checklist (BDC) data, among which 65.73% depression prevalence was identified. The analysis indicated that the AdaBoost classifier achieved the highest classification accuracy of 92.56% when used with the SelectKBest algorithm.

An ML model using the RF algorithm has been implemented for the prognosis of depression among Korean adults in [47]. SMOTE was applied for class balancing between two classes: depression and non-depression. CES-D-11 was used as a depression screening scale where 10-fold cross-validation was utilized to tune the hyperparameters. A total of 6588 Korean citizen's data were included in the study; AUROC value was calculated as 0.870 and achieved an accuracy of 86.20%. However, in this study, biomarkers were not included in the dataset. The authors used three ML algorithms, KNN, RF, and SVM, in [48], to diagnose depression among Bangladeshi students. The study aimed at predicting depression at early stages using related features to avoid drastic incidents. The analysis performed over 577 students' data indicated that the Random Forest algorithm detected the symptoms of depression in the students with 75% accuracy and 60% f-measure.

In [49], ensemble learning and DL approaches have been applied to electroencephalography (EEG) features for detecting depression. Deep Forest (DF) and SVM classifiers were used for feature transformation. Image conversion and CNN were used for feature recognition from the EEG spatial information. The ensemble model with DF and SVM obtained 89.02% classification accuracy and the DL approach achieved 84.75% accuracy. In [50], ML algorithms DT, RF, Naïve Bayes, SVM, and KNN were used to predict stress, anxiety, and depression. The Depression, Anxiety, and Stress Scale questionnaire (DASS 21) analyzed 348 individuals' data. The analysis indicated that Naïve Bayes achieved the highest accuracy of 85.50% for predicting depression. Based on F1 scores, the RF algorithm was more efficient in the case of imbalanced classes. In [51], the author used the sentiment and linguistic analysis with ML to discriminate between depressive and non-depressive social content. RF with RELIEFF feature extractor, LIWC text-analysis tool, and the Hierarchical Hidden Markov Model (HMM) and ANEW scale were used to analyze 4026 social media posts with an accuracy of 90% depressive posts classification, 92% depression degree classification, and 95% depressive communities classification. However, this study takes all depression categories as a single class. Sharma et al. [52] used the XGBoost algorithm on data samples to diagnose mental disorders in the given data. Different sampling techniques were applied to the dataset. The dataset used in this study had imbalanced classes. The study achieved more than 0.90 values for accuracy, precision, recall, and F1 score.

Generalized Anxiety Disorder (GAD) is difficult to perceive and distinguish from major depression (MD) in a clinical framework. In [53], a multi-model ML algorithm was presented to distinguish GAD from MD using structural MRI data and clinical and hormonal information. Conclusively, MRI data provided accumulative data to the GAD classification. However, the sample size and accuracy needed to be increased, and the groups were unbalanced. Xiang et al. [54] used a multikernel SVM with minimum spanning tree (MST) and Kolmogorov–Smirnov test for feature selection. The proposed approach provided a conducive network analysis. A total of 38 MDD patients and 28 healthy controls were included in the dataset. The presented approach achieved 97.54% accuracy. Table 2 presents a comparison of different classification models used for the diagnosis of depression.


**Table 2.** Comparison of different classification models for depression diagnosis.

Discussion of Classification Models

The multikernel SVM proposed in [54] with a high-order MST achieved the highest 97.54% MDD classification accuracy among the reviewed studies. The multikernel SVM model provides dynamic changes in the functional association between brain fragments. The integration of multiple kernels can enhance classification. Another model with an efficient classification accuracy was presented in [46], which achieved 92.56% classification accuracy using the AdaBoost with SelectKBest feature selection method and SMOTE for balancing the classes. AdaBoost falls under the category of DT Ensemble. By comparing both the studies [46,54], it can be concluded that in [46], no biomarker was included in the dataset, while in [54], the dataset used was limited and there was no identification of any depression screening scale. Considering the studies [45,48–50,53,54], SVM has been the most used classifier for the detection of depression as it works well on unstructured and high-dimensional data. SVM is also resistant to overfitting. For data with an anonymous and irregular distribution, SVM can be proved to be an efficient algorithm.

Random Forest (RF) is the second most used classifier in the reviewed studies [45,47,48,50,51] as it is a computationally efficient algorithm. In [51], the RF model achieved 90, 95, and 92% accuracy for classifying depressive posts, depressive communities, and depression degrees. RF enhances the classification accuracies of continuous data by reducing the overfitting in decision trees. As RF is based on ensemble learning; it allows determining complex and straightforward functions more accurately. Figure 2 shows the comparison of classification models used for a depression diagnosis.

**Figure 2.** Comparison of classification models for depression diagnosis.

#### *4.2. Deep Learning Models*

This section highlights the deep learning models presented in multiple studies to detect depression. An artificial intelligence mental evaluation (AiME) framework [55] has been presented in a study for detecting symptoms of depression using multimodal deep networks-based human–computer interactive evaluation. The framework was applied to audio, video, and speech responses of 671 participants and PHQ-9 data. The authors of [56] discuss the multimodal stress detection using fusion of machine learning algorithms. In [56], a DL framework based on EEG data have been suggested for the automatic analysis of depression. The framework includes two DL models; one-dimensional convolutional neural network (1DCNN) and a combination of 1DCNN and LSTM model have been utilized. The dataset used in the study contained 30 healthy and 33 MDD patients' EEG data and quantitative information. BDI-II and HADS were used as the assessment scales. The framework achieved an overall classification accuracy of 98.32%. Erguzel, Sayar et al. [57] presented a hybridized methodology using PSO and ANN to distinguish between unipolar and bipolar depression based on EEG recordings. The presented ANN–PSO approach discriminated 31 bipolar and 58 unipolar subjects with 89.89% accuracy. SCID-I, HDRS 17-item version, YMRS, DSM-IV, and HADS were used as the assessment scales. However, this study used limited datasets.

Feng et al. [58] presented the X-A-BiLSTM model for diagnosing depression from social media data. The XGBoost component helped reduce imbalanced classes, and the Attention-BiLSTM neural network component enhanced the classification capacity. The RSDD dataset with approximately 9000 depressed users and 107,000 control users was used in the study. However, no standard screening scale for depression was used in their work. In [59], a novel approach was presented to optimize word embedding for classification. The proposed approach outperformed the previous state-of-the-art models on the RSDD dataset. The comparative evaluation was performed on some DL models for diagnosing depression from tweets on the user level. The experiments were performed on two publicly available datasets, CLPsych 2015 and Bell Let's Talk. Results showed that CNN-based models performed better than RNN-based models. However, the word embedding models did not perform efficiently with larger datasets.

Zogan et al. [59] presented interpretive Multimodal Depression Detection with Hierarchical Attention Network (MDHAN) to detect depressed people on social media. User posts along with Twitter-based multimodal features were considered. The semantic sequence features were captured from the individuals' profiles. MDHAN outperformed other baseline methods. It determined that combining DL with multi-model features can be effective. MDHAN achieved excellent performance and ensured adequate evidence to explain the prediction with an accuracy of 89.5%. However, this study needs to use a standard dataset of Twitter users because the social media data may be vague and can manipulate the experimental outcome. In [60], deep convolutional neural networks (DCNN) are designed to learn deep-learned characteristics from spectrograms and raw voice waveforms in the first place. To improve the depression recognition performance, we suggest using joint fine-tuning layers to merge the raw and spectrogram DCNN.

He and Cao [60] used DCNN to enhance depression classification. DCNN with LLD and MRELBP texture descriptors were applied on 100 training, 100 development, and 100 testing samples. AVEC2013 and AVEC2014 datasets were combined. The results were the MAE of 8.1901 and the RMSE of 9.8874 for the combined dataset. In [61], the authors presented a model for diagnosing mild depression by processing EEG signals using CNN. The model used four functional connectivity metrics (coherence, correlation, PLV, and PLI). The model obtained a classification accuracy of 80.74%. Only functional connectivity matrices are used in the research, and other metrics need to be used for evaluation. Ahmed et al. [62] discussed early depression diagnosis by analyzing posts of Reddit users using a DL-based hybrid model. BiLSTM with Glove, Word2Vec, and Fastext embedding techniques, Meta-Data features, and LIWC were applied on 401 (for testing) and 486 (for training) with 531,453 posts for depression detection. Beck Depression Inventory (BDI) was used as an assessment scale. The proposed model obtained F1 score, precision, and recall of 81, 78, and 86%, respectively. Table 3 presents a comparison of different deep learning models used for the diagnosis of depression.


**Table 3.**

Comparison

 of deep learning models for depression

 diagnosis.


**Table 3.** *Cont.* Discussion of Deep Learning Models

The reviewed studies used various DL models with different feature extraction and word embedding techniques in this section. The different DL models presented in [56] showed efficient discrimination between depressed and healthy controls. The 1DCNN achieved the highest classification accuracy of 98.32% and the one-dimensional DCNN with LSTM achieved an accuracy of 95.97%. The DL models automatically discriminate EEG signal patterns.

In the majority of the studies [56,57,61], EEG data have been utilized to diagnose the symptoms of depression in the participants. EEG patterns can help to indicate abnormalities in brain functions and irregular emotional alternations. The EEG signals resemble waves with peaks and valleys with the help of which irregularities can be identified. In [56], a variant of CNN, namely DCNN, was applied over EEG signals to diagnose unipolar depression. In [57], a hybrid model of ANN with PSO algorithm was used to discriminate unipolar and bipolar disorders based on EEG recordings, thereby achieving 89.89% accuracy. In [61], a CNN classification model for diagnosing mild depression by processing the EEG signals was used, and the model achieved 80.74% accuracy using the coherence functional connectivity metric. It can be concluded that EEG-based diagnosis is an efficient and cost-effective method for understanding brain activity and the neural that correlates with social anxiety. Figure 3 presents the comparison of DL models for depression.

**Figure 3.** Comparison of deep learning models for depression diagnosis.

#### *4.3. Ensemble Models*

This section briefly highlights different ensemble models presented in the reviewed studies for the diagnosis of depression. In [64], ML and statistical models were used to predict clinical depression and MDD among individuals suffering from immune-mediated inflammatory disease (IMID) by identifying patient-reported outcome measures (PROMs). LR, NN, and RF algorithms were used to analyze a dataset of 637 IMID patients. In [65], long short-term memory (LSTM) and six ML models including LR, logistic regression with lasso regularization, RF, gradient boosted decision tree (GBDT), SVM, and deep neural network (DNN) were used. LSTM has been applied to predict the level of different depression risk factors over the course of two years. The dataset contained 1538 data of elderly people in China using the Chinese Longitudinal Healthy Longevity Study (CLHLS). The results indicated that logistic regression with lasso regularization achieved a higher AUC value than other ML algorithms.

Tao, Chi et al. [66] proposed an ensemble binary classifier to analyze health survey data against ground truth from the SF-20 Quality of Life scales. With ensemble model (DT, AAN, KNN, SVM) applied on the NHANES dataset, the classifier demonstrated an F1 score of 0.976 in the prediction, without any incorrectly identified depression instances. This study has some limitations; the need to use rich online social media sources for feature extraction and dataset range is not defined. Karoly and Ruehlman [67] proposed an algorithm to distinguish between MDD and BD patients based on clinical variables. LR with Elastic Net and XGBoost were applied on 103 MDD and 52 BD patients and achieved an accuracy of 78% for LR with Elastic Net model. There are some limitations in this paper such as the small and unbalanced sample, lack of external sample validation, some misclassifications of classes, and a limited range of evaluation features.

Zhao, Feng et al. [68] evaluated the depression status of Chinese recruits using ML algorithms. NN, SVM, and DT were applied on 1000 participants and achieved 86, 86, and 73% accuracy for NN, SVM, and DT. BD-II was used as an assessment scale. This study needs to include complex socio-demographic and career variables into the model. Ji et al. [69] diagnosed bipolar disorder among Chinese by developing a BDCC using ML algorithms. SVR, RF, LASSO, LR, and LDA were applied on 255 MDD, 360 BPD, and 228 healthy sample data. The experiments obtained an accuracy of 92% for MDD and 92% for BPD detection. However, this model requires large datasets and needs to enhance its cross-sectional nature. Table 4 presents a comparison of different ensemble models used for the diagnosis of depression.

**Table 4.** Comparison of different ensemble models for depression diagnosis.



**Table 4.** *Cont.*

Discussion of Ensemble Models

Among the reviewed studies, ensemble models [66] obtained the highest accuracy of 95.4%. In this study, the NHANES dataset is used for evaluation; the predicted model just predicts the 4% cases wrongly. The ensemble model achieved F1 measure, accuracy, and precision of 97, 95, and 95%, respectively, on the whole dataset. It also shows that the ensemble method for identifying depression on a partial dataset is stable and resilient. The method and experiment showed that combining a classification methodology with binary ground truth may provide better prediction results than baseline standards. The ensemble technique is a straightforward approach similar to the bagging and major voting ensemble methods. Using five machine learning algorithms and Chinese multicenter cohort data, the ensemble model described in [69] obtained the second-highest classification accuracy of 92 percent. The higher AUC obtained in this study, compared to other studies, shows the research's acceptance and the validity of the Chinese version of the BDCC. In addition, the BDCC cuts the time it takes to gather clinical data in half. The ADE takes more than 30 min to complete, while the BDCC takes 10–15 min. The present findings show that the BDCC is just as reliable as the previous form, but it is much easier to deploy. Considering the studies [64,65,67,69], regression has been the most used ML technique for the detection of depression. Regression is simple to implement and easier to interpret the output coefficients. Regression is susceptible to overfitting, but it can be avoided using dimensionality reduction techniques, regularization (L1 and L2) techniques, and cross-validation.

#### **5. Future Research Possibilities**

We propose some possible future study directions in this part, based on the review of prior research in the preceding section.

(1) A larger data sample is required:

The majority of prior depression detection research utilized a small sample size. A small sample size is useful for building a prediction model, while a bigger sample size is important for constructing a more accurate model that works well throughout the population. When a large sample size is used to train a model, it allows for a greater diversity of depressed patients to be included, perhaps leading to models with real therapeutic value. When a few studies use bigger datasets, the methods will most likely alter and show more developed approval metrics. The k-fold cross-validation technique, in particular, may be employed with higher k-values to allow for larger test sets on which to test prediction models and increase generalizability.

(2) Learning method(s):

Various learning techniques give a better outcome in different situations; therefore, choosing the right one is crucial. Unlabeled data may sometimes help develop a prediction

model for a large sample size with little data. As a result, the first step is to determine if the incoming data are labeled, unlabeled, or a combination of labeled and unlabeled data. As a result, employing an unsupervised, supervised, or semi-supervised learning technique will be determined. The second phase is dependent on the learning method's objective, which must be addressed. The last stage is to identify whether the input is linear or nonlinear; linear data are helpful when the dataset is small to prevent overfitting, whereas nonlinear data are important when the dataset is big. The last step is to choose a learning technique to limit the options. The technique for picking the best learning method is to assess various factors such as complexity, flexibility, computation time, optimization ability, and so on, and then choose the best one. If you have too many learning method choices, evaluate the performance of each technique on the provided data; if you just have a few, simply change the default model to make it more appropriate for learning the given data.

#### (3) Clinical application:

Long-term, creating a predictive model aims to find a method that can improve accuracy. However, such a scenario is unlikely to arise in the next few years, since SVM and a few other supervised learning algorithms are presently trustworthy and seem to be around in this area of research. Regardless, after a sufficiently strong method has been thoroughly authorized via preliminary considerations, showing its efficacy, and determining whether it will benefit patients or not, its progression to clinical preliminaries will be critical. Future clinical trials should ensure that machine learning methods efficiently identify depressed individuals who are unlikely to respond to the current specialist under investigation. Clinicians' use of this information improves patient outcomes (for example, diminished inactivity among determination and reduction).

(4) Collaboration of research groups:

With the significant progress among different disciplines, collaboration with other disciplines is crucial for ADE. For affective computing, relevant fields include psychology, physiology, computer science, ML, etc. Thus, researchers should borrow each other's strengths to promote ADE's advances. For audio-based ADE, the deep models only represent the depression scale from audios. The deep models capture patterns only from facial expressions specific to video-based ADE. Notably, physiological signals also contain significant information closely related to depression estimation. Accordingly, different researchers should study together to build multimodal-based DL approaches for clinical application.

(5) Availability of databases:

Because of the sensitivity of depression data, it is difficult to gain various data for estimating the scale of depression. Hence, the availability of data is a major issue. First, as opposed to the facial expression recognition task, database availability is scarce up to the present day. Given the literature review, one can note that the widely used depression databases are AVEC2013, AVEC2014, and DAIC-WOZ. Notably, AVEC2014 is a subset of AVEC2013. Second, there is no multimodal (i.e., audio, video, text, physiological signals) database to learn comprehensive depression representations for ADE. The existing databases consist of two or three modalities. Though the DAIC database comprises three modalities (audiovisual and text), the organizer has not provided the original videos of DAIC, leading to a certain inconvenience for ADE. Third, the limited size of the datasets limits the research in depression prediction, especially when using DL technologies. For instance, AVEC2013 only contains 50 samples for training, development, and test set. Effective methods to augment the limited amount of annotated data are called to address this bottleneck. Fourth, the criteria for data collection should be standardized. At present, different organizers adopt a range of conditions, equipment, and configurations to collect multimodal data.

#### **6. Conclusions**

The ML approaches can be used to assist in diagnosing mental health conditions. PTSD, schizophrenia, depression, ASD, and bipolar diseases lie in the domains of mental disorders. Social media data, clinical health records, and mobile devices sensors data can be analyzed to identify mood disorders. In this paper, we surveyed state-of-the-art research studies on the diagnosis of depression using ML-based approaches. The purpose of this review paper is to provide information about basic concepts of ML algorithms frequently used in the mental health domain, specifically for depression and their practical application. Among the reviewed studies, SVM has been the most used classifier for detecting depression as it works well with unstructured and high-dimensional data. SVM is also resistant to overfitting. SVM can be proved to be an efficient algorithm for data with an anonymous and irregular distribution. As anticipated, most of the SVM classifiers developed in the articles had a high accuracy of greater than 75%. Because data in the mental health area are scarce, SVM outperforms other machine learning methods for diagnosis. We discussed some of the MHMS's research difficulties and potential advancements in mental health and depression. According to the research reviewed, applications based on machine learning provide a significant potential for progress in mental healthcare, including the prediction of outcomes and therapies for mental illnesses and depression.

**Author Contributions:** Data curation, S.A.; Formal analysis, N.u.H. and S.K.; Funding acquisition, S.S.A.; Methodology, S.A.; Resources, A.A.; Software, N.u.H.; Supervision, R.A.; Writing—review & editing, R.A. All authors have read and agreed to the published version of the manuscript.

**Funding:** Taif University Researchers Supporting Project number (TURSP-2020/215), Taif University, Taif, Saudi Arabia.

**Data Availability Statement:** The data supporting this study's findings are available from the corresponding author upon request.

**Conflicts of Interest:** The authors declare that there is no conflict of interest regarding the publication of this paper.

#### **References**


MDPI St. Alban-Anlage 66 4052 Basel Switzerland www.mdpi.com

*Electronics* Editorial Office E-mail: electronics@mdpi.com www.mdpi.com/journal/electronics

Disclaimer/Publisher's Note: The statements, opinions and data contained in all publications are solely those of the individual author(s) and contributor(s) and not of MDPI and/or the editor(s). MDPI and/or the editor(s) disclaim responsibility for any injury to people or property resulting from any ideas, methods, instructions or products referred to in the content.

Academic Open Access Publishing

mdpi.com ISBN 978-3-0365-9099-8