**Applied Machine Learning**

Editor

**Grzegorz Dudek**

MDPI • Basel • Beijing • Wuhan • Barcelona • Belgrade • Manchester • Tokyo • Cluj • Tianjin

*Editor* Grzegorz Dudek Czestochowa University of Technology Poland

*Editorial Office* MDPI St. Alban-Anlage 66 4052 Basel, Switzerland

This is a reprint of articles from the Special Issue published online in the open access journal *Applied Sciences* (ISSN 2076-3417) (available at: https://www.mdpi.com/journal/applsci/special\_ issues/Applied\_Machine\_Learning).

For citation purposes, cite each article independently as indicated on the article page online and as indicated below:

LastName, A.A.; LastName, B.B.; LastName, C.C. Article Title. *Journal Name* **Year**, *Volume Number*, Page Range.

**ISBN 978-3-0365-7906-1 (Hbk) ISBN 978-3-0365-7907-8 (PDF)**

© 2023 by the authors. Articles in this book are Open Access and distributed under the Creative Commons Attribution (CC BY) license, which allows users to download, copy and build upon published articles, as long as the author and publisher are properly credited, which ensures maximum dissemination and a wider impact of our publications.

The book as a whole is distributed by MDPI under the terms and conditions of the Creative Commons license CC BY-NC-ND.

## **Contents**





## **About the Editor**

### **Grzegorz Dudek**

Grzegorz Dudek received his PhD in electrical engineering from Czestochowa University of Technology (CUT), Poland, in 2003 and his habilitation in computer science from Lodz University of Technology, Poland, in 2013. Currently, he is an associate professor at the Department of Electrical Engineering, CUT. He is the author of four books on the subject of machine learning for forecasting and evolutionary algorithms for unit commitment in addition to over 100 scientific papers. He came third in the Global Energy Forecasting Competition 2014—price forecasting track. His research interests include machine learning and artificial intelligence, and their application to practical classification, regression, forecasting and optimization problems.

## *Editorial* **Special Issue on Applied Machine Learning**

**Grzegorz Dudek**

Department of Electrical Engineering, Cz ˛estochowa University of Technology, 42-200 Cz ˛estochowa, Poland; grzegorz.dudek@pcz.pl

### **1. Introduction**

Machine learning (ML) is one of the most exciting fields of computing today. Over the past few decades, ML has become an entrenched part of everyday life and has been successfully used to solve practical problems. An application area of ML is very broad, including engineering, industry, business, finance, medicine, and many other domains. ML covers a wide range of learning algorithms, including classical ones such as linear regression, k-nearest neighbors, decision trees, support vector machines and neural networks, and newly developed algorithms such as deep learning and boosted tree models. In practice, it is quite challenging to properly determine an appropriate architecture and parameters of ML models so that the resulting learner model can achieve sound performance for both learning and generalization. Practical applications of ML bring additional challenges, such as dealing with big, missing, distorted, and uncertain data. In addition, interpretability is a paramount quality that ML methods should aim to achieve if they are to be applied in practice. Interpretability allows us to understand ML model operation and raises confidence in its results.

This book compiles 41 papers published in the Special Issue titled "Applied Machine Learning". The papers focus on applications of ML models in a diverse range of fields and problems. They report substantive results on a wide range of learning methods, discuss conceptualization of problems, data representation, feature engineering, ML models, critical comparisons with existing techniques, and interpretation of results.

### **2. Summary of the Contributions**

There were 116 papers submitted to this special issue, and 41 papers were accepted. Although each paper covers different topics, we can identify six categories where the papers can be classified according to their main focus: computer vision, teaching and learning, social media, forecasting, basic problems of ML, and other topics.

### *2.1. Computer Vision*

Image processing and analysis are a basis of computer vision problems such as semantic segmentation, object classification, localization, and detection, optical character recognition, facial recognition etc. An appropriate representation of image content is a crucial problem. In [1], to deal with this problem, a novel type of representation is proposed where an image is reduced to a set of highly sparse matrices representing detected keypoints. The authors express intensity of features extracted from a dedicated convolutional neural network (CNN) autoencoder. The new features have many advantages such as they are not manually designed but learned, they are expected to minimize information loss and they are relatively easy to interpret.

In [2], a fast-self-adaptive digital camouflage method based on deep learning is proposed. It is designed for the new generation of adaptive optical camouflage which can change with the environment in real-time. The system is composed of a YOLOv3 model that identifies military targets, a pre-trained deepfillv1 model that designs the preliminary camouflage texture, and a k-means algorithm for standardization of the texture.

**Citation:** Dudek, G. Special Issue on Applied Machine Learning. *Appl. Sci.* **2022**, *12*, 2039. https://doi.org/ 10.3390/app12042039

Received: 13 January 2022 Accepted: 13 February 2022 Published: 16 February 2022

**Publisher's Note:** MDPI stays neutral with regard to jurisdictional claims in published maps and institutional affiliations.

**Copyright:** © 2022 by the authors. Licensee MDPI, Basel, Switzerland. This article is an open access article distributed under the terms and conditions of the Creative Commons Attribution (CC BY) license (https:// creativecommons.org/licenses/by/ 4.0/).

1

The experimental results show that the camouflage pattern designed by the proposed method is consistent with the background in texture and semantics, and has excellent camouflage performance.

A problem of classification of remote sensing images for disaster investigation, traffic control, and land-use resource management is considered in [3]. A new remote sensing scene classification network is proposed and a two-stage cyclical learning is developed to speed up model training and enhance accuracy. A t-distributed stochastic neighbor embedding algorithm was used to verify the effectiveness of the proposed model, and a local interpretable model-agnostic explanation algorithm was applied to improve the results.

In [4], tracking pedestrian workers on construction sites is considered to improve efficiency and safety management. Vision-based tracking approaches, suitable in this case, require a large amount of data originating from construction sites. These data are hardly available, so the authors propose to use a small general dataset and combine a deep learning detector with an approach based on classical ML techniques. They use YOLOv3 detector for identifying workers and compare its performance with an approach based on a soft cascaded classifier. They found that both approaches generally yield satisfying tracking performances but feature different characteristics. To augment a self-recorded real world dataset for learning the vision-based tracking system, in [5] virtual construction site scenarios are modeled using 3D computer graphics software. The detector's performance is examined when using synthetic data of various environmental conditions for training. The findings showed that a synthetic extension is beneficial for otherwise small datasets. It is an alternative to evaluate vision-based tracking systems on hazardous scenes without exposing workers to risks.

In [6], a problem of environment classification for unmanned aerial vehicles (UAV) is addressed. Images obtained from video and photographic cameras mounted on a UAV are recognized to detect ground, sky, and clouds. The proposed recognition system includes CNN trained with a dataset generated by both, a human expert and a Support Vector Machine (SVM) to capture context and precise localization.

#### *2.2. Teaching and Learning*

Student grade prediction is an important educational problem for designing personalized strategies of teaching and learning. To solve this problem, in [7], a graph regularized robust matrix factorization is proposed optimized by a majorization minimization algorithm. This method integrates two side graphs built on the side data of students and courses into the objective of robust low-rank matrix factorization. As a result, the learned features of students and courses can grasp more priors from educational situations to achieve higher grade prediction results. This facilitates personalized teaching and learning in higher education.

For developing adaptive e-learning systems, it is very helpful to provide information on how students recognize, process and store information. To improve students' learning evaluation, in [8], a method based on deep multi-target prediction algorithm using Felder-Silverman learning styles model is proposed. It uses feature selection, learning styles models, and multiple target classification to investigate the possibility of improving the accuracy of automatic learning styles identification. The obtained results show that learning styles allow adaptive e-learning systems to improve the learning processes of students.

Students' performance prediction in higher education was considered in [9]. To exploit the knowledge retrieved from one problem for improving the predictive performance of a learning model for a different but related problem, the authors use transfer learning. The experimental results demonstrate that the prognosis of students at risk of failure can be achieved with satisfactory accuracy in most cases, provided that datasets of students who have attended other related courses are available.

Paper [10] was conducted with the aim of identifying the interrelationships among topics based on the understanding of various bodies of knowledge. The study provides a foundation for topic compositions to construct an academic body of knowledge of AI. To this end, ML-based sentence similarity measurement models used in machine translation, chatbots, and document summarization were applied to the body of knowledge of AI. Consequently, several similar topics related to agent designing in AI were identified. The results of this study can be applied in the edutech field.

Predicting the academic standing of a student at the graduation time can be very useful for institutions to select among candidates or in helping potentially weak students in overcoming educational challenges. In [11], this problem is solved using several ML algorithms based on different student data including individual course grades and grade point averages. This approach can be applied to any dataset to determine when to use which college performance representation for enhanced prediction. For predicting the grades of undergraduate students in the final exams, in [12], multi-view learning is applied to exploit the knowledge retrieved from data, represented by multiple feature subsets known as views. A semi-supervised regression algorithm is proposed which exploits three independent and naturally formed feature views, derived from different sources. The experimental results demonstrate that the early prognosis of students at risk of failure can be accurately achieved and could highly benefit the educational domain.

### *2.3. Social Media*

One prominent dark side of online information behavior is the spreading of rumors on social media. Paper [13] analyses the association between user features and rumor refuting behavior in different rumor categories. Natural language processing (NLP) techniques are applied to quantify the user's sentiment tendency and recent interests. The users' personalized features are used to train XGBoost classification model to identify potential refuters. The results revealed that there are significant differences between rumor stiflers and refuters, as well as between refuters for different categories.

The objective of [14] is to detect variables that allow organizations to manage their social network services efficiently. This study, applying ML algorithms and multiple linear regression, reveals which aspects of published content increase the recognition of publications through retweets and favorites. The findings of this research provide new knowledge about trends and patterns of use in social media, providing academics and professionals with the necessary guidelines to efficiently manage these technologies in the organizational field.

Paper [15] concerns the tourists' sentiments regarding travel destinations based on online travel review texts. The authors transformed sentiment analysis into a multiclassification problem based on ML methods, and further designed a keyword semantic expansion method based on a knowledge graph. The method extracts keywords from online travel review texts and obtains the concept list of keywords through the knowledge graph. This list is then added to the review text to facilitate the construction of semantically expanded classification data. The results of sentiment analysis form an important basis for tourism decision making.

Micro-blogs, such as Twitter, have become important tools to share opinions and information among users. The authors of [16] wonder how a user can discover influencers concerned with their interest. They propose a classification model trained on messages labeled with topical classes, so as this model is able to classify unlabeled messages. This model can be used to reveal the hidden topic the messages are talking about.

With the widespread use of over-the-top (OTT) media, such as YouTube and Netflix, network markets are changing and innovating rapidly, making it essential for network providers to quickly and efficiently analyze OTT traffic with respect to pricing plans and infrastructure investments. In [17], a time-aware deep learning method of analyzing OTT traffic to classify users for this purpose is presented. A novel framework to better exploit accuracy, while dramatically reducing classification time is proposed. The resultant approach provides a simple method for customizing pricing plans and load balancing by classifying OTT users more accurately.

Recommendation systems aim to decipher user interests, preferences, and behavioral patterns automatically. The credibility of the recommendation is of magnificent importance in crowdfunding project recommendations. Paper [18] devises a hybrid ML-based approach for credible crowdfunding projects' recommendations by wisely incorporating backers' sentiments and other influential features. The proposed model has four modules: a feature extraction module, a hybrid latent Dirichlet allocation and LSTM-based latent topics evaluation module, credibility formulation, and recommendation module. The proposed model's evaluation depicts that credibility assessment based on the hybrid ML approach contributes more efficient results than existing recommendation models.

### *2.4. Forecasting*

Stock performance prediction is one of the most challenging issues in time series data analysis. Paper [19] proposes to build an automated trading system by integrating AI and the proven method invented by human stock traders. The knowledge and experience of the successful stock traders are extracted from their related publications. After that, an LSTMbased deep NN is developed to use the human stock traders' knowledge in the automatic trading system. Experimental results indicate that the proposed ranking-based stock classification considering historical volatility strategy outperforms conventional methods.

In [20], the authors study the volatility forecasts in the Bitcoin market, which has become popular in the global market in recent years. For the improvement of the forecasting accuracy of Bitcoin's volatility, they develop hybrid forecasting models combining the GARCH family models with the ML approach including NNs.

Paper [21] is about forecasting the Key Performance Indicators (KPIs), usually in the form of time series data, related to the COVID-19 pandemic. Making reliable predictions of these indicators, particularly for emergency departments, can facilitate acute unit planning, enhance the quality of care and optimise resources. The authors compare the KPI forecasting models including classical ARIMA, Prophet and General Regression NN.

A development of the intelligent transport systems has created conditions for solving the supply-demand imbalance of public transportation services. In [22], a method to forecast real-time online taxi-hailing demand is introduced. It is based on NNs and extreme gradient boosting. The proposed method can help to schedule online taxi-hailing resources in advance.

Climate change increases the frequency and intensity of heatwaves, causing significant human and material losses every year. Big data, whose volumes are rapidly increasing, are expected to be used for preemptive responses. In [23], for weekly prediction of heatrelated damages, a random forest model was developed using statistical, meteorological, and floating population data. The results show that the proposed model outperforms existing ones.

One of the hottest topics in today's meteorological research is weather nowcasting, which is the weather forecast for a short time period such as one to six hours. With the main goal of helping meteorologists in analyzing radar data for issuing nowcasting warnings, in [24], a regression model based on an ensemble of deep NNs for predicting the values for radar products is proposed. The proposed model is intended to be a proof of concept for the effectiveness of learning from radar data relevant patterns that would be useful for predicting future values for radar products based on their historical values.

### *2.5. Basic Problems of ML*

Paper [25] deals with the problem of instance selection for classifiers. The main goal is to improve the performance of a classifier (its speed and accuracy) by eliminating redundant and noisy samples. The obtained results indicate that for the most of the classifiers compressing the training set affects prediction performance and only a small group of instance selection methods can be recommended as a general purpose preprocessing step. These are learning vector quantization based algorithms, along with the Drop2 and Drop3.

Support vector machines are a well-known classifiers due to their superior classification performance. To decrease the large-scale SVM complexity, in [26], a novel data reduction method for reducing the training time by combining decision trees and relative support distance is proposed. The method selects good support vector candidates in each partition generated by the decision trees. The selected candidates reduced the training time while maintaining good classification performance in comparison with existing approaches

Paper [27] deals with the problem of solving of partial differential equations, which is a hot topic of mathematical research. The authors introduce an improved Physics Informed Neural Network (PINN) for solving partial differential equations. PINN takes the physical information that is contained in partial differential equations as a regularization term, which improves a performance of NNs. The experimental results show that PINN is effective in solving partial differential equations and deserves further research.

Machine learning of automata and grammars has a wide range of applications in such fields as syntactic pattern recognition, computational biology, systems modeling, natural language acquisition, and knowledge discovery. In [28], an approach to non-deterministic finite automaton inductive synthesis that is based on answer set programming (ASP) solvers are proposed. They consist of preparing logical rules before starting the searching process. The authors show how the proposed ASP solvers help to tackle the regular inference problem for large-size instances and compare their approach with the existing ones. Experiments indicated that the proposed approach clearly outperforms the current state-of-the-art satisfiability-based method and all backtracking algorithms proposed in the literature.

Paper [29] sits in the scientific field known as grammatical inference (GI), automata learning, grammar identification, or grammar induction. The matter under consideration is the set of rules that lie behind a given sequence of words and the main task is to discover the rules that can help to evaluate new, unseen words. The authors propose a new grammatical inference method and applied it to a real bioinformatics task, i.e., classification of amyloidogenic sequences. In the experimental evaluation, they showed that the new grammatical inference algorithm gives the best results in comparison to other automata or grammar learning methods as well as ML approach combining an unsupervised datadriven distributed representation and SVM.

Paper [30] is about anticipatory classifier systems, i.e., the classifier systems that learn by using a cognitive mechanism of anticipatory behavioral control which was introduced in cognitive psychology. The authors note that the learning classifier systems revealed many real-world sequential decision problems where the preferred objective is the maximization of the average of successive rewards. To address such problems, they proposes a modification toward a learning component: a new average reward criterion. In the experimental study, they showed that the anticipatory classifier systems with an averaged reward criterion can be used successively in multi-step environments.

#### *2.6. Other Topics*

A medical care application of ML is considered in [31]. This work is on a sleep apnea which is a common sleep-related disorder that significantly affects the population. It is characterized by repeated breathing interruption during sleep. The authors propose a new probabilistic algorithm based on oronasal respiration signal for automated detection of apnea events during sleep. Unlike classical threshold-based classification models, they use a Gaussian mixture probability model for detecting sleep apnea based on the posterior probabilities of the respective events. The results show significant improvement in the ability to detect sleep apnea events compared to a rule-based classifier that uses the same classification features and also compared to the previously published studies.

Paper [32] deals with a discrete optimization problem of product placement and of order picking routes in a warehouse. The authors propose a genetic algorithm that minimizes the sum of the order picking times. The product placement is optimized by another genetic algorithm. To improve and accelerate an optimization process, several ideas are proposed such as a multi-parent crossover, caching procedure, multiple restart and order grouping. A proposed solution decreases significantly the total order picking times.

ML techniques have been actively applied to the meteorology and climatology fields in recent years. They are used for forecasting in different horizons, modelling climatic data, quality control and correction of observed weather data. Paper [33] deals with the topic of the climate change. It presents a framework for selecting general circulation models (GCMs) in homogeneous climatic zones and detecting future climate change trends. With the support of ML techniques, long records of climate data, from numerous gauging sites and web sources, were analyzed and used to determine historical and projected trends of climate change. In [34], to detect the weather phenomena such as precipitation and fog from the backscatter data obtained from the lidar ceilometer, three ML models were applied: random forest, SVM, and NN. The prediction results showed the potential for precipitation detection, but fog detection was found to be very difficult.

The emission of carbon dioxide caused by various sectors, including construction and industrial processes, has emerged as a severe problem that dramatically affects global climate change. A portland cement production process accounts for a large part global anthropogenic CO2 emission. A fly ash-based geopolymer concrete (FAGP) offers a favourable alternative to conventional Portland concrete due to its reduced embodied carbon dioxide content. In [35], ML methods including artificial NN, deep NN and ResNet were employed to predict mechanical properties of FAGP concrete. The obtained results indicate that the proposed approaches offer reliable methods for FAGP design and optimisation.

Paper [36] deals with a vibration test in the space structure testing. During the physical tests, the structure must not be overtested to avoid any risk of damage. In order to solve the issues associated with existing methods of live monitoring of the structure's response, the authors investigated the use of artificial NNs to predict the system's responses during the test. The conducted research accounts for a novel method for live prediction of stresses, allowing failure to be evaluated for different types of material via yield criteria.

Software vulnerabilities are one of the main causes of cybersecurity problems, resulting in huge losses. Existing solutions to automated vulnerability detection are mostly based on features that are defined by human experts and directly lead to missed potential vulnerability. Deep learning is proposed in [37] as an effective method for automating the extraction of vulnerability characteristics. Word2vec continuous bag-of-words, multiple structural CNNs, and stacking classifiers were found to be the best combination for automated vulnerability detection by comparing classification results.

Paper [38] is on information privacy which is a critical design feature for any exchange system, with privacy-preserving applications requiring the identification and labelling of sensitive information. The authors propose a predictive context-aware model based on a Bidirectional Long Short Term Memory network with Conditional Random Fields (BiLSTM + CRF) to identify and label sensitive information in conversational data. The results demonstrate that the BiLSTM + CRF model architecture with BERT embeddings and WordShape features is very effective and outperforms competitive solutions.

Natural language processing has enormous areas of applications including sentiment analysis, machine translation, text classification and extraction. In [39], the problem of developing a deep learning-based language model that helps software engineers write code faster is considered. This research proposes a hybrid approach that harnesses the synergy between ML techniques and advanced design methods aiming to develop a code auto-completion framework that helps firmware developers write code in a more efficient manner. The proposed framework can save numerous hours of productivity by eliminating tedious parts of writing code.

In [40], the problem of predicting the movement of a drifter on the ocean is considered. The authors estimated drifter tracking over seawater using ML and evolutionary search techniques including differential evolution, particle swarm optimization, multi-layer perceptron, SVM, deep NNs, LSTM and others. Extensive comparative research allows us to evaluate the suitability of various ML algorithms for solving this type of problem.

A salesperson performance measurement is a process that occurs multiple times per year on a company. During this process, the salesperson is evaluated how he or she performed on numerous KPIs. In [41], several data mining techniques are proposed to allow managers to make a better decision about salespeople performance measurement based on metrics defined by the business. The authors applied a naive Bayes model to classify salespeople into pre-defined categories provided by the business. They showed that the proposed approach can be applied in many companies using different KPIs.

**Conflicts of Interest:** The authors declare no conflict of interest.

### **References**


## *Article* **Automatic Identification of Local Features Representing Image Content with the Use of Convolutional Neural Networks**

### **Paweł Tarasiuk, Arkadiusz Tomczyk \* and Bartłomiej Stasiak**

Institute of Information Technology, Lodz University of Technology, ul. Wolczanska 215, 90-924 Lodz, Poland; pawel.tarasiuk@p.lodz.pl (P.T.); bartlomiej.stasiak@p.lodz.pl (B.S.)

**\*** Correspondence: arkadiusz.tomczyk@p.lodz.pl; Tel.: +48-42-631-39-57

Received: 30 June 2020; Accepted: 23 July 2020; Published: 28 July 2020

**Abstract:** Image analysis has many practical applications and proper representation of image content is its crucial element. In this work, a novel type of representation is proposed where an image is reduced to a set of highly sparse matrices. Equivalently, it can be viewed as a set of local features of different types, as precise coordinates of detected keypoints are given. Additionally, every keypoint has a value expressing feature intensity at a given location. These features are extracted from a dedicated convolutional neural network autoencoder. This kind of representation has many advantages. First of all, local features are not manually designed but are automatically trained for a given class of images. Second, as they are trained in a network that restores its input on the output, they may be expected to minimize information loss. Consequently, they can be used to solve similar tasks replacing original images; such an ability was illustrated with image classification task. Third, the generated features, although automatically synthesized, are relatively easy to interpret. Taking a decoder part of our network, one can easily generate a visual building block connected with a specific feature. As the proposed method is entirely new, a detailed analysis of its properties for a relatively simple data set was conducted and is described in this work. Moreover, to present the quality of trained features, it is compared with results of convolutional neural networks having a similar working principle (sparse coding).

**Keywords:** image representation; local features; autoencoder; convolutional neural network; machine learning

### **1. Introduction**

Images are typically represented using regular grids of pixels. The information about image content is kept both in pixels' attributes (color channels) and, which seems to be even more important, in their spatial distribution. This kind of representation, although natural for humans, has at least one crucial drawback: It significantly complicates the design of effective computer algorithms able to accomplish tasks which are relatively easy for our visual system. The nature of this problem lies not only in the huge number of image elements and the variety of their possible distributions, but also in the fact that humans do not consciously operate on individual pixels. In most of the cases, the latter reason makes it impossible to directly write computer programs imitating the unconscious process of image understanding.

There are two typical approaches allowing to overcome the above problem. The first group of methods aims at changing and simplifying the representation of image content. The second one, instead of direct implementation, engages machine learning for this purpose. Although these methods can be used separately (simplified representations may allow to design algorithms ready for direct implementation and there are trainable models that can operate on raw pixels), they are usually combined together. The need for this combination originates in the limitations of existing models (e.g., specific input format) as well as in the necessity of selection of optimal model parameters (even carefully designed algorithms require problem specific fine-tuning).

An algorithm solving a given problem, created both with the use of machine learning techniques and without them, requires domain knowledge either to encode it in a computer program or to design the training objective. Unfortunately, domain experts usually have problems with sharing their knowledge in a form of ready-to-use mathematical formulas. They prefer to express it in imprecise natural language (it must be somehow adapted to become applicable in the source code of the program) or they provide it in a form of a training set (for a given input they specify the expected output). In both cases, it may be easier for them to operate on simplified representations rather than on millions of pixels. Training set preparation may be especially troublesome for some specific tasks. In particular, when precise image segmentation is the goal of analysis, the knowledge acquisition at pixel level can be tiresome and time-consuming. Moreover, if it concerns, e.g., medical applications, where the number of experts is limited, acquiring a huge and representative set of examples, required by most of machine learning algorithms, becomes almost impossible.

Another important aspect of image representation choice is interpretability of the algorithm results. Nowadays many artificial intelligence techniques have become a part of our life. Some of them are, or in the near future, will become responsible for our health and even life. Consequently, their authors must be prepared to at least explain their general principles to potential consumers to convince them that their product is safe. Operating on individual pixels makes it practically impossible. If instead the applied representation is significantly reduced and its components can be assigned a meaningful interpretation, such an explanation becomes plausible.

To conclude, the search for alternative image representations constitutes an important task from image analysis point of view. Moreover, it is very interesting itself as it may also allow better understanding of the principles of human visual system operation. There are many evidences that this system also tries to organize the recognition process creating several intermediate representations corresponding to elements of the observed scene [1]. Such intermediate representations are also observed in trained convolutional neural networks (CNN) [2–6] as they try to imitate the activity of visual cortex (to some extent). This was the main reason behind the choice of a CNN to automate the process of image representation construction in our research.

In this paper, we propose a new CNN-based method allowing to generate general image content representation. The image is described as a tuple of feature maps that describe the localization and intensity of selected visual features on the image plane. A feature map is a matrix that describes the intensity of a certain visual feature at every point of the image. This means that instead of global image features, such as *image is mostly blue* or *there are many vertical lines*, we focus on local features, such as *there is a vertical segment centered at a point with given coordinates*. Due to the local nature of the selected features, the proposed method encodes each feature occurrence as a single matrix element. The exact coordinates of the selected pixel reflect the precise localization of the feature in the image plane. As a result, the relation between feature maps and actual image content is much more direct than in the case of classic CNNs. This greatly simplifies the semantic analysis of features and feature maps, which originally required a specific approach such as deconvolutional neural networks [3].

Naturally, the aforementioned features have to be non-trivial. If visual features consisting of a single pixel were allowed, the trained feature extractor could take the form of an identity function and the feature maps would correspond to the original image. In order to select semantically meaningful features, our method ensures that the feature maps at a selected level are sparse matrices, where the neighborhood of each activated (non-zero) element is zeroed. The resulting image representation, based on the visual features of the image, should contain sufficient information to ensure that the goals of the analysis are met. The encoding is obviously lossy, but the most common patterns found in the training set are expected to be preserved. Because of its ubiquity in the field of machine learning, the MNIST data set of handwritten digits [7] was employed for the purpose of the present study. Not only does it allow for comparisons of the obtained results with those produced by similar techniques, but it is also, due to its simplicity, easily interpretable. The latter asset enables understanding of the meaning of the generated local features.

The remaining part of this paper is organized as follows. Section 2 discusses related works. In Section 3, the concept of convolutional neural networks is outlined and our novel contributions are defined. Section 4 describes the experimental framework for providing as sparse an image representation as possible, without losing the key information. The results are discussed in Section 5. A detailed analysis of the neural network models is illustrated with various visualizations. Finally, Section 6 presents the conclusions and directions for further research.

#### **2. Related Works**

The method of local feature identification proposed in this work uses the specific properties of CNNs. To enforce sparse coding, which allows us to determine the precise localization of these features, we propose a specific neural network architecture with additional filtering layers and a unique adjustment of the training objective. This is a novel approach, and thus it is hard to compare it with existing works. Nevertheless, in this section we try to present some of the related works aggregating them into three groups: works devoted to other methods of generating alternative representations of image content; works trying to automatically find semantic interpretation of features emerging in CNNs; and works having, to some extent, similar working principles to our approach.

### *2.1. Image Representation*

As it was mentioned in Section 1, the change of image representation (extraction of features) is a crucial step in image analysis. It depends naturally on the type of considered task and consequently on the techniques that will be used. In the case of image classification tasks, global representations may be sufficient. However, for pattern localization, object detection and, in particular, for image segmentation, local features extraction is essential. Global representations treat the image as a whole and try to generate descriptors (usually feature vectors) which summarize colors, textures, shapes, etc. visible in this image. Features presented in this work are local. It means that the descriptors are assigned not to the whole image, but to specific locations (keypoints) within it. Naturally, raw pixels are also such local descriptors, but what we look for are reduced representations where the number of descriptors is significantly smaller than the number of all pixels. The reduced representation does not necessarily mean the loss of information. Their number is smaller but as they describe properties of image regions they may contain more information than color channel values assigned to single pixels. Moreover, additional information may be also kept in the data structure reflecting relations (including spatial relations) between these descriptors.

In the literature, there are two typical strategies for local descriptor finding: The first one uses segmentation techniques to define homogeneous regions of the image. Having found them, descriptors may be assigned either to these regions [8] or to their borders [9]. The regions are associated with information about their precise location (e.g., centroid of the region) and, consequently, their spatial relations can be discovered as well. The second strategy achieves a similar goal in the opposite way. First, characteristic points are sought for in the image plane (keypoints) and the local region around them is identified afterwards. In this group, such techniques as SIFT [10] or SURF [11] can be mentioned. They are particularly interesting, as they provide scale and orientation invariance. In all the above cases, after region or keypoint detection, descriptors (local features) must be computed. These descriptors can take into account the shape and the color of a region or they can be based on local gradients (SIFT) or wavelet responses (SURF). All of them, however, are designed manually by the author of a specific application.

The local features can be used both to classify image content and to solve more complex tasks. In the simplest case, clustered descriptors allow the identification of visual vocabulary depending on which bag-of-visual-words (BoVW) technique can be applied. In this approach, image content is

transformed into a real vector (one-hot or frequency encoding) and consequently most classic pattern recognition techniques can be employed. If spatial relations between local features need to be taken into account or more complex tasks (object localization, segmentation) are to be solved, other methods must be used. For SIFT and SURF descriptors, a dedicated efficient matching algorithm was designed to find a correspondence of local features extracted from different images [10]. Other possible approaches construct a graph describing the image content, where local features are related to its nodes and spatial relationships are reflected in the edges. In such a case, geometric deep learning (GDL), allowing to generalize the CNN concept to non-Euclidean domains, can be applied [12–14]. Alternatively, active partitions [15], an extension of classic active contours, can be of use here as well.

### *2.2. Semantic Interpretation*

There are many evidences that feature maps generated by successive convolutional layers of a CNN correspond with some semantically important parts of the analyzed images [3]. The identification of relationships between these parts and feature maps is not, however, a simple task. First of all, CNNs were always treated as trainable black boxes (similarly to other neural networks) and while designing their architecture no attention was paid to how the intermediate outputs can be interpreted. The resulting feature maps are usually blurred and it is really hard to understand the relation between them and the content of the input images. Moreover, in classic CNN architectures (pooling layers and no padding) the size of the feature maps is reduced in consecutive layers. This leads to further problems with identifying the precise location of semantically important regions.

In the literature there can be found several techniques trying to reveal the aforementioned relationships. In [4], first the peaks of the feature maps are mapped onto visual (receptive) fields within the input image. There their correspondence with known semantic parts is checked. As more than one feature map may be connected with a given part, a genetic algorithm is then applied to find the most appropriate subset of the feature maps from all the convolutional layers. In [5], instead of the layer outputs, their gradients maps, calculated with backpropagation algorithm, are used to find activation centers. In [16], the authors introduce class activation mapping (CAM), which can be used for identification and visualization of discriminative image regions, as well as for weakly supervised object localization. In that approach, a CNN network must be trained to classify images (supervised learning) and typical fully-connected layers are replaced by global average pooling (GAP) followed by a fully connected soft-max layer. The CAM for a given class can be found as a linear combination of the final feature maps generated by convolutional layers with weights corresponding to a specific network output. As the size of the class activation map is equal to the size of the final feature maps, it must be upsampled to be comparable with the input image. Finally, in [6], the authors assume that the top layers of a network correspond to the bigger parts or whole objects, while the lower layers reflect smaller parts which are building blocks for the more complex ones. They propose a method that is able to automatically discover a graph describing these relationships. It should be noted, that all these methods, although interesting, are quite complex. Moreover, they try to find correspondence with known parts (supervised process), which need not be optimal in every application.

### *2.3. Working Principles*

The key part of the proposed method involves using sparse matrices as an intermediate step of image processing with CNNs. This should not be confused with sparse convolutional neural networks proposed in [17], as that work was founded upon using sparse filter matrices in multiple convolutional layers, whereas our approach is based on sparse outputs. Another method that applied sparse coding to CNNs was presented in [18] and addressed the problem of image super-resolution. This approach, however, also differs from ours in terms of both the main goal and the motivation behind using sparse matrices. In the present study, sparse matrices are utilized to generate image descriptions based on visual features.

These examples show that it is hard to find works with objectives similar to ours. Nevertheless, we were able to identify two groups of research areas which can be considered related and which will be used as a comparison base for our results.

Sparse coding in feed-forward neural networks was considered in multiple works as a tool for improving the performance of typical CNNs with dense matrices. This includes both theoretical analysis [19] and practical application to image reconstruction [20]. There are also multiple ways to generate sparse image representations for the tasks of image reconstruction and classification. One of the notable approaches is based on the Fisher discrimination criterion [21]. Multiple related works describe solutions based on CNNs [22–24], which make them similar and potentially comparable to the method presented in this paper. However, these studies are concerned with issues related to either the computational speed or accuracy and did not consider the problem of intermediate feature extraction. As presented in [25], sparse coding can simplify the classification task by maximizing the margin from the decision boundary in a selected metric space. It must be emphasized, however, that none of these works was dedicated to automatic detection of visual image features.

Sparse representation of the hidden layer outputs is also typical for spiking convolutional neural networks (SCNNs) [17]. The key component of SCNNs intends to simulate the electro-physiological process that occurs in synapses. Another notable advantage of SCNNs is the possibility to implement them on FPGA-based hardware [26]. The actual solutions are usually based on leaky integrate-and-fire (LIF) neurons [27,28] or spike-timing-dependent plasticity (STDP) learning [29]. Both above-mentioned methods involve the introduction of additional types of neurons that simulate spiking of the electrical charge, according to the selected model of synapses behavior. Distinct peaks related to the presence of specific patterns are rare enough to generate sparse data, which can be further reshaped into a sparse matrix. Thus, SCNN-based methods belong to the field of sparse coding. In the proposed approach, the learning process known from the basic CNNs is enhanced only with additional cost function components (which can be considered as model regularization) and activation functions, but no additional neuron types are introduced.

#### *2.4. Contribution*

The goal of our research was to create a tool that will be able to automatically (in an unsupervised way) discover spatially located visual features for a given class of images. These features should lead to a reduced representation of the image content without the loss of information contained there. Such a representation should allow to create image analysis algorithms which would be easier for interpretation, allowing the use of simpler models where external expert knowledge can be incorporated in a more natural way.

The image representation proposed in this work enables to achieve the above goal. Unlike the case of SIFT or SURF methods mentioned in Section 2.1, the feature identification does not rely on a manually designed algorithm, but it can be trained for a specific class of images. The role of a keypoint extractor is played by an encoder part of the proposed convolutional autoencoder. The value assigned to a given keypoint corresponds with the intensity of a visual feature. This value, together with the number of the sparse feature map where the keypoint was found, constitutes a form of a keypoint descriptor. It need not be more complex as no further matching is required when two images are compared. The feature map number directly identifies keypoints of the same type.

Although the types of the features (the numbers of the successive feature maps) seem to be very abstract, our approach allows us to discover and understand the nature of these features without the necessity of using such complex algorithms as those presented in Section 2.2. Their form can be revealed using a decoder part of our network. All of that would not be possible if we could not precisely locate these features in the original images, which is problematic for typical, blurred feature maps. Our approach solves this problem thanks to the novel training objective component, which enforces leptokurtic distribution of specific layer outputs, and thanks to the new filtering step added to the network architecture.

It should also be emphasized that the proposed representation differs significantly from reduced representations which can be obtained using classic feature reduction algorithms like PCA or non-convolutional autoencoder. The convolutional autoencoder used in this work takes the spatial relationships between reduced features and (which is its additional advantage) it performs feature reduction in a local way. The resulting features are calculated only on the basis of the pixels belonging to the respective receptive field. Such a weight sharing, typical for a CNN, reduces the number of trainable parameters and allows us to detect the same features in different places of the image plane regardless of the image size.

Finally, we use matrices, and not vectors, to encode images, because not only do we want to compress the information, but we want to extract the information about localization of the visual building blocks as well. This is a very specific application, which requires matrices only because the images are represented as regular grids of pixels. Nevertheless, the proposed approach is general. CNN is designed to work with matrix-like structures, but interpretation of these structures is irrelevant. If an autoencoder is used, one can obtain an encoder which generates sparse matrices preserving the whole information about the encoded data. In order to use the resulting sparse matrices, another processing tool unit must be designed. An example is a classifier described in Section 4.3. An alternative solution, which is not presented in this work, could be to train a CNN directly performing a specific task (e.g., classification) with enforced sparsity inside. In this case, however, the sparse information would not preserve the whole information about the input, but only this part which is required to accomplish a given goal.

### **3. Method**

### *3.1. Method Overview*

As proposed by LeCun [2], CNNs are feed-forward neural networks that typically consist of the following.


In this paper, two applications of CNNs are considered. The most important architecture proposed in this study is a CNN-based autoencoder. Its primary goal is to minimize the difference between the input data and the output obtained for any input from the considered data set. The difference is calculated as the Euclidean distance between vectors of pixels. This implies that the resolutions of both input and output are expected to be the same. Thus, in our approach, no pooling layers are used and each convolutional layer is complemented by appropriate padding. As the convolution of *mw* × *mh* matrix with *fw* × *fh* filter yields (*mw* − *fw* + 1) × (*mh* − *fh* + 1) as a result, (*fw* − 1)/2 zero-padding is added to the sides of the matrix, and (*fh* − 1)/2 to the top and bottom. This is possible when both filter dimensions (*fw* and *fh*) are odd. This property is illustrated in Figure 1.

Without setting additional requirements, it would be easy to construct a perfect CNN-based image autoencoder. It would be sufficient that each layer generated an output equal to the input. This could be achieved with a convolution filter that has 1 in a single matrix element and 0 everywhere else. In order to avoid a meaningless result like that, we force one selected hidden layer to consist only of sparse matrices. The sparsity is guaranteed in the following way. For each non-zero element (*i*, *j*) of

the output matrix, all other elements in *s* × *s* square centered in (*i*, *j*) are reduced to zeros. This step of data processing is further referred to as local-maximum filtering (LMF) (Figure 2) and its details are described in Section 3.3.

**Figure 1.** The operation illustrated in this figure is a superposition of 1 × 1 zero-padding and matrix convolution with 3 × 3 filters. As the padding size matches the filter size properly, the output matrix has the same size as the input matrix.

**Figure 2.** The full architecture of the autoencoder consists of two major parts: encoder and decoder. The encoder includes convolutional layers that can either be adjusted to the data set in the learning process or use some fixed weights. The result is further processed with local-maximum filtering (Section 3.3). In the case of adjustable convolutional layers in the encoder, the cost function related to the encoder's CNN output is modified in order to reduce the output kurtosis, as described in Section 3.2. The encoder output is fed into the decoder which consists of convolutional layers that participate in the learning process. The learning objective is to reproduce the original input image, while minimizing the reconstruction error, which is measured in terms of the Euclidean distance.

#### *3.2. Leptokurtic Feature Maps*

In order to obtain satisfactory results of local-maximum filtering in the selected hidden layerduring CNN training, the learning objective is enhanced with a kurtosis-based adjustment. It is based on the following observation; splitting the convolutional layer outputs into small subsets of highly activated points and low activation of the other elements may be equated to leptokurtic distribution of the outputs. Leptokurtic distribution (related to high kurtosis) means that all the elements are concentrated closer to the mean value than in the case of normal distribution. Forcing the leptokurtic distribution may be considered as a process equivalent to kurtosis maximization. The kurtosis function is continuous and differentiable almost everywhere, which provides the ability to apply gradient-based learning. Consequently, it can constitute an additional component of the cost function.

The kurtosis [30] of a vector *X* = (*X*1, *X*2, ... , *Xn*) (this notation is valid for both random variables and fixed numbers) is defined as

$$\text{Kurt}\,X = \frac{\mu\_4(X)}{\sigma(X)^4} - 3,$$

where *μ*4(*X*) is the fourth central moment and *σ* is standard derivation. In order to perform the gradient learning, we need to calculate the actual gradient. As the formulas are symmetric in terms of the elements of *X*, the only expression we need is

$$\begin{split} \frac{\partial(\text{Kurt}\,X)}{\partial X\_{i}} &= \quad \frac{4\left(X\_{i}^{3} - \text{E}((X - \text{EX})^{3})\right)\,\text{Var}(X)}{n\,\text{Var}(X)^{3}} + \\ &\quad - \quad \frac{4\,\text{E}((X - \text{EX})^{4})\cdot X\_{i}}{n\,\text{Var}(X)^{3}}. \end{split} \tag{2}$$

Formula (2) has one important disadvantage when used for gradient learning. As kurtosis (1) is indifferent to the magnitude of the inputs, the differential decreases as the magnitude of the inputs grows. As a result, big values would be modified more slowly by the learning process. In order to reverse this effect and obtain a change that is proportional to the current value of convolutional layer outputs (and to the corresponding weights—as in the case of CNNs these terms are proportional), the differential (2) is multiplied by Var(*X*). This means using exponent 2 instead of 3 in the denominators of expression (2).

#### *3.3. Local-Maximum Filtering*

The additional gradient component, which makes the selected part of the CNN yield leptokurtic outputs, does not guarantee the desired properties of the sparse output. In order to achieve literal sparsity, we need to make sure that some matrix elements are replaced with zeros. This could be easily achieved by thresholding—a process similar to that described in [31]. In order to limit the number of the remaining outputs, the threshold level could be defined as a quantile of the output of either the whole layer or a single resulting matrix. In this work, however, instead of using the global statistics of the CNN layer output, we propose a method that focuses on local properties.

The proposed approach, which generates only one non-zeroed element in each *s* × *s* matrix minor, has two major advantages. First, this operation is easy to implement for parallel computations, which is important, as the present CNN solution was implemented using GPU, supported by the Caffe framework [32]. Each element is considered separately, and is zeroed if it is not strictly the greatest element in the surrounding square. Another advantage of the proposed way of forcing the sparse representation is related to the interpretation of visual features. Typically, the input image has a continuous content, which means that the same visual feature is likely to be detected in multiple neighboring locations. Let us consider a horizontal edge visible in the image as an example. In the case

of a long horizontal line in the image, the same local feature is obviously present in all the points of this line. If the sparse representation was related only to the number of activated pixels, there is a strong probability that we would obtain a subset of pixels forming a single connected component around the most visible feature (or, as in the example, along the line). LMF provides a direct solution to this problem. This situation is illustrated in Table 1.

**Table 1.** Local-maximum filtering (LMF) is a method that generates a sparse output, but is more practical than the standard thresholding. The points are designed to reflect the selected visual features of the image, and the local-maximum filtering makes it possible for each point to reflect a different occurrence of a feature. Thus, it is necessary to employ a mechanism preventing non-zeroed points from being located too close to each other. The strongest activated points are chosen in a greedy way, with only one point allowed in each *s* × *s* square. The presented illustration shows the result for *s* = 3.


### *3.4. Additional Thresholding*

Local-maximum filtering, which was described in the previous section, generates an output that can be regarded as sparse. In the case of the MNIST data set [7], where input data consists of 28 × 28 images, local-maximum filtering with radius *s* = 3 ensures that at most 49 elements of each 28 × 28 matrix remain non-zeroed. This means that either for the original data set or any larger input images, at most 6.25% elements of the output data have values other than zeros. It may be expected, however, that many of the 49 elements have insignificant, near-zero values anyways. Local-maximum filtering yields such an element in each isolated region of the image plane, even if the corresponding visual feature is not present in that region.

It is difficult to suggest any general purpose threshold for the selection of the significant points, as it may depend on the weights of the convolutional layer and on the context of the considered image. In some of the experiments that involve the original MNIST data set, where each image presents exactly one object, we manually limit the number of points in each matrix that are used to encode that object. After local-maximum filtering, which prevents the points from being located too close to each other, all the points except the *k* highest values are replaced with zeros. For *k* = 5 and *k* = 3, it yields 0.64% and 0.38% non-zero values, respectively. This is equivalent to image thresholding, with the threshold value dependent on the appropriate quantile of the values from the processed matrix. This highly sparse representation can be used to experimentally determine how much information is actually preserved in the small number of points.

### *3.5. Properties*

The sparse output obtained from the selected hidden CNN layer can be considered as a form of image encoding, as it is supposed to be used by subsequent layers to reconstruct the input image. Thus, a neural network architecture that meets the presented assumptions can be considered as a general purpose tool for sparse image encoding. The encoding is based on local visual features of the image, which may be easily explained as follows. One of the commonly known properties of

CNNs is the invariance to translation of the visual features of objects on the input image plane [2]. The translation of the object automatically results in a similar translation of its representation generated by the convolutional layer. This is essentially true for a single matrix convolution and remains relevant for sums and superpositions thereof. The outputs of the convolutional layers are known as feature maps, as CNNs take a biological inspiration from the visual cortex [2,33]. An important aspect is the size of the feature or object visible in an image under examination. Calculations performed in order to obtain each element of the feature map involve data from a specific range of the input image, known as the visual field. In the case of the initial convolutional layers, the visual field is significantly smaller than the image itself—for the first layer, it is simply equal to the filter size. If the visual fields of the layer with a sparse output contained the whole input image, whole objects could be encoded as single pixels. However, this would be equivalent to an image classification task, without the analysis of particular elements of the recognized object. In order to split the original image into more basic features, we use visual fields that are smaller than the image itself. In one of the examples, presented in the following section, we use 14 × 14 visual fields selected from 28 × 28 input images. The sizes of visual fields of a neural network are easy to estimate, particularly if the network consists of convolutional layers only, an example is shown in Figure 3.

**Figure 3.** Each output pixel depends on multiple elements from the previous layers. The scope of the related pixel from the previous convolutional layer matches the size of the convolutional filter applied. By tracking the dependencies back to the input data matrix, we can determine the size of the visual (receptive) field. The convolutions involve one-pixel padding from each side, so the image size does not change. A single output pixel is calculated on the basis of 3 × 3 minor of the hidden layer, and the size of the visual field is 5 × 5.

#### **4. Experiments**

### *4.1. Feature Identification*

The experiments described in this section were performed on the original MNIST data set [7], which offers the advantages of a large number of images, a resolution that makes the computational cost considerably low (28 × 28 pixels), and a simple semantic interpretation of the results, as the samples contain handwritten digits.

Three experiments were performed in three set-ups that implemented the idea presented in Figure 2. According to the original partition of the MNIST data set, the autoencoder models were trained with the 60,000 training samples, while the separate 10,000 samples were used for the evaluation. The models were different in terms of the visual features used to encode the image. The architecture of the decoder part, described in Table 2, was common for all the models. None of the models used any form of pooling, and the coexistence of filters and paddings made the matrix size remain unchanged throughout the layers. The presented models applied typical techniques associated with CNNs, such as the dropout method [34] and PReLU activation functions [35]. The encoders were designed as follows.

• MF4: Four manually designed features were used. The filters were fixed and no learning was performed on this encoder. The features were related to vertical, horizontal, and diagonal lines (in both diagonal directions). The contents of the proposed feature-detecting convolution filters are presented in Table 3. The filters applied in this experiment are of a very generic character, so no advantage may be drawn from using specific filters that fit the data set. Absolute values of the matrix convolution results are used, which is followed by local-maximum filtering with radius 3, as described in Section 3.3.


The encoder and decoder could be considered as separate utilities, but combining them into one neural network model made it possible to actually train the feature detectors in AF4 and AF5 experiments. The training was aimed at minimizing the total square error of the autoencoder.

The MF4 features are the most natural approach, as the features were designed manually in order to approximate any pattern that consists of thin lines. The four basic directions, shown in Table 3, fit the structure of a filter matrix precisely. Any change to this approach, such as a set of 3 or 5 segment-based features, would require an arbitrary choice of a direction and involve a specific approximation when described as convolutional filters.

As MF4—a solution with 4 kinds of features—was selected for its simplicity, the most direct comparison based on the automatic features identification involves 4 features as well, which is demonstrated by the AF4 set-up. However, automatic detection of features does not directly indicate any specific number of features as correct. The design of 5 equally important features for MNIST is unintuitive, but the potential gain can be easily researched for using the automatically trained encoder. The AF5 set-up was introduced for this purpose. The number of features can be expanded arbitrarily further, but as the number of features would grow, they would be increasingly difficult to visually distinguish. For the purpose of visual presentation of the results, we focus on a maximum number of 5 features. However, if the data set was more complex than MNIST or involved color images, it may be crucial to introduce more features.

The specifics of automatic encoder training process that are described in Table 4 were proposed as a compromise. This architecture is complex enough to identify potentially useful image features while avoiding the possible disadvantages of overly complex models, such as high resources usage and duplicated filters. The number of layers and the filter sizes were defined by the requirements on the visual fields, while the number of filters in each layer was selected by trial and error. The results were roughly convergent around the chosen preferred values. The possible changes obviously include permutations of feature detectors in the encoder output. The full training time was long enough to make the detailed parameter tuning remarkably difficult, but we believe that the presented models are sufficient to demonstrate the properties of the proposed methods.

The complete encoder architecture from Table 4 involves 28,570 adjustable parameters for AF4 and 29,570 for AF5. The slight difference is related to the last convolutional layer in the sequence. Remaining in a similar order of magnitude, the total number of decoder weights was 114,450 for AF4/MF4 and 115,200 for AF5, due to the additional filter group in the first convolutional layer. While the training process was relatively complex, we believe that the final model can be described as lightweight.

It must be emphasized that getting the optimal autoencoders available to this method would require much more detailed fine-tuning and repeated experiments. However, the presented demonstration of the method does not require putting this kind of endless effort to the optimization. We have defined three different set-ups, which are going to be useful for the analysis, and we use fixed training conditions for all of them, so we can adequately compare them with one another.

The autoencoder error was calculated as the difference between the input and the expected output. Data from the MNIST data set [7] could be considered as a set of 8-bit grayscale images with brightness

levels varying from 0 to 255. However, the presented results refer to normalized values from the [0, 1] range. This applies to the average errors from Table 5. The error for a single sample is a half of the sum of quadratic errors for all the pixels. The table presents average errors for a certain set of samples—both for the whole data set and for all the digits considered separately. The autoencoder itself, in accordance with the previously described architecture, did not use any information on object classes while training.

**Table 2.** All the experiments (namely, MF4, AF4, and AF5) related to the general architecture of the autoencoder, which is shown in the Figure 2, use the same layout of the decoder part. This table includes a detailed layout of the convolutional layers and the activation functions used in the decoder, such as PReLU [35]. The rows of the presented table describe consequent layers of the decoder CNN, denoted as Dec1–Dec4.


**Table 3.** One of the autoencoder-related experiments, labeled as MF4, uses fixed encoder filters (the encoder part of the overall architecture is described in Figure 2). This table presents the predefined values of 7 × 7 convolutional filters, visualized as bitmaps.

**Table 4.** Automatic feature extraction experiments, labeled as AF4 and AF5, use adjustable encoders (Figure 2) with multiple convolutional layers. The layout of the layers and the corresponding activation functions (including PReLU [35]) are presented. The ENCODE activation function is a short term for a sequence of operations: the PReLU activation function, the absolute value, the layer that modifies gradients with relation to kurtosis (Section 3.2), and local-maximum filtering with radius *s* = 3.



**Table 5.** The autoencoder errors, measured in terms of the Euclidean loss function, were calculated for all the proposed network architectures. In addition to the general error on the test set from the MNIST data set [7], specific values were calculated for each class separately. Thus, it was possible to evaluate how well the selected visual features described each of the digits.

The results presented in Table 5 prove a relative success of all the experiments. It is worth noting that the maximum quadratic error between 28 × 28 matrices is 784, and the expected quadratic difference between matrices of uniform random [0, 1] elements is 130.67. The errors obtained from the experiments presented are lower by a whole order of magnitude (per-subset average errors are more than 11 times smaller than the mentioned estimation). The only limitation, which leads to the presumption that error of zero is impossible for the MNIST data set, is based on the sparse encoding that needs to be used as an intermediate sample representation. Due to the specific properties of this sparse representation, there is no other comparison. The training of the decoders was performed in a unified way for all the set-ups, so the results from Table 5 reflect the usefulness of features selected by the encoders. Therefore, in absence of more general ground truth, MF4 results can be considered as reference values for evaluation of AF4- and AF5-based features.

The first conclusion is that the features specific to the data set performed better than the generic manual suggestion—the MF4 experiment resulted in the highest autoencoder errors. The difference between the results of the automatic variants with 4 and 5 features appears to be slight when compared to MF4. It may be also concluded that using a higher number of features makes the encoding more precise, i.e., it enables preserving more information about the exact contents of the original image. Surprisingly, the results for digit 2 are slightly better in the case of AF4 than in AF5, which is an exception to the mentioned rule.

The differences between classes can be explained by the geometric properties of the digits. Digit 0, which is round, generated a particularly high error in MF4, as lines of fixed directions made it difficult to recognize round shapes. The error for 0 in MF4 was even higher than for 8, which contains crossing diagonal line segments in the center—the direction of these segments apparently fits the designed filters. Remarkably, the lowest errors for MF4 were obtained for digits that literally consist of straight segments, namely, 1 and 7. While digit 9 was the third best, 4 was the close fourth, which fits the pattern, as 4 consists of long segments and 9 has a small circular head and a straight, long tail.

The comparison of AF4 and AF5 error rates provided a number of other important observations. In both experiments, 8 was the worst case, which can be justified by the most visually complex shape—a single line that crosses itself and forms two circles is especially difficult to describe with features obtained as the result of convolutional filters. The other digits with significantly high error rates were 0, 2, and 6. For MF4, the digits that contained circles (0 and 6) produced high error rates, while for MF5, the second worst case was 2. It suggests that MF5 was able to handle the features

characterized with small circle shapes better than MF4, partly at the cost of segments specific to digit 2. The difference between errors obtained in MF4 and MF5 is the highest for 0, 9, 5, and 6; remarkably, three of these digits have shapes containing circles.

As was the case in MF4, and also in the other experiments, the lowest error rates were generated for 1 and 7. The property of these digits, which can be summarized as *having a simple shape*, seems to be pretty universal, as confirmed by the results obtained for AF4 and AF5.

### *4.2. Feature Reduction*

The experiments presented in the previous section involve local-maximum filtering, which ensures that at most 6.25% of matrix elements are non-zeros. In this section, however, the results related to even higher levels of sparsity are considered. The number of zeroed elements in the encoding is increased, but exactly the same decoders, trained in Section 4.1, are used to generate the results presented below.

Figures 4–6 include the results of sparse matrices decoding for MF4, AF4 and AF5 experiments respectively. Each table includes the following.


**Figure 4.** The pretrained decoder from the MF4 model can be used either with the original data without a specific limit of non-zero elements in the encoding, or with modified encodings, where each matrix contains up to 3 or 5 non-zero elements. The plot presents decoding errors for images showing individual digits and for the whole test set.

**Figure 5.** The pretrained decoder from AF4 used for decoding of both the original and the highly sparse data, as in Figure 4.

**Figure 6.** The pretrained decoder from AF5 used for decoding of both the original and the highly sparse data, as in Figure 4.

As we can conclude from Figures 4–6, experiment AF4 seems to be most sensitive to additional thresholding, which is particularly evident in the case of encoding digits 2, 3, and 5. However, the other experiments, namely, MF4 and AF5, behave in quite a similar way, giving slightly above 20% greater average loss when 3 points per matrix are used, and only a few percent in the case of 5 points.

The most remarkable phenomenon related to experiment AF5 is the sensitivity of digit 4 to sparse autoencoding. The autoencoder error increased by 5% for 5 points and above 50% for 3 points. This leads to the conclusion that digit 4 consists of a greater number of visual feature occurrences than any other digit, and omitting some of these features generates a significant error.

The most important conclusion is that further sparsity enforcement is generally acceptable, unless the features are too specific (AF4 case) and the reduction is too great (3 points case). With 5 points per matrix, both MF4 (error increase: up to 4%, 3% in average) and AF5 (error increase: up to 5%, only 1% in average) models yielded acceptable results. This means that the whole 28 × 28 digit can be compressed into 20 points (in the case of MF4) or 25 points (AF5), with encoding errors presented in Figures 4 and 6.

#### *4.3. Classification*

In order to determine how much information was preserved in the encoding, we attempted to decode the original image, as described in the previous sections. However, it is not the only possible approach. It is debatable whether the Euclidean distance between the autoencoder output and the original image may serve as a reliable tool for measuring the loss of significant information in the encoding process. However, regardless of the Euclidean distance value, the encoding can be considered as useful if it is sufficient to determine the originally encoded digit. This property can be tested in the image classification task using pre-generated encodings. Another reason for performing this experiment is the possibility to discuss the relation of our results to the numerous classification results from the literature, where a similar task was performed on the same data set.

The sparse representations obtained from the encoder (according to the description shown in Figure 2) can not only be used to decode the original digit, but also directly in the image classification task. All the experiments (MF4, AF4, and AF5) were performed on the basis of a CNN classifier architecture proposed by the authors of this study. The classifier consisted of 6 convolutional layers and two hidden fully connected layers. The last convolutional layer and the hidden fully connected ones were trained using the dropout method [34]. Such an approach was decided, as it should provide adequately complex classifier model to achieve fine results without defining a very deep neural network which would require specific approach to the problem of a vanishing gradient. A model with 30 convolutional filters in each layer and 500 neurons in the hidden fully connected layer was selected as a point where no further extension improved the result significantly. The presented values indicate that the trained classifier models consisted of less than 50,000 convolutional parameters and approximately 12 millions of weights of the fully-connected layers. It must be emphasized that finding the optimal classifier model was not the key objective of this paper. The selected classifier configuration is possibly similar for all the encodings, and the results are well adjusted to the task of comparison between the setups. Further effort to optimize such a classifier remains possible, but this issue alone definitely exceeds the scope of this paper.

For the classification tests, the data set was divided into a testing set (10,000 samples) and a training set (60,000 samples), as proposed in the original MNIST [7] database. Each classifier was tested with a representation obtained by a specific autoencoder. This architecture made it possible to perform additional experiments. Instead of a raw encoder output, where up to 49 pixels from each matrix could have positive values, manually thresholded matrices were used in order to eliminate near-zero values. The data prepared in this way are used in the same tasks as described in the previous sections. It must be emphasized that the same classifier models were used for both the original encodings and the thresholded versions. All the results are presented in Table 6.

**Table 6.** Encodings from Section 4.2 were tested in the image classification task. Each encoding was used both in the original form, obtained as a result of local-maximum filtering, and in the reduced form that guarantees additional sparsity (Section 3.4). For each encoding (MF4, AF4, and AF5) a separate classifier was trained. The results for the additionally thresholded data were generated with the same neural networks that were trained for the original encodings.


As we can deduce from Table 6, the accuracy of the classifier seems to reflect the autoencoder error from the previous tables. Thus, MF4 results are clearly the worst—the general features are not nearly as useful as the automatically calculated ones that were used in experiments AF4 and AF5. The only surprise is that the AF4 classifier on the full encoder results was the best from the whole table (98.67%)—the difference is slight, but the classifier related to AF5 made more mistakes. However, when the reduced representations are considered, the sensitivity of the representations encoded in AF4 to the additional thresholding is clearly visible, as was the case with the autoencoder. AF5 representations, when reduced to five points per matrix, resulted in as good results as in the case of the original classifier objective, providing an accuracy of 98.40%. Surprisingly, the representations reduced to 3 points per matrix, despite generating over 20% higher autoencoder error, can still be regarded as acceptable for practical applications, as even with so drastically reduced information the classifier is able to recognize the digit correctly in 97.88% of the cases.

As the results from Table 6 are denoted as classification accuracies that can be easily compared to each other, we can seek comparison with other MNIST classifiers from the literature as well. However, it must be emphasized that in this paper we treat the image classification just as an analytic tool, and not as the key objective of this paper. Presumably, using the raw MNIST images to train a classifier, without the added difficulty of sparse representation, could only improve the achieved accuracy. The general problem of MNIST classification can be solved with accuracy as great as 99.79% [36]. We do not pursue to beat this result. For broader perspective, we can discuss the relation of our results to the other state-of-art MNIST classifiers that somehow involve sparse representations. Due to the varying objectives and circumstances, such comparisons require analysis that exceeds the straightforward competition for the best accuracy.

The results from Table 6 are clearly better than the classification results obtained with the classical approach to sparse representations and dictionary learning. This includes particularly the convolutional sparse coding for classification (CSCC) method presented in [23], which achieved an accuracy of 84.5% on the MNIST data set, outperforming many previous approaches to sparse representations and dictionary learning. It must be emphasized, however, that the problem statement of that paper was not the same as ours. Dictionary-based methods are more computationally complex. Moreover, in [23], only the training subsets of 300 images were used. Thus, while that work may be regarded as an interesting reference for the present study, a direct comparison would be inappropriate.

Another remarkable work on sparse representation was based on the idea of maximizing the margin between classes in the sparse representation-based classification (SRC) task [25]. The sparse representations related to this model were strictly related to the classification task. In contrast to that approach, the method presented in this paper does not use any information on object labels when training the encoder. On the other hand, no convolutional neural networks were used in [25], and some solutions used in that paper might be outdated. The best classifier presented there reached a 98.13% accuracy rate. This result is lower than AF5 with 5-point-based reduction, which is already very sparse.

The CNN-based architecture ensures that the image features are detected is a translation-invariant manner; translation of a feature would entail translated coding. A similar concept was applied in [22], which proposed another approach to CNN-related sparse coding. The results of MNIST classification were generated for both the unsupervised and the supervised approach to sparse coding, with 97.2% and 98.9% accuracy rates, respectively. It must be emphasized, however, that the method shown in the present paper should be considered as unsupervised, as the autoencoder does not use information on the object labels. The size of the input data is not fixed—the method works for any input data, irrespective of the number of rows and columns. Thus, we cannot speak of a class that an input belongs to and some valid input images can contain multiple digits, which makes it impossible to assign them to a single label.

The results of MNIST classification that are somehow related to the idea of using sparse coding in the hidden layers in image processing tasks are also known from the works on spiking neural networks. A solution which involved weight and threshold balancing [31] performed reasonably well, resulting in a 99.14% accuracy rate in a method that combined spiking neural networks and CNNs. However, the method proposed in [31] was very complicated and the image representations that it produced were not as sparse as those presented in this paper. Similar remarks hold with respect to the work in [28] (non-CNN spiking network with LIF neurons) and the work in [31] (bio-inspired spiking CNNs with sparse coding), which achieved the accuracy of 98.37% and 97.5%, respectively. The latter approach is particularly interesting, as it was coupled with a visual analysis of features recognized by the neural network. The accuracy rates achieved were slightly lower than these obtained in this paper. However, the results in [31] cannot be directly compared with these achieved in this study because of differences in the architectures proposed. Moreover, the work in [31] involved an additional learning objective—the classifier was designed and trained to handle noisy input.

#### *4.4. Larger Images*

All our previous experiments were related to the original MNIST data set [7], where each sample was a 28 × 28 image displaying a single centered digit. The autoencoder was designed to encode each digit in a way that enabled as accurate a reconstruction of that digit as possible. As the solution is based on CNNs (both encoder and decoder, as it is shown in Figure 2), the whole mechanism is translation invariant—-a translated digit would simply yield a translated sparse representation. What is more, as no pooling layers are used, the model can be successfully applied to images of any size. Both matrix convolutions and element-wise operations will still be possible to be computed.

The modified data set with larger images was prepared to illustrate this property, as shown in Figure 7. The digits were placed on 80 × 80 plane in a greedy way, as long as placing another non-intersecting 28 × 28 square was possible. The test set consisted of 2000 images: 68 with a single digit, 119 with two digits, 888 with three digits, and 925 with four digits each.

The features described in the proposed sparse representation are deliberately smaller than the whole digits, so our model should not be considered as digit classifier, in particular for larger, more complex images. Nevertheless, digits should be reconstructed equally well regardless of position and context. The experiment introduced in this section is intended to demonstrate this property.

Table 7 shows the average per-digit autoencoder errors for the extended data set. In the case of images with multiple digits, the error was divided by their count. The division into separate classes was impossible, as a single large image was likely to contain multiple digits from different classes. The overall conclusion is that the MF4 model is quite sensitive to the behavior of image boundaries and, while useful, produces almost 10% greater errors in this atypical application. The models with

automatically calculated features—AF4 and AF5—provided a very slight error increase when compared to the original task. This confirms the universal nature of the presented autoencoders. As expected, translational invariance makes it possible to describe the translated objects as easily as the original inputs. The application of extended input sizes does not create any technical difficulties either.

**Figure 7.** In order to demonstrate that the proposed autoencoder architecture is size-independent, 80 × 80 images containing multiple digits from the MNIST data set [7] were generated. The extended images contain up to four objects.

**Table 7.** The autoencoders trained in Section 4.1 were used with larger images, as shown in Figure 7. An average per-object error was calculated and compared to the original results, related to the objects with a single centered object. The information on the relative error increase is included in the table as well.


### **5. Analysis**

The autoencoder architectures presented in Section 4.1 can be further analyzed in terms of errors and semantic understanding of related visual features. The results from Sections 4.1 and 4.2 can be used to compare the overall quality of selected solutions and analyze the dependency between the autoencoder errors and particular image classes. This could be related to the level of adjustment of the set of selected features to the data set, for example, manually selected directions of lines are particularly irrelevant in the case of digits 0 and 5. Another important aspect is the inner complexity of the digits. Digit 1, which usually consists of one or two segments, is especially easy to model. The exact directions of the segments, however, do not fit any of the manually selected filters. Thus, only an automatically computed feature extractor was able to take full advantage of the simplicity of the shape of this digit, producing a remarkably low error rate for this class. The errors for the other nine classes differ only slightly in the case of the automatically chosen features, which means that this solution actually reflects the properties of the data set. Manually selected features were not digit-specific, which resulted in generally higher error rates and greater variance of errors among different digits in MF4 experiment.

In the case of the manually designed features, the convolution filters were created arbitrarily, as illustrated in Table 3. However, another way of visual presentation of the features can be achieved by using the decoder part of the autoencoder architecture (see Figure 2). The results of decoding

single points related to each of the manually chosen features are presented in Table 8—the relation to the filters shown in Table 3 is apparent. This technique of visualization can also be applied to the models with the automatically constructed features. The results of this approach are reported in Tables 9 and 10—apparently, this time the features are implicit and difficult to categorize semantically.

As the shape of each visual feature is complex, and possibly context-dependent, we need an analysis that goes beyond simple visualization. Tables 11–13 provide information on the intensity of each feature in the data set, including both averaged results and those obtained for separate subsets consisting of different digits. The intensity is calculated as a sum of the elements of the respective encoding matrix (output of the encoder, as illustrated in Figure 2). It must be emphasized that because of different weights in neural network models related to each decoder/classifier, it is possible to compare only the values within the same table. However, the results presented still enable a thorough analysis that otherwise would be difficult to perform.

The MF4 results (see Table 9) are especially easy to understand, as digits are rather taller than wide, vertical lines (feature #2) are the most visible among the data set. Remarkably, digit 1 contains almost no other features. The least intense feature is related to the backslash segments (#4), which occur mostly as a part of arc (digits 8, 3, and 0). As it was expected, digit 8 is especially rich in all kinds of features. However, because of the typically skewed writing style, even in this particular case backslash lines are less intense than segments in the other directions.

The features selected in the AF4 model are particularly interesting. It is relatively clear that no feature is dedicated solely to backslash lines, which is demonstrated in Table 10. Instead, we get feature #1 that seems to address the right side of a small arc. This is reflected in Table 11—digits 8, 6, and 3 exhibited particularly intense occurrences of this feature. While features #2 and #3 seem to reflect vertical and horizontal lines, respectively (digit 1 is strongly correlated with feature #2), the relation is not nearly as straightforward as it was in MF4. Feature #4, related to slash lines, seems to reflect a part of digit 2 especially well.

The results of the experiment AF5, which yielded the best autoencoder (Section 4.1) and classification accuracy for highly sparse data (Section 4.3), are less intuitive. The idea of horizontal lines is divided between features #2 and #4. Features #1 and #5 seem to handle both slash/backslash lines and sections of small arcs as well—both these features are important for encoding digit 8. Feature #3 clearly describes some cases of curves that are oriented vertically (digits 2, 8, and 6), but it does not involve straight segments. Digit 1 is described mostly with feature #4. Visual features detected by AF5 model are mostly implicit and difficult to describe semantically.


**Table 8.** MF4 decoder results for synthetic encoding, where only one point is activated. This is intended to show the shape of the visual features that are described by each matrix.

**Table 9.** For each visual feature MF4, the sum of occurrence intensities (values in the encoding) was calculated. The result was considered separately for each digit, as different digits consisted of different visual features. It must be emphasized that the results in this table are relative and, while comparing them to each other is noteworthy, they should not be compared with the results from the other tables.


**Table 10.** AF4 decoder results for synthetic encoding, where only one point is activated.


**Table 11.** The sums of occurrence intensities (values in the encoding) for each visual feature AF4. It must be emphasized that the results in this table are relative, so they can only be compared to the numbers from the same table.



**Table 12.** AF5 decoder results for synthetic encoding, where only one point is activated.

**Table 13.** The sums of occurrence intensities (values in the encoding) for each visual feature AF5. It must be emphasized that the results in this table are relative, so they can only be compared to the numbers from the same table.


The conclusions drawn from Tables 9, 11, and 13 can be further explained with proper illustrations. Decoding a single point did not provide a satisfactory understanding of the visual features (as was the case with feature #3 in AF5), which provides motivation to search for a better way of feature visualization. Instead, we can use the decoder for the actual multiple-pixel combinations generated for inputs from the test set. The sparse representation can be further split into separate matrices, related to different visual features. The decoding of a single matrix can be considered as partial reconstruction, which consists only of occurrences of the corresponding feature. For example, using only the first channel of the model with manually selected features will result in a partial reconstruction of a digit that consists of horizontal lines only. Additionally, in order to explain the limitations of presented methods, specific digits with especially low and especially high autoencoder errors were chosen for this visualization. Selected results for the discussed models are presented in Figures 8–10.

The rows described as encoding in Figures 8–13 contain the visualizations of tuples of sparse matrices. For the sake of readability, the active elements, which are naturally rare, were magnified threefold. The features row provides information on the distribution of particular types of visual features on the image plane. This may be associated with a particular digit segment that possesses that feature. It must be emphasized, however, that the decoder output cannot be described as a sum of separate features—the CNN-based decoder function is nonlinear and non-additive. The encoded information on the presence of a selected feature indicates not only the presence of specific digit segments associated with that feature, but it may also be indicative of the absence of other features, as is the case with digit 8 (Figure 8).

A similar approach was applied to larger images. In order to include even more information in the presented visualization, images containing both digits and other symbols were used. The results are shown in Figures 11–13. The autoencoders were trained in a digit-specific way, but the test images used in this case contain other symbols as well. The selected non-digit shapes, however, generated visible errors—some segments were erroneously enlarged, merged, broken into pieces or blurred.

An additional demonstration of the proposed method using a longer text fragment, which is a 512 × 128 scan of a postal address, is presented in Figures 14–16. They depict the encoding contents and decoder outputs for the three proposed models. The number of active elements in the sparse encoding is proportional to the image size, which is visibly larger than in the other examples. Digits are clearly readable in the decoder output, as their size is roughly similar to the MNIST samples. Some of the most significant errors occur for the pairs of letters that are especially close to each other, which are visible in the second line of text.


**Figure 8.** Sample encoding of selected digits performed by MF4 model. Apart from the input and output data and highly-sparse encoding (up to 5 non-zero elements in each encoding matrix), visualizations of single features are presented. Each visualization was acquired by decoding a synthetic code, where one of the visualized matrices was used, and all the other matrices were filled with zeros.


**Figure 9.** Sample encoding of selected digits performed by AF4 model. Highly sparse encoding (up to 5 non-zero elements in each encoding matrix) was used. Feature-specific decodings are presented, similar to those in Figure 8.


**Figure 10.** Sample encoding of selected digits performed by AF5 model. Highly sparse encoding (up to 5 non-zero elements in each encoding matrix) was used. Feature-specific decodings are presented, similar to those in Figure 8.


**Figure 11.** Sample encoding of a larger image that contains both digits and other symbols, performed with the MF4 model. The presented feature-specific decodings were generated in the same way as those presented in Figure 8.


**Figure 12.** Sample encoding of a larger image that contains both digits and other symbols, performed with the AF4 model. The presented feature-specific decodings were generated in the same way as those presented in Figure 8.



**Figure 14.** Sample encoding of a postal address scan that contains both digits and letters, performed with the MF4 model. The presented feature-specific decodings were generated in the same way as those shown in Figure 8.


**Figure 15.** Sample encoding of a postal address scan that contains both digits and letters, performed with the AF4 model. The presented feature-specific decodings were generated in the same way as those shown in Figure 8.


**Figure 16.** Sample encoding of a postal address scan that contains both digits and letters, performed with the AF5 model. The presented feature-specific decodings were generated in the same way as those shown in Figure 8.

The presented illustrations involved 28 × 28 MNIST samples, 80 × 80 images with multiple symbols and 512 × 128 scan of a postal address. However, the method is scalable for any size of an input image. The execution time is proportional to the number of pixels, as the relations presented in the Figure 17 are roughly quadratic. The presented calculation time is very small, despite the standard GPU set-up being used for a single image. Processing multiple images of similar size in batches would reduce the average per-image processing time even further.

**Figure 17.** Processing times of MF4, AF4, and AF5 models for square images of different sizes. As in the case of the standard training and test process, these results were achieved using GPU.

### **6. Conclusions**

This paper has presented a novel method of image content representation. In our approach, we propose to encode an image as a tuple of sparse matrices that describes the intensity and position of selected visual features occurring in the image. The method was validated through a series of experiments on the MNIST data set [7]. The presented simple variant, in which each matrix was reduced to local maxima, provided the ability to generate sparse matrices where no more than 6.25% of elements were preserved. However, as revealed by further analysis, very sparse matrices with no more than five elements preserved in each 28 × 28 matrix (which is less than 0.64% of elements preserved) were sufficient to keep a low autoencoder error rate and obtain a classifier accuracy of 98.40%.

The application of the method to the classification task provided satisfactory results, outperforming the classical sparse coding solutions [23,25]. In our approach, the method of encoding was the same, regardless of the image class, because no class-dependent information was used in the training process. Thus, the presented method can be considered unsupervised. This is a relevant factor that has to be taken into account when comparing the results of the present study with those in [22]—our results are not as good as those of the supervised variant presented in that work but outperform those of the unsupervised variant.

The presented classifier also performs better than most solutions based on spiking neural networks [28,29]. It is not as good as some models presented in [31], but it must be emphasized that the classification accuracy comparison is not the key aspect in this particular case. Spiking neural networks possess different properties, such as the ability to handle noise [29]. Some of the models, such as LIF neurons [27,28] or STDP learning [29], suffer from too high complexity, or unsatisfactory sparsity level of the generated representation (especially in [31]).

The sparse representation based on visual features made it possible to perform a detailed analysis of the nature of the selected features. We presented both context-free feature visualizations (see Tables 8, 10, and 12) and digit-specific distribution of particular features (see Figures 8–13). Additionally, per-digit statistics of feature occurrences were discussed (see Tables 9, 11, and 13).

The presented method is based solely on convolutional layers and point-wise operations (activation functions). This makes the feature detection invariant to translation. Consequently, the pretrained autoencoder can process the input images of any size, which was demonstrated on the basis of a data set with larger images, each containing multiple digits (see Section 4.4). Experiments using non-digit symbols were also performed (see Figures 11–13), in order to show that the generated features were adjusted to the selected data set—not surprisingly, the autoencoder trained in the proposed way performed markedly worse on non-digit characters.

The study was based on the MNIST data set [7], which is relatively easy to analyze. However, further tests are needed to verify the applicability of the method to more complex databases, such as the ImageNet data set [37] or data from the Pascal Visual Object Classes Challenge [38]. In order to be applied to images of real-life objects, the method would have to be enhanced with the ability to recognize more complex and more numerous visual features.

The different visual features used in the presented approach are sensitive to object rotation and size. In particular, the images from the MNIST data set were easy to describe with lines, and the rotation of the line was the key element that identified the features. However, the task of feature detection in different rotations can be approached in multiple ways. Recent works on CNNs have provided new advances to transformation-invariant CNNs [14,39,40]. The prospect of combining those methods with the proposed representation, resulting in additional information about orientation and scale of detected keypoints, is another promising option worth pursuing for further development.

**Author Contributions:** Conceptualization, A.T., P.T., and B.S.; methodology, P.T. and A.T.; software, P.T.; validation, P.T.; formal analysis, P.T.; investigation, P.T., A.T., and B.S.; resources, P.T. and A.T.; data curation, P.T.; writing—original draft preparation, P.T, A.T., and B.S.; writing—review and editing, P.T., A.T., and B.S.; visualization, P.T.; supervision, A.T.; project administration, A.T.; funding acquisition, A.T. All authors have read and agreed to the published version of the manuscript.

**Funding:** This project has been partly funded with support from the National Science Centre, Republic of Poland, Decision Number DEC-2012/05/D/ST6/03091.

**Conflicts of Interest:** The authors declare no conflicts of interest.

### **References**


© 2020 by the authors. Licensee MDPI, Basel, Switzerland. This article is an open access article distributed under the terms and conditions of the Creative Commons Attribution (CC BY) license (http://creativecommons.org/licenses/by/4.0/).

## *Article* **Fast Self-Adaptive Digital Camouflage Design Method Based on Deep Learning**

### **Houdi Xiao, Zhipeng Qu \*, Mingyun Lv, Yi Jiang, Chuanzhi Wang and Ruiru Qin**

School of Aeronautic Science and Engineering, Beihang University, No. 37 Xueyuan Road, Haidian District, Beijing 100191, China; xhdbuaa@buaa.edu.cn (H.X.); lv503@buaa.edu.cn (M.L.); jiangyi312@buaa.edu.cn (Y.J.); wangchuanzhi@buaa.edu.cn (C.W.); qrraxsy@buaa.edu.cn (R.Q.)

**\*** Correspondence: quzhipeng@buaa.edu.cn

Received: 20 June 2020; Accepted: 28 July 2020; Published: 30 July 2020

**Abstract:** Traditional digital camouflage is mainly designed for a single background and state. Its camouflage performance is appealing in the specified time and place, but with the change of place, season, and time, its camouflage performance is greatly weakened. Therefore, camouflage technology, which can change with the environment in real-time, is the inevitable development direction of the military camouflage field in the future. In this paper, a fast-self-adaptive digital camouflage design method based on deep learning is proposed for the new generation of adaptive optical camouflage. Firstly, we trained a YOLOv3 model that could identify four typical military targets with mean average precision (mAP) of 91.55%. Secondly, a pre-trained deepfillv1 model was used to design the preliminary camouflage texture. Finally, the preliminary camouflage texture was standardized by the k-means algorithm. The experimental results show that the camouflage pattern designed by our proposed method is consistent with the background in texture and semantics, and has excellent camouflage performance in optical camouflage. Meanwhile, the whole pattern generation process takes a short time, less than 0.4 s, which meets the camouflage design requirements of the near-real-time camouflage in the future.

**Keywords:** adaptive camouflage; convolutional neural network (CNN); k-means; object detection; image completion; machine learning; saliency detection

### **1. Introduction**

Camouflage is the most common and effective means to combat military reconnaissance [1,2]. It can conceal military equipment in natural environments. With the development of camouflage technology, optical camouflage has evolved from deformable camouflage to digital camouflage [3]. However, traditional digital camouflage is mainly designed for a single background and state [4]. In the traditional digital camouflage design method, the colors are only the main color of a specific environment, and the texture is formed by the non-random algorithm arrangement of finite pattern templates [5]. The traditional way of realizing digital camouflage is to coat the equipment surface with camouflage paint according to the designed camouflage texture or to wear or cover the fabric with camouflage texture. There is a limitation that the camouflage performance of a specified time and place is appealing, and the camouflage performance is greatly weakened when the location, season, and time changes. Cross-region, multi-season, multi-period camouflage has become a new military demand for modern weapons and equipment. Therefore, the camouflage technology, which can change with the environment in real-time has become the inevitable direction of the development of the military camouflage field. During the last decades, much effort has been directed toward achieving this goal. To realize multi-region adaptive camouflage, texture and color must not be fixed, real-time camouflage texture is designed according to the changes in the environment. Different from the traditional camouflage, the implementation way needs to use a controllable multi-color variable

material. By fabricating the controllable color-changing material into color-changing units embedded on the surface of the camouflaged target, just like the cells that make up the chameleon's skin, we can refer to the whole device as camouflage skin. The external control system controls the camouflage skin to display the design's camouflage texture.

Researchers find inspiration from nature. It is well known that there are loads of animals in nature with excellent camouflage abilities, such as cephalopods (like squid and cuttlefish) [6–9], chameleons [10], some insects [11], and so on. These animals have amazing control over their appearance (such as color, contrast, pattern). The principle of animal camouflage is to use the nervous system to sense changes in the environment and control the cells on the skin to change into different colors and textures according to the environment. Inspired by the principle of animal camouflage, researchers have designed various types of color-changing devices or camouflage samples [12–14], to mimic cells on the surface of the animal's skin. Single material for all colors has been reported [15–17]; it is an inverse polymer-gel opal which is prepared from an electroactive material. It can be stimulated by an electric field to change colors. In addition, devices that use magnetic field stimulation to achieve color changes have also been reported [18]. A mechanical chameleon based on dynamic plasmonic-nanostructures has been designed [14]. At present, most researchers focus on the technology of controllable color change, but there are few reports on the design method of new generation self-adaptive camouflage texture [19].

As one of the key technologies of adaptive optical camouflage, the study of adaptive camouflage texture design has important theoretical and practical significance. In this paper, the design method of adaptive camouflage texture for typical military targets in the natural environment is studied. With the development of computer vision technology, deep learning has been applied to various image processing scenarios. It has been used for image classification [20–22], object detection [23–25], image segmentation [25–28], image completion [29–31], and so on, and has achieved a series of amazing results. However, there are few reports on the application of the current achievements of deep learning to the field of military camouflage. We wondered if deep learning could be used to mimic the way that the animal's nervous system senses the environment and controls skin cells to change color and texture. Hence, we proposed a fast-self-adaptive digital camouflage design method based on deep learning. The method can realize the recognition of camouflage target and the design of adaptive camouflage pattern in near real-time. Firstly, we used the YOLOv3 algorithm to realize the recognition of typical military targets. Secondly, we used the deepfillv1 algorithm to realize the preliminary design of adaptive camouflage texture. Finally, the k-means algorithm was used to realize the standardization of adaptive camouflage texture. The experimental results showed that the camouflage pattern designed by our proposed method is consistent with the background in texture and semantics. It has excellent camouflage performance in optical camouflage. The whole process took less than 0.4 s. All experiments are implemented on Python3.6, TensorFlow v1.6, CUDNN v7.1, CUDA v9.2, and run on hardware with a CPU Intel Core I7-9700F (3.0 GHz) and GPU RTX 2080 Ti. The proposed camouflage pattern design method has potential application value in future real-time optical camouflage.

#### **2. Literature Review**

Military camouflage colors and patterns have evolved throughout history to improve their effectiveness, with each variant designed for a specific environment. Therefore, camouflage patterns are only effective in areas where the local background remains relatively constant. For a military system to operate in a variety of environments, its camouflage must be adjusted accordingly [32]. As a result, researchers around the world are beginning to design adaptive camouflage techniques that can change the color and texture of surfaces, like chameleons or octopuses, depending on their environment. Some researchers propose to project the collected background image on the surface of the target to achieve the purpose of camouflage. Inami et al. [33,34] designed an active camouflage system. The system first obtains the real-time background image through the image acquisition device installed on the back of the target and then projects the display scheme calculated from the observer's

perspective onto the target surface with reflective materials. Morin et al. [13] used microfluid network technology to prepare a soft camouflaging robot. It could change the color, contrast, pattern, apparent shape, luminescence, and surface temperature. The researchers changed the robot's color, pattern, and so on by filling the tiny tubes with different colored liquids. Pezeshkian et al. [32] proposed the use of gray-level co-occurrence matrices to synthesize a texture similar to the background. Then, the texture is displayed on the surface of the battlefield reconnaissance robot by electronic paper display technology to achieve the purpose of camouflage. Inspired by the skin discoloration principle of cephalopods, Yu et al. [35] used a thermochromic material to prepare a photoelectric camouflage demonstration system that could transform between black and white. The color-changing material is colorless and transparent when the temperature is higher than 47 ◦C, and black when the temperature is lower than 47 ◦C. By controlling the temperature of each unit, the researchers can display different patterns. Unfortunately, only black and white patterns can be displayed. Wang et al. [14] used the adjustable plasmon technology to prepare a color-changing device that could cover the whole visible band and then developed a bionic mechanical chameleon. The mechanical chameleon could sense color changes in its environment and actively change its own color to match the color of its environment. It is a pity that the author did not study the design method of camouflage texture. So far, researchers have focused on how to design and implement color change, but few have studied how to design appropriate adaptive camouflage textures. The existing researches on camouflage texture mainly focus on traditional camouflage texture. For example, Xue et al. [5] designed digital camouflage textures based on recursive overlapping of pattern templates. Zhang et al. [36] proposed a digital camouflage design method based on a space color mixing model. The model can simulate the color-mixing process in the aspects of color-mixing order, shape, and position of color-mixing spot. Jia et al. [37] proposed a camouflage pattern design method based on spot combination. The core idea of the above design method is random or non-random arrangement of finite templates. Due to the simplicity of the template, these traditional design methods cannot meet the needs of the new generation of adaptive camouflage texture design. Therefore, it is necessary to study the design method of adaptive camouflage texture.

### **3. Methodology**

### *3.1. Outline of Proposed Method*

The essence of optical camouflage is to make it impossible for human eyes or optical cameras to distinguish a target from its environment. This is similar to target removal in image processing. Inspired by this, we applied the image completion algorithm to the design of camouflage texture. In this paper, we provide a quick method based on the convolutional neural network to generate adaptive digital camouflage. The flowchart of our method is shown in Figure 1. To achieve camouflage of the target, we need to identify the target that needs camouflage. First of all, we used the YOLOv3 [38] algorithm to conduct recognition model training for four typical military targets. After adjusting the hyper parameters, we obtained a model with good recognition probability. Secondly, we masked the target area and entered it into the deepfillv1 [39] model that was pre-trained by places2 [40] for image completion. In this step, we obtain the preliminary camouflage texture. Thirdly, we used the k-means algorithm to calculate the main color of the filled area and compared it with the military standard color. The most similar color in the standard was selected as the main color, and the corresponding color was replaced. At this point, we have the final adaptive camouflage texture. The digital camouflage generated by this method is consistent with the texture of the surrounding environment. This method can generate visually plausible camouflage pattern structures and textures.

**Figure 1.** Outline of proposed digital camouflage design method.

### *3.2. Dataset*

In this paper, the Imagenet2012 [41] and Places2 [40] data sets were used. The Imagenet2012 [41] dataset consists of over 1.28 million images, containing 1000 categories, with the number of images per category ranging from 732 to 1300. Four typical military target images were selected from the ImageNet2012 [41]. This dataset was segmented into a training and test set. The four typical targets are airships, aircraft carriers, tanks, and uniformed soldiers. A total of 2187 images were selected from the dataset, of which 1970 were used for training and 217 for testing. There are no less than 500 pictures in each category. Table 1 shows the number of train and test images for the different categories. Figure 2 shows one sample for airships, aircraft carriers, tanks, and uniformed soldier images, which was selected from Imagenet2012.

**Figure 2.** Data samples from the Imagenet2012 dataset. (**a**) airship; (**b**) aircraft carrier; (**c**) tank; (**d**) uniformed soldier.

The Place2 [35] dataset contains more than 400 different types of scene environments and 10 million images. Basically, covering people's common scenes. Figure 3 shows one sample for forest, desert, grassland, and snowfield environment images, which was selected from Places2.


**Table 1.** Details of the training and test set.

**Figure 3.** Data samples from the Places2 dataset. (**a**) forest; (**b**) desert; (**c**) grassland; (**d**) snowfield.

### *3.3. Military Target Detection Based on YOLOv3*

To achieve camouflage of the target, we first needed to identify the target that needs camouflage. The YOLOv3 [38] algorithm was used to identify military targets. The Yolo series algorithm is an algorithm that could detect objects quickly [38,42,43]. The YOLOv3 is the latest version [38]. YOLOv3 can achieve high precision real-time detection, which is very suitable for our application background. Its network structure is shown in Figure 4. The resolution of the input picture in the network structure diagram is 416 × 416 × 3 (in fact, it can be any resolution.), and has four labeled classes. It uses darknet-53, which removes the full connection layer, as the backbone network. The YOLOv3 is a fully convoluted network that makes extensive use of residual network structures. As shown in Figure 4, YOLOv3 consists of DBL, resn, Up-sample, and concat. DBL stands for convolution (conv), batch normalization (BN) and leaky relu activation (Leaky reu). Resn represents the *n* residual units (res unit) in this residual block (res\_block). Zero padding means using zero to fill the edge of the image. Up-sampling represents up-sampling. The concat represents the merging tensor. DBL\*n represents the *n* DBL. The add represents the addition operation. The following y1, y2, and y3 represent feature maps with three different dimensions.

The network structure of darknet-53 is shown in Table 2. It uses successive 3 × 3 and 1 × 1 convolutional layers and some shortcut connections. The application of the shortcut connection layer allows the network to be deeper. It has 53 convolutional layers [38].

**Figure 4.** YOLOv3 network structure.



The loss function of YOLOv3 consists of localization loss *Lossl*, confidence loss *Loss*c, and classification loss *Lossp*. The loss function is as follows:

$$Loss = Loss\_l + Loss\_c + Loss\_p \tag{1}$$

$$Loss\_{l} = \begin{array}{c} \lambda\_{\text{coord}} \sum\_{i=0}^{S^{2}} \sum\_{j=0}^{B} I\_{ij}^{obj} \left[ \left( \mathbf{x}\_{i}^{j} - \mathbf{x}\_{i}^{j} \right)^{2} + \left( \mathbf{y}\_{i}^{j} - \mathbf{y}\_{i}^{j} \right)^{2} \right] \\ + \lambda\_{\text{cond}} \sum\_{i=0}^{S^{2}} \sum\_{j=0}^{B} I\_{ij}^{obj} \left[ \left( \sqrt{w\_{i}^{j}} - \sqrt{w\_{i}^{j}} \right)^{2} + \left( \sqrt{h\_{i}^{j}} - \sqrt{h\_{i}^{j}} \right)^{2} \right] \end{array} \tag{2}$$

$$\text{Loss}\_{\mathsf{c}} = \begin{array}{rll} & \sum\_{i=0}^{S^2} & \sum\_{j=0}^{B} I\_{ij}^{obj} \left[ \hat{\mathsf{C}}\_i^j \log \left( \mathbf{C}\_i^j \right) + \left( 1 - \hat{\mathsf{C}}\_i^j \right) \log \left( 1 - \mathbf{C}\_i^j \right) \right] + \\ & \sum\_{j=0}^{S^2} & \sum\_{j=0}^{B} I\_{ij}^{modj} \left[ \hat{\mathsf{C}}\_i^j \log \left( \mathbf{C}\_i^j \right) + \left( 1 - \hat{\mathsf{C}}\_i^j \right) \log \left( 1 - \mathbf{C}\_i^j \right) \right] \end{array} \tag{3}$$

$$Loss\_p = \sum\_{i=0}^{S^2} f\_{ij}^{obj} \sum\_{c \in \text{caleses}} \left[ p\_i^j(c) \log(p\_i^j(c)) + \left( 1 - p\_i^j(c) \right) \log \left( 1 - p\_i^j(c) \right) \right] \tag{4}$$

where λ*coord* = 5, *I obj* ij = 1 when the *j* th boundary box in cell *i* is responsible for detecting the object, otherwise 0. *x*, *y*, *w*, *h* denotes the bounding box parameter, *C<sup>j</sup> <sup>i</sup>* is the box confidence score of the box *j* in cell *i*, λ*noobj* = 0.5, *I noobj* ij is the complement of *I obj* ij , *<sup>p</sup><sup>j</sup> i* (*c*) denotes the conditional class probability of the box *j* th in cell *i* for class *c*. All the letters with superscript represent the corresponding ground truth (GT) values.

The learning rate adopts cosine attenuation:

$$
\alpha\_{\text{decayed}} = \alpha\_{\text{cmd}} + 0.5 \ast (\alpha\_{\text{initial}} - \alpha\_{\text{cmd}}) \ast \left(1 + \cos(s\_{\text{global}}/s\_{\text{train}} \ast \text{pi})\right) \tag{5}
$$

where α*decayed* denotes the decayed learning rate; α*initial* denotes the initial learning rate; α*end* denotes the end learning rate; *strain* denotes the total train steps; *sglobal* denotes the global steps.

More details about YOlOv3 can be found in reference [38]. The basic code is online at https: //github.com/YunYang1994/tensorflow-yolov3. Thanks to Yun Yang for sharing the code.

### *3.4. Preliminary Camouflage Texture Design Based on Deepfillv1*

The essence of optical camouflage is to make it impossible for human eyes or optical cameras to distinguish a target from its environment. This is similar to target removal in image processing. Inspired by this, we applied the image completion algorithm to the design of camouflage texture. This method could be used to design the camouflage pattern consistent with the real-time background texture.

Deepfillv1 is a generated image inpainting model based on the contextual attention mechanism [39]. It can quickly generate a novel image structure consistent with the surrounding environment. The framework of deepfillv1 is shown in Figure 5. Deepfillv1 consists of two stages. The first stage is a simple dilated convolutional network trained with reconstruction loss to rough out the missing contents. The second stage is the training of the contextual attention layer. The core idea is to use the features of known image patches as the convolution kernel to process the generated patches to refine the fuzzy repair results. It is designed and implemented with convolution for matching generated patches with known contextual patches, channel-wise softmax to weigh relevant patches, and deconvolution to reconstruct the generated patches with contextual patches. The spatial propagation layer is used to improve the spatial consistency of the attention module. In order to make the network produce novel contents, another convolution path parallel to the contextual attention path is designed. The two paths are combined and fed to a single decoder for the final output. The entire network is trained end-to-end. The coarse network is trained explicitly with the reconstruction loss, while the refinement network is trained with the reconstruction, as well as global and local gradient penalty wasserstein GAN (WGAN-GP) [44,45] losses. The reconstruction loss uses a weighted sum of pixel-wise *l*<sup>1</sup> loss. The weight of each pixel is computed as γ*<sup>l</sup>* , where *l* is the distance of the pixel to the nearest known pixel. γ is set to 0.99. The WGAN-GP uses the Earth-Mover distance *W Pr*, *Pg* for comparing the generated and real data distributions. Its objective function is constructed by applying the Kantorovich–Rubinstein duality:

$$\min\_{G} \max\_{D} \mathbb{E}\_{\mathbf{x} \sim P\_r} [D(\mathbf{x})] - E\_{\tilde{\mathbf{x}} \sim P\_\mathcal{S}} \left[ D(\tilde{\mathbf{x}}) \right] \tag{6}$$

where *<sup>D</sup>* is the set of 1-Lipschitz functions and *Pg* is the model distribution implicitly defined by **<sup>~</sup> x**= G(**z**). **z** is the input to the generator.

**Figure 5.** The framework of deepfillv1.

The *W Pr*, *Pg* is as follows:

$$\mathcal{W}(P\_{r\prime}, P\_{\mathcal{S}}) = \inf\_{\gamma \in \prod^{\prime}(P\_r, P\_{\mathcal{S}})} E\_{(\mathbf{x}, \mathbf{y}) \sim \mathcal{\gamma}} \left[ \left\| \mathbf{x} - \mathbf{y} \right\| \right] \tag{7}$$

where *Pr*, *Pg* denotes the set of all joint distributions γ(**x**, **y**) whose marginals are *Pr* and *Pg*, respectively.

More details about deepfillv1 can be found in reference [39]. The basic code is online at https://github.com/JiahuiYu/generative\_inpainting. Thanks to Jiahui Yu for sharing the code.

#### *3.5. Standardization of Camouflage Texture based on K-means*

The initial camouflage texture generated previously, although visually well integrated into the background, cannot be directly applied to the actual camouflage due to its large amount of colors. This is partly because it contains too many colors, which makes it difficult to operate in practical projects. On the other hand, the patterns generated by this method change with the environment, which leads to a huge increase in color further. Therefore, it is necessary to choose a limited number of colors as representative colors according to certain standards to replace similar colors, so as to achieve a balance between camouflage performance and engineering practice. We call this process the standardization of camouflage texture.

Based on the traditional digital camouflage color extraction method, we used the k-means clustering algorithm to extract the main color of the camouflage area. Note that extracting the primary color region here is different from the traditional method. We extract the preliminary camouflage designed area, while the traditional method is to extract the whole background. The flowchart of the extraction process of primary colors is shown in Figure 6. The extracted primary color also needs to meet the following restrictions [5]:


The Red-green-blue (RGB), hue-saturation-value (HSV), and Lab color spaces are commonly used in image processing. The RGB color space is related to devices, it does not reflect the true nature of human vision. However, the Lab model is a device-independent color system and based on physiological characteristics. This means that it is a digital way to describe human vision. In this paper, we chose the Lab color space since it can mimic the human vision system more closely.

**Figure 6.** Standardization of camouflage texture4 results.

Figure 6 shows the standardization process for camouflage textures. Firstly, we converted the color space of the filled area from RGB space to Lab space. Secondly, we needed to initialize cluster centroid C and number K. Normally, C is set randomly, and K is set to 4 or 5. Thirdly, pixels in the filled area of the image were classified into different categories according to their distance from the clustering centroids. Then, the most representative color in each pixel category was selected as the representative color of each category. According to the geographical environment of the background, digital camouflage is usually divided into four types—woodland, desert, ocean, and urban camouflage. According to a large number of data and actual production experience, the standard primary colors of various digital camouflage are determined. In this study, after determining the representative colors of the target area, we selected the standard color, which is closest to the representative color as the primary colors for designing the target camouflage. Finally, we used the standard color to replace the color of the pixel in the filled area to obtain the self-adaptive digital camouflage texture.

We used the Euclidean distance to calculate the distance between the representative colors and the standard colors. The distance of one color pair *d*(*r*,*s*) is computed as:

$$d(r,s) = \sqrt{(L\_r - L\_s)^2 + (a\_r - a\_s)^2 + (b\_r - b\_s)^2} \tag{8}$$

### **4. Results**

#### *4.1. Military Target Detection*

We used the method described in Section 3.3 to detect typical military targets. Four typical military target images were selected from the ImageNet2012 [41]. This dataset was segmented into a training and testing set. The four typical targets are airships, aircraft carriers, tanks, and uniformed soldiers. A total of 2187 images were selected from the dataset, of which 1970 were used for training and 217 for testing. There were no less than 500 pictures in each category. Unless otherwise noted, all the original images in this article about the four typical targets were from Imagenet2012.

The initial training parameters are shown in Table 3. *IOUthreshold* is the intersection over union (IOU) threshold. We used multi-scale training methods.



We used k-means clustering to determine our nine bounding box priors. On the selected dataset, the nine clusters were: (55 × 69), (151 × 91), (84 × 261), (200 × 188), (331 × 137), (179 × 346), (358 × 223), (350 × 303), (373 × 387).

We used the pre-training weight on the coco data set as the initialization weight. After 17 k steps of training, the model converged. The training loss is shown in Figure 7. As shown in Figure 7a, after 17 k steps of training, the loss was reduced to 0.6, which is basically convergent. The mean average precision (mAP) at IOU = 0.5 (mAP@.5) was 91.55%, as shown in Figure 7b. The results showed that the model we trained had high precision and could meet our application requirements. The total loss was calculated on the training set, and the mAP was calculated on the test set.

**Figure 7.** The total training loss and mean average precision. (**a**) the total loss; (**b**) the mAP.

Through training, our recognition precision of these four typical targets in different environments reached 91.55% (mAP@.5=91.55%), which highly met our application requirements. Moreover, this recognition process was just a demonstration. In practical applications, specific databases could be added according to the actual needs, and the recognition classes and precision could be increased through retraining (Figure 8). The recognition results of the model after our training are shown in Figure 8. As can be seen from Figure 8, our model could well identify four typical military targets. When the resolution of the input picture was 416 × 416, the model detection time was less than 25 ms.

**Figure 8.** Object detection results. (**a**) airship; (**b**) aircraft carrier; (**c**) tank; (**d**) uniformed soldier.

### *4.2. Preliminary Camouflage Texture*

We used the method described in Section 3.4 to design the initial camouflage texture. Firstly, we masked the target region detected by the YOLOv3, as shown in Figure 9b. Then, we input the masked image into the pre-train deepfillv1 model. The weights adopted by deepfillv1 were trained on the Place2 [40] data set. Places2 is a data set of scene images, containing 10 million pictures and more than 400 different types of scene environments, which could be used for visual cognitive tasks with the scenes and environments as application content, meeting our application requirements. We used the weight files from the literature [39] that were pre-trained on Place2 [40]. Finally, the complete image was obtained, which is called preliminary camouflage texture, as shown in Figure 9c. It could be seen from Figure 9 that the generated preliminary camouflage pattern had a texture consistent with the environment, which could be well integrated into the environment. When the input image resolution is 416 × 416, the preliminary camouflage texture generation process takes less than 0.2 s.

**Figure 9.** Preliminary camouflage texture design. (**a**) original image; (**b**) masked image; (**c**) completed image.

The more detailed experimental results are shown in Figure 10, where the first column shows the original images, with rectangular bounding boxes indicating the targets to be camouflaged, the second column shows the results of the preliminary camouflage texture design.

**Figure 10.** More detailed experimental results. (**a**) airship; (**b**) preliminary camouflage texture of the airship; (**c**) aircraft carrier; (**d**) preliminary camouflage texture of the aircraft carrier; (**e**) tank; (**f**) preliminary camouflage texture of the tank; (**g**) uniformed soldiers; (**h**) preliminary camouflage texture of the uniformed soldiers.

### *4.3. Standardization of Camouflage Texture*

We standardized the initial camouflage texture using the method described in Section 3.5. As shown in Figure 11b, the camouflage texture we designed corresponded to the rectangular area where the target is located. Note that K is set to 5 when discussed later in this article. At the same time, we needed to design the camouflage texture for the camouflage target's forward view, backward view, left view, right view, and top view, respectively. In practice, the texture of the camouflage area needed to be mapped to the actual target surface. This step could be accomplished with 3D rendering software, such as Maya or OpenGL, which is not described in this article as it focuses on the design approach. In this article, we simply mapped the camouflage texture to the camouflage target surface through a mask to observe its camouflage effect, as shown in Figure 11c. The output camouflage textures in Figure 11c had an overall optical camouflage effect, where textures and semantics were consistent with the environment look very naturally.

**Figure 11.** Adaptive digital camouflage texture. (**a**) Original image; (**b**) Rectangular texture area; (**c**) Camouflaged images using textures.

The more detailed experimental results are shown in Figure 12. The camouflage texture generation process in this paper is different from the traditional one. The traditional camouflage texture was obtained by random or non-random distribution using finite pattern templates or structural elements of texture. The camouflage texture generation in this paper is to use the method of image completion to generate the texture consistent with the current environment. The texture composition of this paper had no fixed structural elements. It might have been random for the natural environment like forest or regular for the artificial environment like the city, depending on the state of the current environment. The texture features come from a lot of training we did on the places2 [40] data set using the deepfillv1 [39] algorithm. The place2 [40] data set contains more than 400 different types of scene environments and 10 million images. Basically, it covers people's common scenes. The deepfillv1 [39] algorithm, after training on places2 [40], was able to generate a meaningful image consistent with the background texture of the incomplete image. As shown in Figure 12, camouflage textures are designed using this method on both natural and artificial backgrounds. In Figure 12, the first column shows the original images in natural and artificial environments, and the second column shows the camouflaged image corresponding to the first column. The color of the second column camouflage texture was not replaced by the standard color. The third column shows the camouflaged image with the camouflaged texture of the standard color. In Figure 12, the first row shows the camouflage texture in the natural environment using our method, the second row shows the camouflage texture in the artificial environment using our method. As can be seen from Figure 12, the camouflage texture designed using the method we provided has an excellent camouflage effect in both natural and artificial environments. The camouflage texture in the natural environment is irregular, and the camouflage texture in the artificial environment has a certain rule, which is consistent with the current environment. Comparing Figure 12e,f, it is found that the camouflage performance decreased after filling the camouflage texture with the most similar standard color. This is because the standard colors we choose are only 30 colors specified in the standard, which is a little different from the main colors of the current environment. This does not affect the effectiveness of the design method we provide. With the development of controllable color change technology, we may be able to choose far more than 30 colors in the future. At that time, the camouflage performance of the camouflage texture designed by this method would be further improved.

When the input image resolution was 416 × 416, the standardization of the camouflage texture took 0.1 s. All the tests were implemented on Python3.6, TensorFlow v1.6, CUDNN v7.1, CUDA v9.2, and run on hardware with CPU Intel Core I7-9700F (3.0 GHz) and GPU RTX 2080 Ti. Meanwhile, the whole camouflage texture generation process took less than 0.4 s. This time could be shortened significantly with the improvement of the image completion algorithms or the improvement of hardware performance. We firmly believe that real-time methods will be proposed in the near future. This shows that the method provided in this paper could quickly generate camouflage texture in real-time, which could be used in future combat equipment and personnel adaptive camouflage design.

**Figure 12.** Camouflage textures in natural and artificial environments. (**a**) the original image in natural environment; (**b**) the camouflaged image corresponding to (**a**) whose color wasn't replaced by standard color; (**c**) the camouflaged image corresponding to (**a**) whose color was replaced by standard color; (**d**) the original image in artificial environment; (**e**) the camouflaged image corresponding to (**d**) whose color wasn't replaced by standard color; (**f**) the camouflaged image corresponding to (**d**) whose color was replaced by standard color.

#### **5. Discussion**

Visual saliency is the perceptual quality that makes an object, person, or pixel stand out relative to its neighbors, and thus captures our attention. Therefore, it is a reasonable and effective method to evaluate the camouflage performance of a camouflaged target by the saliency detection algorithm, which has been used in many literatures [5,46,47]. In this paper, the frequency-tuned salient region detection (FT) [48] algorithm, as a classic saliency detection algorithm, was used to quantitatively evaluate the performance of camouflage texture. The saliency map of the camouflaged target was obtained after FT algorithm detection. The higher the salience value was, the more conspicuous was the foreground target, and the weaker was the camouflage effect.

To verify the validity and effectiveness of our proposed design method, we compared the saliency map of the camouflage texture designed using the traditional design method in the literature [5] with the camouflage texture designed using the method we provided, as shown in Figure 13. There are five images in Figure 13, where (a) shows an original image with foreground targets highlighted with red rectangles, (b) shows the results of camouflaging the targets using the camouflage texture designed by the existing method [5], (c) shows the results of manually camouflaging the targets using the camouflage texture designed by the method provided by us, (d) shows the saliency map corresponding to the image (b), (e) shows the saliency map corresponding to the image (c). As we can clearly see that the camouflage texture designed with our method has better camouflage performance, since in Figure 13d, the target contour and the mosaic-like stripe of the camouflage texture can be clearly distinguished, while in Figure 13e, the camouflage target could hardly be distinguished. Note that all

images above have the same resolution. This difference is due to the camouflage texture designed by the method in literature [5] being inconsistent with the background texture and semantics, while the texture designed by the method we provide is consistent with the background texture and semantics. This is because we are able to learn features from the surrounding environment using the deepfillv1 algorithm, whereas the existing method is a texture template based on empirical design.

**Figure 13.** Comparison between our method and an existing method. (**a**) original image; (**b**) traditional camouflage texture; (**c**) adaptive camouflage texture; (**d**) the saliency map corresponding to image (**b**); (**e**) the saliency map corresponding to image (**c**).

As shown in Figure 14, in order to evaluate the camouflage performance of the camouflage texture designed by our method, we used the saliency algorithm FT to calculate the salience map of the images before and after camouflage. In Figure 14, the first column shows the original image, the second column shows the salience map corresponding to the first column, the third column shows the camouflaged images using a texture designed by our method, the fourth column shows the salience map corresponding to the third column. As we can see from Figure 14, the designed camouflage texture satisfies the color condition: (1) the main color has different brightness, and (2) the main color is not much different from the background color. At the same time, the camouflage texture and the background have texture and semantic consistency. Therefore, the camouflage texture designed with the method we provided could blend well into the background.

**Figure 14.** Evaluate the performance of the camouflage image using the salience map.

In contrast with the original image, the foreground targets are almost indistinguishable in the camouflaged image's salience map, suggesting that the camouflaged targets blend well into the background and are invisible to human eyes and visible light reconnaissance equipment. At the same time, we input the camouflaged image into the YOLOv3 model of the previous training for reidentification. The results show that the target in the camouflaged image cannot be recognized by the YOLOv3 model. The experimental results show that the adaptive digital camouflage with excellent camouflage performance could be designed quickly with our design method.

### **6. Conclusions**

Adaptive optical camouflage technology is the inevitable direction of future optical camouflage technology development. As one of the key technologies of adaptive optical camouflage, the study of adaptive camouflage texture design has important theoretical and practical significance. In this paper, a fast-self-adaptive digital camouflage design method based on the neural network is proposed for the new generation of self-adaptive optical camouflage. First, we used the YOLOv3 algorithm to train the recognition model of four typical military targets. After adjusting the hyper-parameters, we got a model with good recognition probability, whose mean average precision (mAP) was 91.6%. Then, we used the deepfillv1 algorithm to do the preliminary camouflage texture design for the recognition area. Finally, the clustering algorithm was used to extract the main color of the camouflage target region, and the most similar color in the standard is used to standardize the color in the initial texture. The camouflage texture designed by our method was consistent with the texture and semantics of the real-time background. The whole texture generation process is very short, less than 0.4 s, which could meet the requirements of near-real-time camouflage design in the future. The saliency detection results showed that the camouflage texture generated by the proposed method had good camouflage performance in optical camouflage. At present, the method is effective for camouflage design in forest, grassland, desert, and other natural environments. But in artificial environment, such as urban environment, the effect of camouflage design is not very ideal. In addition, there are not many typical target images available and relevant datasets need to be further collected. In the future, on the one hand, we will further study how to improve the camouflage design performance of this method in an artificial environment. On the other hand, the implementation of adaptive camouflage systems will be

further studied, such as the control system of adaptive camouflage. Nevertheless, this paper proposes and implements a new idea of adaptive camouflage texture design, which has important potential application value in future real-time optical camouflage.

**Author Contributions:** Conceptualization, H.X.; Data curation, H.X.; Formal analysis, H.X.; Investigation, Z.Q.; Methodology, H.X. Z.Q.; Project administration, M.L.; Supervision, M.L.; Visualization, Y.J., C.W. and R.Q.; Writing —original draft, H.X.; Writing—review & editing, Y.J. All authors have read and agreed to the published version of the manuscript.

**Funding:** This research received no external funding.

**Conflicts of Interest:** The authors declare no conflict of interest.

### **References**


© 2020 by the authors. Licensee MDPI, Basel, Switzerland. This article is an open access article distributed under the terms and conditions of the Creative Commons Attribution (CC BY) license (http://creativecommons.org/licenses/by/4.0/).

## *Article* **Remote Sensing Scene Classification and Explanation Using RSSCNet and LIME**

### **Sheng-Chieh Hung 1, Hui-Ching Wu <sup>2</sup> and Ming-Hseng Tseng 1,3,4,\***


Received: 23 July 2020; Accepted: 2 September 2020; Published: 4 September 2020

**Abstract:** Classification is needed in disaster investigation, traffic control, and land-use resource management. How to quickly and accurately classify such remote sensing imagery has become a popular research topic. However, the application of large, deep neural network models for the training of classifiers in the hope of obtaining good classification results is often very time-consuming. In this study, a new CNN (convolutional neutral networks) architecture, i.e., RSSCNet (remote sensing scene classification network), with high generalization capability was designed. Moreover, a two-stage cyclical learning rate policy and the no-freezing transfer learning method were developed to speed up model training and enhance accuracy. In addition, the manifold learning t-SNE (t-distributed stochastic neighbor embedding) algorithm was used to verify the effectiveness of the proposed model, and the LIME (local interpretable model, agnostic explanation) algorithm was applied to improve the results in cases where the model made wrong predictions. Comparing the results of three publicly available datasets in this study with those obtained in previous studies, the experimental results show that the model and method proposed in this paper can achieve better scene classification more quickly and more efficiently.

**Keywords:** neural network; deep learning; cyclical learning rate; remote sensing; scene classification

### **1. Introduction**

With the gradual advancement of technology today, smart mobile devices and aerial cameras are beginning to appear on the market. As the performance of the hardware improves, aerial photography technology is constantly improving along with it, and rapid breakthroughs in imaging technology have made it possible to acquire imagery quickly. There are more types of imagery than ever before and the imagery is also clearer than was possible previously. Remote sensing images can be used in many technical aspects of scene classification, such as land-cover detection, urban planning, disaster relief, traffic control, etc. Therefore, how to classify large amounts of remote sensing image data covering land areas is an important research topic [1–8].

In the past, when using image data for classification research, the primary focus was on how to effectively perform the task of feature extraction [9,10]. The deep-learning methods used today make use of different convolutional neural network models to automatically perform feature recognition, extract the required image details, and then train the classification model to recognize the scene. The popularization of graphics cards in recent years has enabled the amount of time spent on deep learning in neural network model training to be reduced. It has also enabled the rapid

development of deep-learning research and studies related to the classification of high-resolution remote sensing images [11–13].

Related parameters used in deep-learning models include the training methods, network architecture, optimizer design, hardware operation, etc. Adjusting the model training hyper-parameters is a very important aspect of the model design. Smith [14] proposed a new learning-rate method; instead of a fixed value, a cyclical learning policy was set for the model being trained. The results showed that this could effectively reduce the number of training iterations and improve the accuracy of the classification. Smith and Topin [15] then proposed a new super-convergence one-cycle cyclical learning policy and suggested the usage of a large learning rate for model training, which, according to them, can improve the model's generalization capability. More recently, Leclerc and Madry [16] explored the impact of different learning rates on deep learning and found that the low and high values of cyclical learning rates [14,15] concur with their two regimes.

Training is the process of learning and adjusting the parameters of a model. During training, iteration methods such as the gradient descent learning algorithm are commonly used. Early stopping [17] is a method often used during training to prevent overfitting. This method calculates the accuracy of the test dataset during model training. When the accuracy of the test data no longer improves, the training stops and is terminated early. This not only helps to prevent overfitting of the training model but also improves the model's generalization ability.

This study aimed to produce an improved model with enhanced detection ability and a small number of training iterations. A two-stage circular learning method for training was, therefore, proposed. First of all, the image features were obtained by transfer learning; the classification framework designed in this paper was then used for training. The two-stage circular learning rate was used to reduce the number of iterations and, finally, the best model weights were obtained using the early stopping method. The main contributions of this paper are as follows:


The remainder of this paper is organized as follows: the dataset is introduced in the second section. In the third section, the steps of the developed research method are explained in detail. The data results and results from other studies are discussed in the fourth section, along with an analysis of why improved results were obtained using the proposed method. The fifth section is the conclusion.

### **2. Datasets**

In this study, three publicly available remote sensing image data sets—the UC Merced land-use dataset [19], RSSCN7 [20], and WHU-RS19 [18]—were used for testing the performance of the proposed method.

### *2.1. UC Merced Land-Use Dataset*

The pixel resolution of the UC Merced land-use dataset is 1 ft (=0.3048 m), and the dataset contains a total of 2100 images. The images are composed of 21 different land-use types; each class of each image has 100 RGB color images. The land-use types include agricultural, airplane, baseball diamond, beach, buildings, chaparral, dense residential, forest, freeway, golf course, harbor, intersection, medium residential, mobile home park, overpass, parking lot, river, runway, sparse residential, storage tanks, and tennis court. All the images contain different textures and colors, as shown in Figure 1. The UC Merced land-use dataset images were converted into 224 × 224 pixel size for transfer learning.

**Figure 1.** Example images of UC Merced dataset: (**a**) agriculture; (**b**) airplane; (**c**) baseball diamond; (**d**) beach; (**e**) buildings; (**f**) chaparral; (**g**) dense residential; (**h**) forest; (**i**) freeway; (**j**) golf course; (**k**) harbor; (**l**) intersection; (**m**) medium residential; (**n**) mobile home park; (**o**) overpass; (**p**) parking lot; (**q**) river; (**r**) runway; (**s**) sparse residential; (**t**) storage tanks; (**u**) tennis court.

### *2.2. RSSCN7 Dataset*

The RSSCN7 dataset is a public dataset released by Wuhan University in 2015. There are seven different scene categories, including grass, field, industry, river, lake, forest, residential, and parking lot. The entire dataset includes a total of 2800 images. Each scene category of the dataset contains 400 images and corresponds to one of four different sampling ratios (1:700, 1:1300, 1:2600, and 1:5200); there are 100 images corresponding to each of these ratios. In the original dataset, all of the images have a size of 400 × 400 pixels. The data were acquired in different seasons and under various weather conditions. In the case of sampling differences in different proportions, classification of this dataset is a greater challenge. The different image categories are shown in Figure 2. The RSSCN7 dataset images were converted into 224 × 224 pixel size for transfer learning.

**Figure 2.** Example images of RSSCN7 dataset: (**a**) grass; (**b**) field; (**c**) industry; (**d**) river lake; (**e**) forest; (**f**) resident; (**g**) parking.

### *2.3. WHU-RS19 Dataset*

The WHU-RS19 dataset was extracted from Google Earth satellite imagery. The spatial resolution of these satellite images is up to 0.5 m, and the spectral bands are red, green, and blue. There are 19 scene categories, including airport, beach, bridge, commercial, desert, farmland, football field, forest, industrial, meadow, mountain, park, parking, pond, port, railway station, residential, river, and viaduct. There are about 50 images corresponding to each category, and the entire dataset contains a total of 1005 images. The original image size is 600 × 600 pixels. Because the resolution, scale, direction, and brightness of the imagery vary greatly, processing this dataset is somewhat challenging. These data are also widely used in evaluating various scene classification methods. Figure 3 shows some samples from the dataset. The WHU-RS19 dataset images were converted into 224 × 224 pixel size for transfer learning.

**Figure 3.** Example images of WHU-RS19 dataset: (**a**) airport; (**b**) beach; (**c**) bridge; (**d**) commercial; (**e**) desert; (**f**) farmland; (**g**) football field; (**h**) forest; (**i**) industrial; (**j**) meadow; (**k**) mountain; (**l**) park; (**m**) parking; (**n**) pond; (**o**) port; (**p**) railway station; (**q**) residential; (**r**) river; (**s**) viaduct.

### **3. Method**

In this section, the model training method used in the experiments is introduced along with the convolutional neural network model used in this study and the two-stage cyclical learning-rate numerical design method used for the training.

### *3.1. Convolutional Neural Network Model*

To start, the VGG [21] neural network trained by ImageNet was used as the model for image feature extraction. This method adjusts the structure of the neural network used for certain trained network models by using transfer learning to perform other image training tasks. In the experiments carried out in this study, the structure of the original neural network layer was adjusted by freezing the weights in one or more layers of the original model, minus the time to retrain the deep model. The original model was used for feature extraction, and the newly embedded model layer was trained for use in classification. It was, therefore, only necessary to update and modify the weights of the newly added network layer during training; the frozen layer weights that were transferred from the learning model could be kept, as shown in Figure 4 [22].

**Figure 4.** Convolutional neural network model.

After removing the top-level fully connected layers of the pre-training model, a new classification network was added. In our model, we chose an exponential linear unit (ELU) [23] as our activation function because it can reduce the vanishing gradient problem by identifying positive values. The ELU has negative values; hence, it allows the mean unit to approach zero for a deep neural network model and obtain a faster convergence than the rectified linear unit (ReLU). Regularization was also added as a tool to reduce overfitting in the network. Regularization can be considered as a penalty term in the loss function. The so-called "penalty" refers to restrictions that are applied to some parameters in the loss function to prevent overfitting. Moreover, we added a batch normalization layer [24] after the convolution layer in our model, which can act as a regularizer to decrease overfitting. Batch normalization can use a larger learning rate in model training to achieve faster convergence benefits.

Considering the balance between the network capacity and the test accuracy and discussing the influence of different regularization and optimization strategies, we finally designed this new deep learning network architecture with a high generalization capability. The proposed CNN architecture is called RSSCNet (Figure 5). In RSSCNet, the depth of the weight layers is 17, and the mathematical formulation is written as follows:

$$\left\{ \begin{array}{c} X = f^{(16)}(f^{(15)}(\cdots(f^{(2)}(f^{(1)}(\mathbf{x}, \mathbf{w}, \mathbf{b})))) \\ \qquad \qquad \qquad \qquad \qquad \qquad \end{array} \right\}\_{\prime} \tag{1}$$

where *x* is the input data of each image, *w* is the weight matrix, *b* is the bias, *X* is the representative feature of *x*, and *Y* is the output probability. The RSSCNet architecture includes 15 convolutional layers, five max pooling layers, one global average pooling layer, two batch normalization layers [24], one dropout layer, and two fully connected layers with a softmax classification. Note that the activation function of the last two convolutional layers uses an ELU [23], while the other weight layers use the ReLU activation function. The convolution filter size is 3 × 3. The dropout rate is 0.5. The regulations of L1 and L2 are 0.01 and 0.02, respectively.

**Figure 5.** CNN classifier network architecture of our RSSCNet model.

### *3.2. Image Data Augmentation*

When training deep-learning models, large amounts of data are needed for training, and it is necessary to try to avoid overfitting during the training. Proper use of regularization strategies related to deep learning, such as L1/L2 regularization, dropout, batch normalization, early stopping, and data augmentation, is needed. Among these strategies, data augmentation is regarded as an effective method for training a generally applicable model using limited training data [25]. In data augmentation, after the image is rotated, resized, scaled, and flipped, or has its brightness or color temperature changed, the original image in the dataset is changed to create more images that will allow the model to continue learning. In order to make up for the lack of data, in this study, an augmented training method was included in the training. After using horizontal and vertical flipping processing, the training image was increased by small-scale translation and scaling in order to enhance the generalization ability of the model.

### *3.3. Cyclical Learning Rate*

The learning-rate method proposed by Smith [9] sets cyclical learning rates for the model instead of a fixed value and uses this to train the model. Results show that this can improve the accuracy of the classification and reduce the need for trivial adjustments. During training, the number of iterations used is usually reduced, and there are three different cyclical ways in which the learning-rate loop method can learn: "triangular", "triangular2", and "exp\_range". In our research, a two-stage cyclical learning-rate method was used to train the model. In the first stage, the "triangular" method was used to quickly find the best solution in the model; using this method, it was possible to avoid falling into a local solution when the learning rate was large. At the second stage, using the traingular2 method, the learning-rate cycle was gradually reduced to confine the model results until, finally, the solution stayed at a fixed position with no large swings (see Figure 6).

**Figure 6.** Cyclical learning rate during training.

The proposed two-stage cyclical learning rate method is calculated as shown in Equation (2).

$$\begin{array}{c} z = \left| 1 + \frac{i}{\Lambda} - 2 \ast \left[ 1 + \frac{i}{2\Lambda} \right] \right| \\ D = \left\{ \begin{array}{c} 1, \text{ for stage 1} \\ \left( \frac{\left| \frac{l}{\Lambda \text{min}} \right|}{\left| \frac{l \text{max}}{l \text{max}} \right|} \right) \ast \left( \text{for stage 2} \right) \\ \left( \text{for } lr\_{\text{max}} - lr\_{\text{min}} \right) \ast \max(0, (1 - z)) \end{array} \right\}. \tag{2}$$

where *z* is a dummy variable, *imax* is the total number of training epochs *i*, Δ is the step size that is equal to the half cycle length, *D* is the damping factor, *lr* is the cyclical learning rate, *lrmin* is the minimum learning rate, and *lrmax* is the maximum learning rate.

#### *3.4. t-Distributed Stochastic Neighbor Embedding (t-SNE) Analysis Method*

⎧ ⎪⎪⎪⎪⎪⎪⎪⎪⎪⎨

⎪⎪⎪⎪⎪⎪⎪⎪⎪⎩

The t-SNE analysis method is a nonlinear dimensionality reduction algorithm used for exploring high-dimensional data. Laurens van der Maaten and Geoffrey Hinton [26] proposed a new technique for visualizing similarity data in 2008. This technique can not only retain the partial structure of the data but can also display clusters of multiple scales at the same time. The t-SNE algorithm can project data into two-dimensional or three-dimensional space and uses good visualization to verify the effectiveness of the dataset or algorithm. The t-SNE method was used in various fields as a visualization method to evaluate the quality of classification [27,28]. It uses conditional probability and a Gaussian distribution to define the similarity between sample points in high and low dimensions and uses KL (Kullback–Leibler) divergence to measure the similarity between the sum of two conditional probability distributions; it also uses it a value function to decrease complexity by using the gradient method. The t-distribution is used to define the probability distribution at low dimensions to alleviate the congestion caused by dimensional disasters.

### *3.5. LIME Model Explanation Kit*

Although a deep learning model can obtain quite good classification results, it is difficult to understand how the classification results are derived because of its black-box characteristics. How to interpret the reasoning mechanism of the deep-learning model has become an important topic of research. In recent years, among the deep-learning methods, LIME is a new evaluation method for the interpretability of the model [29], i.e., whether it is possible to understand the importance of the deep-learning model for the interpretability of the image in the subsequent classification and prediction. The problem with model interpretability is that it is difficult to define the decision boundary of the model in a way that humans can understand. As shown in Figure 7, LIME is a Python library that attempts to generate some local feature-circle super-pixels. This can be used to explain the principle on which the model is based, which is usually difficult to describe, and to help with understanding whether the basis on which the model applies its decisions is appropriate or not. Figure 7a shows that the most interesting super-pixel of the RSSCNet model contains an airplane; hence, it can be correctly classified by the RSSCNet model. Figure 7b depicts that there is no storage tank in the super-pixel unlike the RSSCNet model, thereby causing the RSSCNet model to make a misclassification.

**Figure 7.** Image explanation using LIME (local interpretable model, agnostic explanation): (**a**) example of correct classification; (**b**) misclassified example.

### **4. Results and Analysis**

### *4.1. Experimental Set-Ups*

### 4.1.1. Implementation Details

In this study, the tensorflow2.0 suite within Python was used as the platform for the experiment. The hardware and system configuration included a Windows 10 version 1703 system. An NVIDIA GTX 1080 TI graphics card was used; the computing core was an i7-6700 3.40 GHz 8-core central processing unit (CPU) with 32 GB memory. Different pre-trained model parameter settings were used, and the results of using these were compared. Attempts were made to adjust the settings for training methods with different stages. The size of the batches in the experiment was set to a uniform size of 64, which was more in line with the memory capacity of the graphics card; the image length and width were set to 224 pixels in all cases. This study designed the training set size based on earlier studies. For the UC Merced land-use dataset, two training configurations—80% training and 50% training—were used. For the RSSCN7 dataset, 50% training and 20% training were used, and, for the WHU-RS19 dataset, the two modes used were 60% training and 40% training Two modes and 10 repeated random training cycles were used to verify and evaluate the experimental results.

### 4.1.2. Evaluation Methods

In this experiment, the confusion matrix and the overall accuracy were used to evaluate the classification performance, and the results were compared with those obtained using other, recently developed methods. The confusion matrix can be applied to the performance analysis of two-class or multi-class classification. After the model made its predictions, each class was assigned to one of a group of tables so that the data could be displayed and so that the detailed classification results for each category after the predictions were made could be seen. To evaluate the accuracy of the classification results, the overall accuracy was used. The accuracy ranged from 0 to 1, where a closer number to 1 denotes better classification performance. The total number of images that were correctly classified was divided by the number of test images.

In addition to the confusion matrix, the kappa coefficient is also often used to analyze the difference in the classification results for indicators of the multi-category classification quality. This coefficient is a method used in statistics to evaluate consistency. It calculates the index of the overall consistency and the classification consistency. The value range is [−1, 1]. A higher coefficient value denotes higher accuracy of the classification achieved by the model. The kappa coefficient (*K*) calculation formula with a higher degree is expressed as follows:

$$K = \frac{\left(P\_0 - P\_c\right)}{\left(1 - P\_c\right)}.\tag{3}$$

### *4.2. Results and Analysis*

### 4.2.1. Analysis of Experimental Parameters

In this study, different pre-trained models were tested for the evaluation of the classification results. This was done so that the best feature extraction method for the appropriate pre-trained model could be found. Once it was found, the weight layers in the pre-trained model were adjusted. It was considered whether the pre-trained model weights would affect the results of the transfer learning in order to find the best training-layer training plan for the model parameter configuration that would be used in subsequent experiments in this study.

### 1. Different pre-trained CNN models

Different pre-trained models have different degrees of influence on image feature extraction. In this experiment, the WHU-RS19 dataset was used to embed four different common pre-trained models—VGG16 [21], VGG19 [21], ResNet50 [30], and InceptionV3 [31]—into the classification model. A classification performance test was carried out to help decide which of the four models was the most suitable for use as the pre-trained model. The training used an Adam optimizer; the batch size was 64, and the number of iterations was 150. The results of the training carried out using the four different models are shown for comparison in Figure 8. The results show that the pre-trained VGG16 model had the best overall accuracy; thus, in subsequent testing, this was used as the image feature extraction model.

**Figure 8.** Comparison of accuracy using different pre-trained models.

### 2. Different numbers of fine-tuning layers during training

After choosing to use the VGG16 pre-trained model, this study further explored whether, by freezing some of the network layers in the pre-trained model, the model could be made to have a better generalization performance. This experiment was also conducted using the WHU-RS19 dataset, and the results are shown in Figure 9. Based on the four blocks contained in the VGG16 model, two, four, seven, 10, and 13 layers were frozen. The results show that, when no layers were frozen, all the layers of the pre-trained model were retrained and fine-tuned, which means that, although this method requires more resources and a longer training time, it can produce a better overall accuracy. This shows that, in the classification of remote sensing images, the feature image that is required can be obtained by further training.

**Figure 9.** Comparison of accuracy using different fine-tuning layers.

For the different fine-tuning layers discussed, the main consideration was whether to retrain the weights in the pre-trained model. Retraining the entire network inevitably takes a lot of time. However, in the process of feature extraction, in addition to training, we believe that it is necessary to focus on the training of the classifier and, hence, in this study, we aimed to strengthen the model's ability to classify features by using a two-stage training method. In Figure 10, the WHU-RS19 dataset is used as the comparison dataset for the proposed method.

A two-stage training method (shown as "2-stage" in Figure 10) in our research indicates that two different optimizers were used and separated into two parts for the two-stage training. In the first stage, the SGD (stochastic gradient decent) optimizer was used to carry out 100 training iterations to train the entire neural network model. In this stage, the pre-trained model and classifier model could be adjusted at the same time, thus strengthening the features of the capture model and learning the classification ability. In the second stage of the training, the model weights with the best accuracy learned in the first stage were loaded, and the Adam optimizer was used to carry out the next 50 training iterations. We also compared the result with only using the Adam optimizer training for 150 iterations (shown as "1-stage" in Figure 10). As can be seen from Figure 10, the test accuracy obtained using the two-stage training (97.76%) was better than that obtained using the one-stage training (96.33%).

**Figure 10.** Comparison of overall accuracy with and without two-stage training.

### 3. Different classification architectures

For a performance comparison, we compare the RSSCNet architecture proposed herein with the other architectures in the literature, such as VGG-16-Net [21] and the GSB + LOB model [32]. Figure 11 showed the ranking results of the test accuracy of each model (i.e., RSSCNet: 0.978, VGG-16-Net: 0.973, and GSB + LOB model: 0.97), which indicated that the proposed RSSCNet is the best classification model.

**Figure 11.** Comparison of testing accuracy using different classification architectures.

### 4. Different cyclical learning-rate methods

In this section, we compare the performances of the two-stage cyclical learning rate method (2-stage CLR, Figure 11) and Smith and Topin's (2019) one-cycle cyclical learning rate method (1-cycle CLR). The results of Figure 12 showed that the two-stage cyclical learning rate method achieved the highest test accuracy of 98.0% at epoch = 134. On the contrary, the one-cycle cyclical learning rate method only achieved the highest test accuracy of 97.0% at epoch = 27. Although the latter could provide quicker access to the best test accuracy for training, the former could achieve a better test accuracy; hence, it was used in subsequent experimental results for model training and performance testing.

**Figure 12.** Comparison of testing accuracy using different cyclical learning rate (CLR) methods: (**a**) cyclical learning-rate policy; (**b**) testing accuracy.

### 4.2.2. Experimental Results

1. Classification of UC Merced land-use dataset

In this section, the classification results for the UC Merced land-use dataset are discussed. The t-SNE analysis method was used for the classification. This method is a non-linear dimensionality-reduction algorithm used for exploring high-dimensional data. It can map multi-dimensional data to two or more dimensions using technology suitable for visual presentation. In this study, extracting the features of the deep layers for analysis helped with the understanding of the differences between the features obtained by the model before and after training. In Figure 13a, which shows the results before training, the features extracted by the model show little correlation; however, after training, as shown in Figure 13b, the features are highly clustered, thus showing that these models do help to improve the classification performance.

**Figure 13.** Visual analysis on UC Merced dataset using t-distributed stochastic neighbor embedding (t-SNE): (**a**) before training; (**b**) after training.

Figure 14 shows the VGG16 model supplemented by the model classification matrix proposed in this study using the best classification weights obtained from the early training termination strategy. The resulting confusion matrix is also shown; the training rate was 80%. The matrix contains the individual classification results for 21 categories. The kappa coefficients of the UC Merced dataset were 0.9985 and 0.9895 at 80% and 50% training, respectively.

In Table 1, the classification results found in this study are compared with those obtained using other classification methods. It can be seen that, by using the two-stage cyclical learning-rate training method, this study achieved the best overall accuracy out of the results shown.


**Table 1.** Comparison of the overall accuracy and standard deviations using 80% and 50% training ratios on UC Merced dataset.

**Figure 14.** Classification confusion matrix of our method on UC Merced dataset.

### 2. Classification of RSSCN7 dataset

The t-SNE analysis method was used to extract the deep features of the proposed model and to analyze it. In this section, the classification results obtained by applying this model to the RSSCN7 dataset are discussed. As shown in Figure 15b, after training, the features became highly clustered, which shows that the model proposed by this research helps to improve the scene classification.

**Figure 15.** Visual analysis on RSSCN7 dataset using t-SNE: (**a**) before training; (**b**) after training.

Figure 16 shows the overall accuracy confusion matrix extracted by VGG16 supplemented by the classification method proposed in this study and using the optimal classification weights in the

training early termination strategy. The training rate used was 50%. The matrix contains the individual classification results for seven categories. Among these, agricultural land and grassland are most likely to be confused. This is perhaps because the two categories have similar characteristics—both containing a large proportion of green ground, which easily leads to errors. The kappa coefficients of the RSSCN7 dataset were 0.9737 and 0.9329 at 50% and 20% training, respectively.

**Figure 16.** Classification confusion matrix of our method on RSSCN7 dataset.

Table 2 shows a comparison of the results obtained for the RSSCN7 dataset classification using different recently proposed methods, including the one proposed in this paper. The proposed two-stage cyclical learning-rate training method achieved the best overall accuracy for two different training ratios. With a 50% training ratio, it produced an increase in accuracy of about 3% compared with other methods.

**Table 2.** Comparison of the overall accuracy and standard deviations using 50% and 20% training ratios on RSSCN7 dataset.


### 3. Classification of WHU-RS19 dataset

The t-SNE analysis method was used to extract the deep features of the proposed model and to analyze it. In this section, the classification results obtained by applying this model to the WHU-RS19 dataset are discussed. As shown in Figure 17b, as a result of the training, the features became highly clustered, showing that the proposed model helps to improve the classification of the scene.

**Figure 17.** Visual analysis on WHU-RS19 dataset using t-SNE: (**a**) before training; (**b**) after training.

Figure 18 shows the confusion matrix for the WHU-RS19 dataset extracted using VGG16 with the classification types proposed in this study and the optimal classification weights from the training early termination strategy. The training rate used was 60%. The matrix contains the individual classification results for the 19 categories in the dataset. Among these categories, the combinations football field and park and of forest and mountain were the most easily confused during the classification. Table 3 shows a comparison between the results of the WHU-RS19 dataset classification obtained using the proposed method and methods proposed in other recent papers. Using the two-stage cyclical learning-rate training method proposed in this study, our proposed method achieved the best overall accuracy. The kappa coefficients of the WHU-RS19 dataset were 0.9968 and 0.9874 at 60% and 40% training, respectively.


**Table 3.** Comparison of the overall accuracy and standard deviations using 60% and 40% training ratios on WHU-RS19 dataset.

From the confusion matrix classification results shown in Figure 18, it can be seen that the categories that are misclassified in the WHU-RS19 dataset include "residential", "forest", "farmland", and "bridge" (Figure 19a). We first used the LIME analysis on the misclassified four images and generated the super-pixel feature regions that the model was most interested in (Figure 19b). By observing the super-pixel area in Figure 19b, we can understand why the model misclassifies "residential" as "industrial", "forest" as "park", "farmland" as "river", and "bridge" as "pond".

**Figure 18.** Classification confusion matrix of our method on WHU-RS19 dataset.

**Figure 19.** Misclassified images on WHU-RS19 dataset: (**a**) original image; (**b**) super-pixel explanation by LIME analysis.

This research attempted to correct the four images misjudged by the model in Figure 19a with the hopes of improving the model classification performance. First, with regard to the reason for the wrong judgment of the "residential" image, we believe that the other images of the residential category in the dataset contained various bright colors as a whole, while the roofs of the buildings in industrial areas tended to be mostly white. The house colors tended to be white, which may have led to the classification errors. Therefore, the color of the "residential" image was increased in saturation to make it more similar to the other "residential" images in the dataset. From the super-pixel area of the "forest", the cut block contained a part of the bare land, which was different from the other images in this category, which mostly only contained forests. The colors and the details were also blurred. Therefore, the color contrast was enhanced, the overall sharpness of the image was increased, the shadows between the trees were intensified, and the image was prevented from being judged as a "park" again. The image of the "farmland" depicts that the light of the horizontal road in the image is quite obvious, and a green straight line is included in the captured image features. Therefore, we reduced the brightness of the strong part of the image. In the "farmland" parts, we performed a small sharpening to try to remove the noise in the image and strengthen the details of the interval between farmlands. In the last "bridge" image, the feature did not contain the "bridge" feature at all. Therefore, we increased the brightness of the "bridge" itself and the color saturation and sharpness of the image. We also tried to increase the chance of a "bridge" edge being captured as a feature. Figure 20a,b present the four corrected images and their corresponding super-pixel feature regions, respectively. Finally, the four corrected images were replaced with the original images, and the category prediction of the entire dataset was again performed. Consequently, the result reached an overall accuracy of 100%. Figure 21 displays the corrected classification matrix.

**Figure 20.** After correction of misclassified images on WHU-RS19 dataset: (**a**) corrected image; (**b**) super-pixel explanation by LIME analysis.

**Figure 21.** Classification confusion matrix of WHU-RS19 dataset after image correction.

#### 4.2.3. Further Explanation and Discussion

In this section, we further discuss how fine-tuning, a circular learning rate, and an increase in the amount of data can improve the classification performance. The classification results obtained in this study can be expanded to help understand the possible impact of this project on model training.

### 1. The effectiveness of fine-tuning

In Section 4 of this paper, two-stage training using the proposed model was described, and it was found that the proposed method has significant optimization for training. Moreover, we also wanted to understand, in addition to not carrying out freezing in the first stage, whether freezing the first 19 layers in the second stage would produce different results from those obtained by freezing different layers at two different stages. Therefore, we investigated three different situations: no freezing of any layer, freezing of the top seven layers, and freezing of top the 19 layers; the results for different combinations of these three situations are shown in Figure 22.

The five combinations investigated were no freezing at either stage (shown as "0 + 0"), top seven layers frozen at second stage only ("0 + 4"), top 13 layers frozen at second stage only ("0 + 13"), top 13 layers frozen at first stage only ("13 + 0"), and top 13 layers frozen at two stages ("13 + 13"). From Figure 22, it can be seen that two-stage training with no freezing ("0 + 0") achieved the best testing accuracy. The test accuracy was the worst when the top 13 layers were frozen in the two training phases ("13 + 13"). According to the results of Figure 22, the test accuracy increases as the number of fine-tuning layers increases.

**Figure 22.** Overall accuracy of fine-tuning using different frozen combinations.

#### 2. Effectiveness of image data augmentation

In this study, after inverting and increasing the number of images by carrying out a small amount of panning and zooming, augmentation training was also included in the training. As shown in Figure 23, doing this also successfully increased the training accuracy. The results here are shown as no data augmentation ((shown as "1-stage" in Figure 23), single-stage data augmentation included in the training ("1-stage DA"), and two-stage data augmentation included in the training ("2-stage DA").

**Figure 23.** Accuracy of different training methods.

3. Effectiveness of using a two-stage cyclical learning-rate method

In two-stage cyclical learning, the training can be implemented using two different optimizers. In this study, an SGD optimizer with a cyclical learning rate was used in the first stage. In the second stage, an Adam optimizer was used for the training. In two-stage cyclical learning-rate training, this method obtains the best weight in the first stage and then enter the second stage. When used together with the cyclical learning rate, this can greatly accelerate the convergence. Figure 24 shows a comparison of the number of iterations needed to achieve the best level of accuracy for three different training methods.

**Figure 24.** Test accuracy of with and without two-stage cyclical learning-rate method.

When only a single-stage fixed learning rate (shown as "1-stage" in Figure 24) was used for the training, the convergence speed was the slowest and the lowest test accuracy was obtained. The use of a single-stage cyclical learning rate ("1-stage CLR") could speed up the convergence and give a greater chance of avoiding the local optimal solution so that better results could be obtained. When a two-stage circular learning rate ("2-stage CLR") was used in the early part of the second training stage, due to the change of optimizer, the process of finding the best accuracy fluctuated. However, overall, the best accuracy could be reached more quickly, using the smallest number of iterations of the three methods.

### **5. Conclusions**

As a result of continued advances in technology, better-quality and higher-resolution data can be obtained, leading to improvements in remote sensing image classification and in predictions based on it. It takes a lot of time to train and adjust the classification. In order to reduce the time required for the training of the model and to explore how quickly the model can converge with high-generation capability, we recommend the RSSCNet model integrated with the simultaneous use of a two-stage cyclical learning-rate training policy and the no-freezing transfer learning technology that requires only a small number of iterations. In this way, an excellent level of accuracy can be obtained. Data augmentation technology, regularization, and early stopping strategy can then be used to also deal with the problem of limited generalization encountered in the rapid training of deep neural

networks. The experimental results that were obtained also confirm that the use of the model and training strategies proposed in this paper can outperform current models in terms of accuracy.

In this study, by using the LIME super-pixel explanation, the root causes of model classification errors were made clearer and a better understanding was obtained. This made it easier to carry out subsequent processing and adjustment of the data or models. After the image correction preprocessing on the four misclassified images using the RSSCNet model in the WHU-RS19 dataset, this image correction procedure was found to improve the overall classification accuracy. This investigation is only a preliminary study.

In future research, we will try to establish universal image correction preprocessing for the case of suspected outliers and merge different XAI (explainable artificial intelligence) analysis technologies to improve interpretation capabilities so that they can be applied to a more diverse range of imagery with different classification issues.

**Author Contributions:** M.-H.T. and S.-C.H. conceptualized and designed the whole framework and the experiments, as well as wrote the manuscript. S.-C.H. performed the experimental analysis. M.-H.T. contributed to the discussion of the experimental results. H.-C.W. helped to organized and revise the manuscript. All authors have read and agreed to the published version of the manuscript.

**Funding:** This research was funded by the Ministry of Science and Technology, Taiwan, grant numbers MOST 108-2621-M-040-002 and MOST 109-2121-M-040-001. The support is greatly appreciated.

**Conflicts of Interest:** The authors declare no conflict of interest.

### **References**


© 2020 by the authors. Licensee MDPI, Basel, Switzerland. This article is an open access article distributed under the terms and conditions of the Creative Commons Attribution (CC BY) license (http://creativecommons.org/licenses/by/4.0/).

## *Article* **Comparing Classical and Modern Machine Learning Techniques for Monitoring Pedestrian Workers in Top-View Construction Site Video Sequences**

### **Marcel Neuhausen, Dennis Pawlowski \* and Markus König**

Computing in Engineering Department, Ruhr-University Bochum, 44801 Bochum, Germany; marcel.neuhausen@ruhr-uni-bochum.de (M.N.); koenig@inf.bi.ruhr-uni-bochum.de (M.K.) **\*** Correspondence: dennis.pawlowski@ruhr-uni-bochum.de

Received: 29 October 2020; Accepted: 24 November 2020; Published: 27 November 2020

**Abstract:** Keeping an overview of all ongoing processes on construction sites is almost unfeasible, especially for the construction workers executing their tasks. It is difficult for workers to concentrate on their work while paying attention to other processes. If their workflows in hazardous areas do not run properly, this can lead to dangerous accidents. Tracking pedestrian workers could improve the productivity and safety management on construction sites. For this, vision-based tracking approaches are suitable, but the training and evaluation of such a system requires a large amount of data originating from construction sites. These are rarely available, which complicates deep learning approaches. Thus, we use a small generic dataset and juxtapose a deep learning detector with an approach based on classical machine learning techniques. We identify workers using a YOLOv3 detector and compare its performance with an approach based on a soft cascaded classifier. Afterwards, tracking is done by a Kalman filter. In our experiments, the classical approach outperforms YOLOv3 on the detection task given a small training dataset. However, the Kalman filter is sufficiently robust to compensate for the drawbacks of YOLOv3. We found that both approaches generally yield a satisfying tracking performances but feature different characteristics.

**Keywords:** cascaded classifier; computer vision; construction site management; deep learning; tracking

### **1. Introduction**

Construction sites constitute highly dynamic environments in which workers execute diverse orders simultaneously. Workers need to perform tasks, interact with heavy construction equipment and keep an eye on their surroundings, which is difficult for complex tasks. This requires construction workers to have a high level of concentration to avoid mistakes. If the construction site is noisy and congested, it can be difficult for the construction workers to concentrate on their work and the environment at the same time. In addition, the continuous change of a construction site often leads to hazardous situations. Heavy construction machines move across the site to execute their jobs. Construction vehicles cross the worker's paths and cranes lift loads over their heads. In addition, pedestrian workers inevitably share the same workspaces with construction machines or interact with them in order to accomplish their orders [1]. As a result, worker's activities often happen in close proximity to heavy machinery. Hazardous situations such as close calls can occur as a consequence of this [2]. Furthermore, in certain cases, people can misjudge the danger. Due to the mentioned facts, it is to be expected that workflows on construction sites are not always ideal. In addition, the incidents lead to hazardous situations for pedestrians on construction sites. Thereby, the pedestrian worker can be injured or have a fatal accident. For these reasons, construction workers undergo regular training to raise their awareness with respect to hazardous situations [3] as well as to develop their knowledge and skills [4] to improve their workflows. Despite this effort, working in the surroundings of heavy construction machines remains to be a hazardous job. Identifying a reasonable workflow during its execution in a steadily changing environment also keeps to be a challenging concern. Accordingly, productivity and safety problems continue to occur on construction sites.

Monitoring pedestrian workers on the sites from a top-view perspective could improve this situation. This way, the worker's trajectories could be analyzed and adaptations to their workflows could be made during their work [5]. In addition, a monitoring system would help the machine operators in their work, as they usually do not have a complete overview of the environment [1]. In this case, the positions and movement directions of workers who are nearby could be provided to the machine operators. This would allow them to recognize hazardous situations at an early stage and take appropriate action to prevent an accident. The mentioned scenarios represent possible applications of a monitoring system on construction sites, for which a suitable method has to be investigated. Therefore, in this paper, we only focus on the detection and tracking of construction workers, since this is the basis for a surveillance system.

Different approaches have already been made in tracking pedestrian workers on construction sites. Depending on the surrounding, various technologies are used for monitoring workers in construction related scenarios [6]. The literature often refers to global navigation satellite system (GNSS) tags for outdoor localization on construction sites [7,8]. This is mainly applied to track construction machines and equipment, but some approaches make use of GNSS to locate pedestrians on site [9,10]. Near buildings, walls or large construction elements, GNSS-based localization is affected by multipath effects caused by reflections of the signal on these objects [11,12]. Since construction workers often work near these objects, tracking such workers with GNSS becomes unreliable. Less sensitive approaches rely on radio-frequency technologies. These also include the attachment of tags to the worker's gear. Corresponding readers can either be stationary on the construction site or attached to construction machinery. Employing a setup that utilizes radio-frequency identification (RFID) technology enables the warning of workers and machine operators whenever a tagged worker gets into the range of a machine's reader [13] or specific zones [14]. Other approaches develop systems noticing pedestrian workers when entering certain zones by localizing them using ultra-wideband (UWB) tags [7]. Although the accuracy can be improved by combining RFID technology with ultrasound [15], a precise localization and tracking of workers in real time constitutes a challenging task [7] which has not been sufficiently solved in general yet [16]. Besides technical deficiencies, the deployment of such a system for the implementation of a tracking application on a real construction site requires a tag for each worker. Additionally, at least three receivers per monitored area have to be installed in the case of two-dimensional tracking which results in high acquisition costs [6]. Furthermore, workers perceive the attachment of tags with unique identifiers (IDs) to their gear to be obtrusive and to cause discomfort [17].

In contrast to tag-based methods, camera-based tracking approaches provide cost-effective surveillance alternatives. The costs amount to one camera per monitored area and there are no further costs for additional equipment for the workers. In research, some effort has already been made to detect construction worker's in camera images [18]. Park and Brilakis [8] built a real-time capable detection scheme to recognize construction workers in camera images for initializing a visual tracker. They used background subtraction via a median filter to find moving objects within the images. Then, a pedestrian detection is performed on these objects using a support vector machine (SVM) which operates on Histogram of Oriented Gradients (HOG) features. Exploiting the loud colors of the worker's safety vests, construction workers are identified from the pedestrian detections by clustering hue, saturation, value (HSV) color histograms using a k-Nearest Neighbors (k-NN) algorithm. Chi and Caladas [19] also decided on background subtraction to identify moving objects before classifying these by a neural network

approach. In [20], the background subtraction is exchanged by a sliding window approach. However, similar to Park and Brilakis [8], they used HOG and color features but concatenate them to a single vector which is passed to a SVM in order to identify workers. Using color and spatial models classified by a Gaussian kernel, Yang et al. [21] decided on a similar strategy. In recent years, deep neural network based approaches have become prominent. To recognize worker's activities, Luo et al. [22] applied temporal segment networks. Son et al. [23] used an region based convolutional neural network (R-CNN) based on ResNet to detect workers under changing conditions. Luo et al. [24] proposed an approach which detects construction workers in oblique images of cameras mounted in heights. They applied YOLO for detection but only achieved a precision of about 75%. Vierling et al. [25] proposed a convolutional neural network (CNN)-based concept detecting workers in top-view images. To cope with high altitudes, their approach relies on several zoom levels each with a separate CNN for detection. An additional meta CNN is used to choose the correct zoom level for a certain height.

Despite the use of transfer learning techniques, CNN-based approaches usually require large amounts of training data. Corresponding data, which represent construction sites from a top view perspective, are rarely available. In addition, the generation of a sufficiently large dataset is a time-consuming and demanding task. Beyond that, these networks commonly operate on small image sizes of about 400 × 400 px. High resolution images cannot be processed in real time without disproportionately large amounts of computational power or without drastically downscaling images. Since surveillance cameras monitor large areas of construction sites, workers occupy only very small parts of the image. Besides the fact that detecting small objects with CNNs is a challenge, reducing the size of the image makes detection even more difficult. The downscaling may eliminate relevant features in the image required for a reasonable detection.

As it is unclear whether CNNs can outperform classical machine learning methods on these terms, we conduct a comparison in this paper. The goal of the comparison is to find out whether one of the two approaches can satisfactorily track several construction workers when the amount of training data is small. In doing so, we restrict our focus to one view within one construction site at first. In this case, a camera could be mounted, e.g., on the mast of a tower crane. Alternatively, several cameras could be used to completely monitor a construction site, but this is not covered in this paper. Because no suitable dataset is publicly available for comparison, we assemble a small generic dataset ourselves. This shows pedestrian construction workers from a top view perspective under different conditions. As a representative CNN approach, we choose YOLOv3 [26], since it is a state-of-the-art detection network. For its counterpart from the field of classical machine learning, we rely on preliminary work [27] discussing diverse computer vision techniques for monitoring construction workers. In this work, we juxtapose eligible methods and develop a theoretical concept relying on a classical machine learning method. Based on these results, we implement a tracking approach based on a soft cascaded classifier in the course of this paper. A simple Kalman filter is applied to both approaches in order to track the detected workers within the recorded video sequence. In our experiments, we compared the detection and tracking results of our implemented approach with those of YOLOv3 trained on the same data. We found that our classical machine learning approach yields substantial better detection results on the small dataset than the CNN. However, the Kalman filter proves to be sufficiently robust to compensate for the lower detection quality of YOLOv3. Owing to this, both approaches perform similarly well on the tracking of pedestrian workers in general. Nevertheless, each approach possesses different tracking characteristics. A general recommendation can, thus, not be made. The appropriate tracking solution has to be determined with respect to the demands of the particular application.

#### **2. Materials and Methods**

To compare the performance of CNN-based and classical computer vision approaches to the monitoring of pedestrian workers, we assembled a small characteristic dataset. This dataset includes different scenarios as well as various environmental conditions. Section 2.1 describes the dataset in detail. This dataset is used for the training and the evaluation of both approaches chosen for our comparison. The classical approach is composed of a soft cascaded classifier and a background subtraction which preprocesses the input images to enable detection in real-time. Its detailed structure and the parameter optimization are described in Section 2.2. YOLOv3 also possesses several hyperparameters. In Section 2.3, we elaborate on our choice of those hyperparameters. For a better comparison, the detections of both approaches are tracked by the same method over the course of the video sequences. We chose Kalman filtering for this as it is a simple but robust method which is sufficient for the purpose of comparing the two approaches. Details about the Kalman filter's motion model and other required parameters are given in Section 2.4. The implementation and testing of the two approaches was done with the programming language C++ on a standard computer.

#### *2.1. Dataset Acquisition*

Top-view scenes of construction sites have rarely been recorded. Accordingly, a dataset reasonable for our purpose has not been published yet. For that reason, we recorded video sequences to train and test our approach explicitly for the scope of this paper. It is important to know if different backgrounds and lighting conditions can have an influence on the tracking of construction workers. In addition, we want to test if our approach can distinguish between different moving objects. The videos were therefore taken in two scenarios with different levels of difficulty for our approach: in the first one, construction workers act on a uniformly plastered terrain, whereas a mixture of gravelled and plastered areas is chosen for the second scenario. While the first scene is illuminated well, the second scene is slightly overexposed. Both sequences are recorded by a non-pivoting camera at a height of 20 m, which is aligned to the ground and has a fixed position (see Figure 1). This results in a top-view perspective in the center of the image, but becomes oblique at the borders of the image. In accordance with a typical crane camera set up [28], all videos were recorded with a frame rate of 25 fps and a resolution of 1920 × 1080 px. In both scenarios, construction workers wearing safety vests and helmets walk randomly through the camera's field of view including sudden stops and directional changes. They also interact with static construction specific elements such as pylons and barrier tapes. Of course, exposing workers intentionally to hazardous situations or heavy construction machinery is unwarrantable. Such hazardous situations are substituted by smaller moving vehicles instead which still allows for evaluating the correct classification of moving objects. In both video sequences, the construction workers are manually labeled by hand in order to prepare both the ground truth and the training data. Rectangular areas surrounding the worker's heads and shoulders are used to indicate the positions of the workers.

Datasets for training, validation and evaluation are generated from the labeled sequences. For this, we divide each sequence into three shares of equal length. From the first share of both sequences, we generate the training dataset. This training dataset includes 1000 pedestrian construction worker samples (see Figure 2). Samples are only extracted from every 10th frame to reduce the similarity among the samples. For generating the validation dataset, we proceed analogously with the second share of both sequences. The validation dataset also contains 1000 pedestrian construction worker samples. From the third part, we generate evaluation data that consist of video sequences with an average length of 12 s. From each sequence, we choose a representative scene which contains settings commonly occurring on construction sites. Both scenes show up to four pedestrian construction workers simultaneously and other objects colored similarly to the worker's safety vests that are either static or moving. The static objects

are red and white barrier posts, barrier tapes and a red barrier on a gravel area. These elements have similarities to the colors of the construction workers. Moving objects include a red vehicle that moves linearly in the same direction as a construction worker. In addition, the red color between the worn safety vest and the vehicle is similar. The workers walk through the scene while sometimes passing each other closely. In addition to these two scenes, we choose a third one to evaluate our approach with regard to moving vehicles. This scene shows a single worker walking while a car approaches him from behind.

**Figure 1.** The construction workers are captured by a stationary camera which is located 20 m above the ground. The camera is pointed at the ground.

**Figure 2.** The training dataset contains construction workers taken from the top view. The workers are in different orientations and illuminated in different ways.

### *2.2. Classical Detection Approach*

Based on our theoretical concept previously proposed in [27], we implemented an explicit approach to detection of construction workers in top view images. As shown in Figure 3, detection is done by first extracting relevant regions of interest (RoIs) from the current camera image. Afterwards, each region is classified to determine whether it contains a worker or not.

**Figure 3.** Concept of the classical detection approach. Each time step *t* corresponds to an image frame. The current frame in time step *ti* is subtracted from a background model which is generated from images of the previous time steps. This results in foreground blobs which represent RoIs. Image patches of the current image corresponding to these regions are fed into the classifier determining the presence of construction workers.

Most previously implemented image-based approaches in civil engineering are based on a frontal camera perspective. In our case, however, the camera is mounted at a height in order to monitor a large area and to satisfy the requirements of different applications. This leads to a top-view perspective. Most features that make up a pedestrian, such as legs and arms, are usually covered by his own body due to the perspective. Although there are already approaches for the general detection of pedestrians in top-view images, detecting pedestrian workers can be eased by leveraging typical features that characterize those workers. On construction sites, workers wear helmets and safety vests with loud colors that constitute clearly visible and prominent features. Similar to the findings of Park and Brilakis [8], a combination of motion, color and shape features is more reasonable for a proper detection in our scenario than the use of common pedestrian detection features. While motion is a fundamental characteristic of pedestrians, it is not unique to them. Other objects such as construction vehicles may also move through the cameras field of view so that classification by motion is unrewarding. Notwithstanding, constraining the detection to image regions containing movement spares the expense of investigating the entire image. This improves the subsequent detection in two different ways. On the one hand, applying the classifier to only a few regions of the image significantly speeds up the detection process which allows for the monitoring of substantially larger areas of a site without an appreciable time lag. On the other hand, this preselection excludes most image regions not containing any construction workers from the further detection process which highly reduces false positive detections in advance.

To find regions of motion, moving foreground objects have to be separated from the background. For this, Gaussian mixture models are frequently used and well established methods estimating a scene's background are available [8]. We use an improved adaptive Gaussian mixture model (GMM) approach [29], because it is insensitive to background movements that often occur in outdoor scenarios, e.g., wobbling bushes or leaves blown by the wind [30]. To adapt to changes in the scene such as varying illumination conditions or newly positioned and removed static objects, the mixtures of Gaussians are learned and adjusted over time. Accordingly, the Gaussians' parameters, their number per pixel, as well as the learning rate are updated online by the adaptive approach while applying the model to consecutive video frames.

Comparing the current frame (see Figure 4a) to the learned background model results in a binary segmentation image which indicates fore- and background pixels. In the next step, all connected foreground pixels are aggregated to regions of motion by blob detection. For this purpose, an algorithm is used that

finds contours in the binary image [31]. In a further step, rectangles are formed, which enclose each contour. These rectangles identify the RoIs for the further detection process, as shown in Figure 4b.

Since the background subtraction identifies the relevant areas within the image which are likely to contain pedestrian construction workers, the detector only has to focus on those image regions. Each of those RoIs contains a single moving object in the scene. Distinguishing between a worker and any other object can be done by a binary classification for each RoI.

As described above, construction workers in top-view images are characterized best by their motion, color and shape. The classification of color and shape features provides a proper basis of decision-making since we already use motion to find candidate regions. Following this, we use color histograms as they are simple but effective color feature descriptors. The histograms are computed on the hue and saturation channels of the HSV color space. Since the value channel decodes brightness, it is neglected in favor of the histograms' invariance to changes in illumination conditions [8], which is a common issue in outdoor scenes. For determining shape information, we decided on Haar-like features. The choice of these features allows us to design both feature descriptors as low-order integral channel features [32]. This guarantees the efficient computation of the feature responses which increases the detection speed. However, it remains unclear how to arrange those feature descriptors' positions and sizes on an image patch so that they optimally respond to a construction worker.

(**a**) Scene (**b**) Background image

**Figure 4.** Background subtraction example: (**a**) video frame of walking construction workers; and (**b**) background image with white foreground pixels aggregated to blobs of motion (red) and the corresponding RoIs (green) for further processing.

For this reason, we apply a Soft Cascaded Classifier [33] which learns the optimal set of features using AdaBoost [34]. For this, we deduce weak classifiers from the feature descriptors by thresholding their responses. Then, we generate a weak classifier pool containing thresholded feature descriptors at all conceivable sizes and positions in an image patch. During training, AdaBoost iteratively draws that weak classifier from the pool which separates the set of samples best. This way, a strong classifier emerges from the set of chosen weak classifiers. Figure 5 visualizes the process by the example of the first five Haar-like feature descriptors chosen by AdaBoost.

**Figure 5.** Iterative feature selection by AdaBoost (from left to right). A shape model is learned from Haar-like features. The first three iterations approximate the worker's head while the following two focus on the worker's shoulders.

By a subsequent calibration, the weak classifiers within the strong classifier are arranged into cascading stages. A sample is only passed further along the cascade if the evidence of belonging to the positive class is sufficiently strong. As a result, negative samples are rejected early in the cascade which drastically speeds up the classification process for a vast number of samples to be verified.

### *2.3. Hyperparameter Setting of YOLOv3*

Training YOLOv3 requires the adjustment of several hyperparameters. Despite extensive use of YOLO's built-in data augmentation features, the training dataset is too small for training a reasonable detector from scratch. For this reason, we base our training on the Darknet53 model which has been pre-trained on ImageNet [35]. To train YOLO on pedestrian workers, we pass the 1000 sample images of our training dataset in a mini batch size of 64 to the network. In the beginning, we use a high learning rate of 0.001 in order to quickly adjust the detector towards the new class. Every 3800 epochs, the learning rate is scaled down by a factor of 0.1, facilitating the learning process to converge towards an optimal result. For regularization, we adapt the weights by a momentum of 0.9 and a weight decay of 0.0005. After about 10,400 epochs, the validation error stops decreasing so that the training is stopped.

#### *2.4. Tracking Using Kalman Filtering*

Knowing the current position of construction workers as provided by a detector can already be advantageous for productivity management and safety applications. However, this can be further improved by tracking the workers over time. This allows keeping track of entire workflows as well as anticipating the worker's movement directions.

Modern tracking approaches rely on a motion model which predicts an object's trajectory and an appearance model to recognize the tracked object in the following video frames. However, it is sufficient for our purpose to rely on a motion model only. The applied detector compensates for the omitted appearance model. Hereby, the detector implicitly adopts the recognition task.

In this paper, we apply Kalman filtering. Its simplicity and robustness make it a good choice for the comparison of the performance of the two proposed detection methods. The Kalman filter is based on a motion model which only describes the relationship between the tracked object's current state and its predicted state at the next time step. For this, the model consists of the tracked object's current state and a prediction matrix which models the transition from one time step to another. A tracked worker's current state [*x*, *y*, *vx*, *vy*] can be described by the worker's position *x*, *y* and his velocity *vx*, *vy* in *x* and *y* directions. To predict the worker's state in the following time step *t* + 1, the prediction matrix is applied to the current state at time step *t* with a time duration Δ*t*,

$$
\begin{pmatrix} \mathbf{x}\_{t+1} \\ \mathbf{y}\_{t+1} \\ \mathbf{v}\_{\mathbf{x},t+1} \\ \mathbf{v}\_{\mathbf{y},t+1} \end{pmatrix} = \begin{pmatrix} 1 & 0 & \Delta t & 0 \\ 0 & 1 & 0 & \Delta t \\ 0 & 0 & 1 & 0 \\ 0 & 0 & 0 & 1 \end{pmatrix} \begin{pmatrix} \mathbf{x}\_{t} \\ \mathbf{y}\_{t} \\ \mathbf{v}\_{\mathbf{x},t} \\ \mathbf{v}\_{\mathbf{y},t} \end{pmatrix}. \tag{1}
$$

The predicted state is then set to the new current state of the Kalman filter's model. Afterwards, the state can be updated using the actual measurements of the worker's position [*xt*, *yt*]. In our case, the detected construction workers serve as both tracker initializations and observed measurements. The detections are spatially correlated with regard to already existing tracks. Detections with a close spatial proximity are assigned to the individual track. For detected workers not matching a pre-existing track, a new track is set up. The Kalman filter then predicts the location of a worker frame by frame. Detections close to the predicted locations which we assign to the track give evidence to correct the Kalman filter's motion model and prevent the predictions from drifting off.

To determine optimal values for the noise covariance matrices *Q* and *R*, we performed a grid search using both detection algorithms. Values on the matrices' diagonals in the range of [1, 50] were considered for this. We found that the optimal values for the diagonals of *Q* and *R* are 1 and 50, respectively.

### **3. Optimization**

Although some applications such as workflow optimization are commonly executed asynchronously to the recordings, other applications such as the assistance of machine operators require current worker's trajectories. Accordingly, a reasonable approach to the monitoring of pedestrian construction workers demands real-time capability so that worker's locations trajectories are provided for every video frame given by the camera. To properly compare the CNN-based and classical computer vision approaches, we optimize their detection quality with respect to a sufficient speed with regard to the camera's specification.

The detection speed of the soft cascaded classifier is mainly determined by the number of weak classifiers constituting the strong classifier. They define the depth of the cascade and, thus, the quantity of feature applications required to correctly classify a given sample. Besides, the number of training samples and the features' explicit manifestations situated in the feature pool highly affect this method's detection quality. In Section 3.1, we optimize the soft cascaded classifier with regard to these parameters and discuss the chosen values.

YOLO's detection speed is restricted by a single hyperparameter which is the size of the input images. The larger are the input images, the more convolutions are involved, which results in substantially slower processing speed. Accordingly, Section 3.2 addresses the optimal choice of the input image size for a proper detection with respect to the real-time capability.

### *3.1. Optimization of the Soft Cascaded Classifier*

A considerably high precision of the detector is desirable in order to satisfactorily initialize the trackers by only actual construction workers [8]. This prevents the initialization of false positive tracks as well as tracker updates based on false detections. However, to be suitable for a monitoring application, the detector must be able to identify workers in real-time in the first instance. To obtain a satisfying classifier in terms of detection speed with a preferably high precision at an acceptable recall, we conduct a grid search by varying different training parameters.

In preliminary experiments, we already found that subdividing the color histograms into five bins is sufficient for the classifier to properly recognize the characteristics of a construction worker's helmet and safety vest. Other parameters to be investigated include the minimal size of the feature descriptors provided by the feature pool. Too small feature descriptors may affect the classification. The feature descriptors' sensitivity to noise increases with decreasing size which may impair the classification. Similar results hold for slight translations, rotations and scalings of the object to be detected within the RoI provided by the background subtraction. By this, too small feature descriptors may easily be positioned off the corresponding feature. On the other hand, large feature descriptors may miss small features. For these reasons, we investigate the effect on the feature descriptors' sizes to the detection quality. We vary the minimal feature size of the feature descriptors provided in the feature pool for training between 10% and 30% of the given image patch size. Additionally, we vary the quantity of training samples from 200 to 2000 to find the optimal generalization behavior. Finally, to speed up the recognition process, we determine the minimum number of weak classifiers required for a reliable classification. For that, we vary the number of weak classifiers constituting our cascade from 50 to 350. To compare the detection results while varying a parameter, we juxtapose the classifiers' receiver operating characteristic (ROC) curves. For this, we apply the trained classifiers to the validation dataset and plot their true positive rate (TPR) against their false positive rate (FPR) with respect to the particular classifier's confidence. Figure 6 illustrates the ROC curves for all three parameters. In addition, we calculate the accuracy of the classifier when the number of weak classifiers is varied. During the variation, we apply the classifier to the validation dataset and to the training dataset and plot the accuracy curves. The results are shown in Figure 7.

As can be derived from Figure 6a, providing a feature pool in which feature descriptors have a minimal size of 10% of the image patch size for the classifier yields best classification results. The classification quality decreases with increasing minimal feature size. This shows that the classification is not impaired but even improved by rather small feature descriptors. Noise and the worker's positioning seem to have only little effect, if any.

The effect of different amounts of training samples are shown in Figure 6b. As can be seen, the classification quality increases up to a total number of 1600 training samples. From then on, the quality starts to decrease again. This shows that the classifier looses essential features that describe a construction worker in general terms if more than 800 positive and negative samples are used.

As Figure 6a depicts, the general classification quality increases with the number of the cascaded weak classifiers. Nevertheless, the ROC curves of classifiers consisting of more than 100 weak classifiers resemble each other closely for very small FPRs. Since a classifier's FPR is inversely proportional to its precision, we focus our further evaluation on those small FPRs to ensure a preferably high precision which is mandatory for a proper functioning of our monitoring approach. Although the general performance of the largest cascade exceeds that of all others, the classifier consisting of 152 weak classifiers yields the lowest FPR up to a TPR of 96.2%.

The findings in Figure 6c are also supported by the results in Figure 7. The evaluation up to 150 weak classifiers shows that, in general, the accuracy of the classifier increases. If the classifier contains 152 weak classifiers, it reaches an accuracy of 98 % on the validation set. Above 150 weak classifiers, the accuracy of the classifier decreases slightly on the validation set. In contrast, the accuracy increases slightly on the training set. This gives rise to suggesting a slight overfitting.

Based on these findings, we conclusively train a soft cascaded classifier for the application in our approach to the monitoring of workers on construction sites. We provide a pool of feature descriptors consisting of Haar-like features and five-bin color histograms. We add feature descriptors beginning with a size corresponding to 10% of the samples' sizes and iteratively increase their sizes by further 10% up to a size of 100% of the samples' sizes. For training, we initially provide 800 samples, each positive and negative, randomly chosen from our training dataset. We then let AdaBoost choose 152 weak classifiers during training which classify the set of training samples best. This halves the time required for the later classification process while retaining an acceptable detection rate. After calibrating this soft cascaded

classifier and integrating it into our classical detection setup, we are able to process images at about 28 fps, which even exceeds the frame rate of the camera used for our experiments.

(**c**) Number of weak classifiers **Figure 6.** ROC curves indicating the effect of different hyperparameters on the detection quality.

**Figure 7.** Calculated accuracy of the classifier when varying the number of weak classifiers.

#### *3.2. Optimization of YOLOv3*

Prior experiments showed that YOLO's detection results are considerably robust concerning changes in hyperparameters such as the learning rate, momentum or weight decay. According to these experiments, an optimization of these parameters would not significantly improve the detection results. Therefore, we use these already empirically collected values in our work and do not perform any fine tuning. The most important factor, however, is the size of the input image, since this influences both the detection speed and the precision of the detector. Thus, we conduct a grid search to find the optimal input image size for our purpose. For this purpose, we measure the detection speed in milliseconds (ms) for each image size and calculate the mean Average Precision (mAP) to determine the detection accuracy. Due to YOLO's downsampling architecture the input image size should be a multiple of 32 px in width and height. We begin the grid search at a factor of 13 for both image dimensions resulting in an image size of 416 × 416 px and gradually increase the factor by 3 up to a final input image size of 800 × 800 px. All these instances of YOLO are trained on the entire training dataset described in Section 2.1 and are evaluated on the validation dataset.

We found that the detection accuracy generally rises with an increasing image size, although the detection speed drops simultaneously (see Table 1). At the maximum image size of 800 × 800 px, YOLO yields the best detection accuracy. The processing time for a single image, however, is 97 ms which corresponds to about 10 fps when processing each image of a video. With a processing time of 46 ms per image (22 fps), using the minimal input image size is still too slow to run the detection on every single frame of our video sequences in real-time. Since our camera captures frames at 25 fps, we decided to use an input image size of 608×608 px as tradeoff. This way, we are able to process at least every second frame with a desirable precision.


**Table 1.** The results of detection accuracy and detection speed for different image sizes.

### **4. Results**

In our experiments, we examined the detection quality as well as the capabilities of the approaches, as the basis of a tracking system. To determine their detection quality, we applied both optimized approaches to our test data (see Section 4.1). Afterwards, we used both detection methods for tracker initialization and the recognition of already tracked workers. We examined the precision of the resulting tracks as well as the general ability to accurately identify construction workers, as shown in Section 4.2.

#### *4.1. Detection Quality*

To contrast the performances of the optimized detection methods introduced in this paper, we applied both methods to the evaluation dataset described in Section 2.1. For the comparison, we adduced the particular precision and recall of each method on this dataset. Since tracking the workers in a centimeter-perfect manner is usually not required, we defined an Intersection over Union (IoU) value of at least 0.6 to be sufficient to indicate a true positive detection. On our evaluation data, the classical approach exhibits a recall of 96.2% with 99.8% precision. Contrarily, YOLO only achieves a recall of 88.2% with a similar precision of 99.2% using a confidence threshold of 0.9. Even reducing the threshold to 0.5 results in a recall of only 93.5% while precision drops to 97.0%.

#### *4.2. Tracking*

Keeping track of the workers is the main purpose of monitoring applications. Accordingly, the tracking should be preferably accurate to provide satisfying results. The detectors of such monitoring systems are the key components for a reliable tracking as their detections serve as initializations and updates for the tracks. In this experiment, we compared the performances of both detectors developed in this paper with respect to the aforementioned requirements. We applied Kalman filtering to their detections to implement a simple and easy to evaluate tracking system.

Each generated tracker has the same dimension during evaluation. The selected size was set manually before, so that the rectangle completely encloses the construction worker in the center of the image. If the detector does not detect a construction worker who is already tracked, the Kalman filter determines its next position only using the prediction. This allows tracking a construction worker who is in motion but is covered, for example, by an object temporarily. Since the accuracy of the position determination decreases without correction of the motion model, these predictions are executed up to a maximum of one second. Otherwise, the track is erased and a new one is set up for the construction worker when the detector detects him again. The prediction is also used to mark the direction of the worker's movement in the image. For this, the next position in the current image is predicted. Construction workers who move towards a dangerous area can be better identified this way.

The precision and continuity of the underlying detector primarily determines the accuracy and robustness of the tracks. We relied on the approach of Xiao and Zhu [36] to measure these metrics. They proposed the average sequence overlap score (AOS) and center location error ratio (CER) measures for accuracy and track length (TL) for robustness. Ambiguity errors caused by tracks crossing each other

are not taken into account since this is a feature of the tracker rather than of the detector. We adapt those metrics to fit to our experimental setup, as shown in Equations (2)–(4).

$$\text{AOS} = \frac{1}{n} \sum\_{t=1}^{n} \frac{A\_t^G \cap A\_t^T}{A\_t^G \cup A\_t^T} \,. \tag{2}$$

$$\text{CER} = \frac{1}{n} \sum\_{t=1}^{n} \frac{||\mathbf{C}\_t^G - \mathbf{C}\_t^T||\_2}{\text{size}(A\_t^{\widehat{\mathbf{C}}})},\tag{3}$$

$$\text{TL} = \frac{n}{N} \tag{4}$$

where *t* represents the current time step, *n* denotes the number of video frames in which a worker is tracked and *N* is the number of frames in which the worker is present. The worker's bounding box areas are indicated by *A* and their centers by *C* where the superscripts *G* and *T* mark ground truth and tracked bounding boxes, respectively. Finally, || ◦ ||<sup>2</sup> denotes the two dimensional euclidean distance and size(◦) represents the two dimensional size of an area.

For an eligible comparison of their performances, we evaluated the tracking systems, emerging from each detector combined with Kalman filtering, to identical scenes of construction sites. These scenes are chosen as described in Section 2.1. None of them have been shown to any detector before, neither during training nor during calibration and validation. In the following, we refer to Scene 1 as the scene of four pedestrian workers from the video sequence which is well illuminated, while Scene 2 denotes its overexposed counterpart from the other sequence. The scene showing a car approaching a worker is referred to as Scene 3. The tracking results using the classical detection approach are given in Section 4.2.1. Section 4.2.2 discusses the results of the tracker relying on YOLOv3.

#### 4.2.1. Tracking Results Using a Classical Machine Learning Detector

We applied the tracking system based on the classical machine learning detector to all three evaluation scenes. To each resulting track, a random ID was assigned. Further, we determined the AOS, CER and TL for each track separately. The performance of this system on each of the three evaluation scenes is summarized in Table 2.

**Table 2.** Results of the classical machine learning tracking system applied to all three evaluation scenes. For each scene the resulting measures according to Equations (2)–(4) are given. Tracking IDs are assigned randomly to each particular track. The last row highlights the performance of this system averaged over all scenes.


The results indicate a significant decrease in tracking quality for overexposed scenes. This becomes clear when comparing the results of Track 5 with those of the other tracks of the first scene. As shown in Figure 8, Track 5 is located in a highly overexposed section of the scene, whereas the other tracks traverse well-illuminated sections only. This complicates the detection and the tracking consequently results in an inaccurate location, as illustrated in Figure 8a. This figure shows the AOS measure by comparing ground truth data in green to the actual tracks in red and the CER measure by blue lines indicating the distance between the labels' centers. As can be seen, Tracks 0 and 8 in the well-illuminated area match the ground truth closely and their centers are also close to each other. In contrast, the overlapping area of Track 5 in the highly overexposed section and its corresponding ground truth label is substantially smaller, whereas their centers' distance is considerably large. Such deviations during the tracking impair its accuracy which becomes visible from the jagged walking path determined for Track 5 in Figure 8b. This is also supported by the TL of only 0.92. Workers in overexposed areas are harder to detect so that the tracking begins delayed, which results in a shortened track length.

In Scene 2, these effects are much less pronounced since the scene is only slightly overexposed. As the comparison of the AOS shows, the difference in the labels' overlap of Scenes 1 and 2 is only 0.025 on average, and, even if the outlier (Track 5) in Scene 1 is disregarded, the difference raises to only about 0.086. The average TL in this scene also decreases only moderately compared to Scene 1. By depicting an example of Scene 2, Figure 9a illustrates that the quality of our tracking approach is sufficient even on slightly overexposed scenes. This finding is supported by measuring the TL and the CER. Both do not vary significantly from Scene 1 to Scene 2. All construction workers are consistently tracked almost throughout the entire duration of both scenes. By applying the tracking system to Scene 3, we show its behavior in the case of moving objects which are not pedestrian construction workers. The results for Scene 3 in Table 2 depict that only one track is set up. As graphically confirmed by Figure 9b, this track belongs to the only worker in this scene. The worker is tracked successfully and no further tracks for the non-worker objects are mistakenly generated. Again, the AOS achieved by the classical system is within an acceptable range, and CER and TL confirm that the worker is tracked consistently throughout the whole scene's duration.

**Figure 8.** Example of the tracking performance of the classical approach on Scene 1. (**a**) Illustration of the AOS and CER measures. Ground truth labels (green) are compared to the actual tracks (red). Blue lines indicate the distance between the labels' centers. (**b**) Tracking result. The worker's current positions are given by randomly colored squares. Equally colored lines illustrate their previous walking paths. Arrows show their predicted walking direction.

**Figure 9.** Examples of the classical tracking system applied to different scenes. (**a**) An example of the tracking system applied to Scene 2 shows that the tracking is sufficient even on slightly overexposed scenes. (**b**) On Scene 3, the tracking results show that the worker is tracked successfully while the moving vehicle remains untracked as intended.

#### 4.2.2. Tracking Results Using a Deep Learning Detector

To ensure a proper comparison, the evaluation of the deep learning-based tracking system is done analogously to the evaluation of the classical approach. Again, IDs are assigned randomly to the tracks, and the accuracy and robustness metrics are determined per track. Table 3 summarizes the performance results of the tracking.

**Table 3.** Results of the deep learning-based tracking system applied to all three evaluation scenes. For each scene, the resulting measures according to Equations (2)–(4) are given. Tracking IDs are assigned randomly to each particular track. The last row highlights the performance of this system averaged over all scenes.


The tracking system exhibits a satisfactory performance on average over all sequences. Especially on well-illuminated scenes, a high accuracy is provided as indicated by the AOS and CER of Scene 1. The system's performance on this scene is very robust since all workers were tracked almost over the entire length of this video sequence. However, Track 1, which is located in a heavily overexposed area, shows that overexposure affects the quality of the tracking. As the TL shows, the track remains robust, but the accuracy measured by the AOS decreases by about 10% This is also confirmed by a slight increase in the CER. Nevertheless, the system performs very well if the scene is only slightly overexposed as is the case in Scene 2. Here, the AOS barely indicates any negative effects despite the overexposure. Only the CER shows a minor increase. Similar to the highly overexposed case, the TL measurements again emphasize the tracking system's robustness, although slight overexposure is present on the entire scene.

Besides this, the results of this experiment also show that the system does not confuse construction workers with any other object in the scenes. As Table 3 shows, the correct number of tracks was set for each scene. Alongside the four construction workers, Scenes 1 and 2 include static construction related objects, such as traffic cones and barriers. These stay correctly undetected and untracked during the course of the video sequences. Moving objects such as the car in Scene 3 also remain correctly unnoticed by the tracker.

#### **5. Discussion**

As shown by our experiment regarding the detection quality, both methods outperform the approach of Luo et al. [24] by far. While Luo et al. achieved a precision of 75.1%, our approaches reach 99.8% and 99.2%, respectively. A proper comparison of the results is delicate since the datasets used for training and evaluation differ from each other. However, tests using YOLO with the same input image size as proposed by Luo et al. yield similar results, which indicates at least a weak comparability of the setups. This confirms that the input image size is a crucial parameter for a proper detection when using YOLO, as the objects to be detected shrink proportionally to the size of the images. Nevertheless, this experiment also shows that the results of the classical method exceed those of YOLO by about 5%. As mentioned above, the image size may be a reason for this. The classical approach operates on the original high resolution, exhibiting more detailed features than the drastically downscaled images used by YOLO. A second reason may arise from the limited dataset since it is known that CNNs require a vast amount of training data. The number of provided training samples may have been too small to sufficiently adapt to the class of workers ,although we deliberately used a pre-trained network which already developed general feature descriptors for classification. On the contrary, the findings in Section 3.1 highlight that classical computer vision approaches cope significantly better with fewer samples than CNNs.

Both reasons emphasize common practical issues concerning data gathered in the context of construction site monitoring. Since computer vision approaches are rarely applied in the field, dataset generation has rarely been regarded a subject until now. Beyond this, generating a reasonable dataset covering all environmental and lighting conditions is extremely time-consuming and tedious. Given such a dataset, the small size of the pedestrian workers within the images remains to be a limiting factor. Thus, choosing a classical detection approach over a CNN is desirable to obtain the best detection results for the purpose of monitoring pedestrian workers.

However, our tracking experiments showed that both approaches yield suitable results. As can be seen from the results of Scene 1, the evaluated systems perform similarly well on a well-illuminated environment. Both approaches exhibit an excellent accuracy with very low CER and high AOS of about 90%. Furthermore, the high TL rates indicate their robustness. The slight deviations from the optimum are mainly due to workers entering and leaving the scene. In these situations, workers are located at the image boarders and may be only partly visible. This complicates the detection of the workers so that tracks may be set up late or may terminate early. Apart from these effects, the tracking performed by both approaches is very precise despite the CNN's limitations in detection quality and speed. This shows that even simple tracking methods can compensate for lower detection rates. However, Scene 1 also shows that both approaches have issues with high overexposure. In both cases, the AOS of the regarding track drops noticeably accompanied by a raise in CER. The worker is still tracked sufficiently by both systems but the accuracy dramatically suffers from the overexposure. The CNN-based approach seems to cope with this issue slightly better than the classical system. This is no reliable statement yet as there is only a single instance of evidence available in the data. The results of Scene 2 provide further insights into the subject of overexposure. On this scene, the TL rates remain almost stable indicating a satisfying robustness. This ensures reliable tracking with both tracking systems despite a certain degree of overexposure. However, the CNN-based approach achieves significantly better AOS rates in this slightly

overexposed environment. The CER rates fluctuate more than with the classical approach. On the one hand, the considerable decrease in AOS shows that the classical system has difficulties in identifying the worker's dimensions precisely when overexposed. This is also supported by the findings regarding the overexposed Track 5 of Scene 1. At the same time, the stable CER rates indicate that the worker's centers—and thereby their locations—are recognized with high accuracy. On the other hand, the CNN-based system reveals its weakness concerning a precise localization, as shown by the varying CER. Instead, it accurately determines a worker's dimensions despite aggravated conditions. With this knowledge, the size of the detected area could in the future be passed to the tracker so that the worker is for the most part completely marked on the image. For the classical approach, a fixed size of the tracking box would be more suitable, which has to be determined in advance depending on the camera height.

This points out that both approaches possess certain pros and cons. Under ideal conditions both perform similarly well but as conditions change, they start focusing on different aspects. While the classical approach precisely keeps track of the worker's locations, the CNN-based system accurately recognizes their dimensions. Accordingly, a conclusive decision has to be made with respect to the demands of particular applications.

#### **6. Conclusions**

In this study, we investigated whether deep learning methods surpass classical approaches on construction worker monitoring despite their limitations. We chose YOLOv3 for the CNN and a classical approach based on a soft cascaded classifier as representatives for our comparison. The trained detectors were then embedded in a tracking system to track construction workers in video sequences. To evaluate the tracking systems under various conditions, we generated different video sequences. These contain different environmental and lighting conditions as well as stationary and moving non-worker objects. Both tracking systems were applied to the same sequences to ensure a proper comparison.

As our experiments showed, the classical approach clearly outperforms the CNN on the detection task in terms of quality and speed. The lack of quality is most likely due to an insufficient amount of training data and the heavy downscaling of the images. The low detection speed of substantially less than 22 fps is affiliated to the high computational costs. However, the tracking experiment showed that the CNN's drawbacks are fully compensated even by a simple tracking method. Both approaches showed satisfying results when tracking workers under ideal conditions. They were even able to suppress a false tracking of any stationary or moving non-worker object. Nevertheless, both tracking systems reveal deficiencies when applied to overexposed conditions. While the CNN keeps precise track of the worker's dimensions, localization becomes inaccurate. The classical approach behaves exactly vice versa. According to these findings, there is no optimal monitoring approach in general. A suitable approach has to be chosen with respect to the demands of the particular application.

With our example scenarios, we showed that with both approaches a satisfactory tracking of the construction workers from the top view is possible. Nevertheless, it would be an advantage to know whether satisfactory tracking is also achieved with more complex work processes with different moving machines, different construction site constellations and different weather conditions. Therefore, future work should compare the presented tracking approaches with scenes that have a different focus. In this case, optical flow techniques can additionally be used to enable tracking also with cameras mounted, e.g., on the crane boom. The background subtraction would consider known crane movements and thus dynamically changing backgrounds would be taken into account. Furthermore, future work should concentrate on substantially increasing the size of the datasets. By this, the training of CNN-based approaches can be improved. This helps the detector to develop a better generalization ability, which may

advance the tracking results especially under difficult conditions. Reasonable strategies for the extension may be alternative data augmentation techniques as well as computer generated data.

**Author Contributions:** Conceptualization, M.N.; data curation, M.N.; investigation, M.N.; methodology, M.N.; project administration, M.N.; software, M.N. and D.P.; supervision, M.K.; validation, D.P.; visualization, D.P.; writing—original draft, M.N.; and writing—review and editing, M.K. All authors have read and agreed to the published version of the manuscript.

**Funding:** This research received no external funding.

**Conflicts of Interest:** The authors declare no conflict of interest.

### **References**



**Publisher's Note:** MDPI stays neutral with regard to jurisdictional claims in published maps and institutional affiliations.

© 2020 by the authors. Licensee MDPI, Basel, Switzerland. This article is an open access article distributed under the terms and conditions of the Creative Commons Attribution (CC BY) license (http://creativecommons.org/licenses/by/4.0/).

## *Article* **Using Synthetic Data to Improve and Evaluate the Tracking Performance of Construction Workers on Site**

### **Marcel Neuhausen \*, Patrick Herbers and Markus König**

Chair of Computing in Engineering, Ruhr-University Bochum, 44801 Bochum, NRW, Germany; patrick.herbers@ruhr-uni-bochum.de (P.H.); koenig@inf.bi.ruhr-uni-bochum.de (M.K.) **\*** Correspondence: marcel.neuhausen@ruhr-uni-bochum.de

Received: 19 March 2020; Accepted: 16 July 2020; Published: 18 July 2020

**Abstract:** Vision-based tracking systems enable the optimization of the productivity and safety management on construction sites by monitoring the workers' movements. However, training and evaluation of such a system requires a vast amount of data. Sufficient datasets rarely exist for this purpose. We investigate the use of synthetic data to overcome this issue. Using 3D computer graphics software, we model virtual construction site scenarios. These are rendered for the use as a synthetic dataset which augments a self-recorded real world dataset. Our approach is verified by means of a tracking system. For this, we train a YOLOv3 detector identifying pedestrian workers. Kalman filtering is applied to the detections to track them over consecutive video frames. First, the detector's performance is examined when using synthetic data of various environmental conditions for training. Second, we compare the evaluation results of our tracking system on real world and synthetic scenarios. With an increase of about 7.5 percentage points in mean average precision, our findings show that a synthetic extension is beneficial for otherwise small datasets. The similarity of synthetic and real world results allow for the conclusion that 3D scenes are an alternative to evaluate vision-based tracking systems on hazardous scenes without exposing workers to risks.

**Keywords:** construction productivity; construction safety; deep learning; synthetic data; tracking

### **1. Introduction**

Vision-based detection and tracking have already found their way in a wide area of applications. Pedestrian detection has become an essential topic in the modern automotive industry and is also a relevant part of various surveillance systems. Even sports analytics make use of these techniques for the assistance of referees as well as for the automated generation of game statistics. Despite the huge potential, such approaches are rarely employed in the construction sector nowadays. Especially processes on construction sites could benefit from computer vision methods. As construction sites are complex and continuously changing environments, keeping track of the ongoing processes can be challenging. Various trades are involved at each stage of construction and different work orders are executed by workers and construction machines simultaneously. On the one hand, this can affect an individual worker's workflow since the complex processes are hard to grasp. On the other hand, workers can easily lose track of the ongoing processes in their surroundings while concentrating on their own tasks. This inadvertency may result in hazardous situations. Excavator drivers may fail to notice unaware pedestrian workers crossing their paths. Also, crane operators may lift their loads over workers standing in blind spots.

By monitoring pedestrian workers on site, their current workflows can be optimized [1]. Furthermore, assistance systems can support machine operators in avoiding hazards involving

workers [2]. However, a reliable and precise tracking system is a prerequisite for both applications. A multitude of approaches in this field apply tag-based methods where tags are attached to all objects to be monitored. Depending on the specific use case, Radio-Frequency Identification (RFID), Ultra-Wideband (UWB), or Global Navigation Satellite System (GNSS) technology is employed [3]. Approaches using GNSS are usually applied to the localization of equipment, workers, and machines in spacious outdoor construction environments [4,5]. Near to large structures like buildings, walls, and other construction components, GNSS becomes unreliable as it suffers from multipath effects [6]. Tracking workers in the proximity of heavy construction machinery or cranes carrying large loads may, thus, become too inaccurate for operator assistance systems. Accordingly, RFID- and UWB-based approaches are commonly used to improve the safety during construction. By equipping workers with RFID tags, operators get warned if a worker enters the range of a machine's tag reader [7]. Warning workers of entering hazardous areas can be achieved by positioning the workers using UWB tags and readers [8]. Nevertheless, in general a precise localization by means of radio-frequency technologies remains challenging [9]. Beyond that, all tag-based approaches involve high costs for the amount of required tags and readers [3]. Additionally, wearing such tags causes discomfort for the workers [10]. Camera-based monitoring approaches overcome these deficiencies. They allow for the tracking of workers even close to large structures as they neither suffer from multipath effects nor from other signal interferences leading to an inaccurate positioning. Additionally, they constitute an affordable alternative to tag-based solutions since only a few cameras are required to monitor large areas of construction sites.

In recent years, some effort has already been made to monitor pedestrian workers using computer vision techniques. Firstly, pose estimation is used to recognize specific sequences of movement as well as unsafe actions [11,12]. Secondly, workers are detected in video images using Support Vector Machine (SVM) and k-Nearest Neighbors (k-NN) classifiers [13] and are then tracked over time [14]. More recent approaches employ Convolutional Neural Networks (CNNs) for both detection and tracking purposes [15,16]. However, the detection results of such approaches are in need of improvement. Park and Brilakis [13] state recall rates of 87.1 % to 81.4 % at precision rates ranging from 90.1 % to 99.0 %, respectively. Luo et al. [15] report an Average Precision (AP) of 75.1 % with an Intersection over Union (IoU) threshold of only 0.5. Although these results are as of yet insufficient for the implementation of a reliable safety-relevant system, its considerable potential can already be inferred. To improve the detection allowing for an uninterrupted tracking, the training of the underlying machine learning algorithms has to be enhanced. Furthermore, the subsequent evaluation of those systems has to be extended in order to model a broader variety of scenarios which may occur during a real application. This ensures a reliable tracking during productive operation even under challenging conditions. Training and evaluation both require a vast amount of data. This especially holds for deep learning approaches as they usually require a multitude of training samples compared to classical machine learning techniques. However, sufficient datasets rarely exist for construction site scenarios and gathering a comprehensive dataset is time consuming and tedious. Although common data augmentation techniques, such as randomly cropping, flipping, or rotating samples, enable the extension of existing datasets, the effect to the training of a detector is limited. While this increases the variability in the dataset, adapting the training towards scenarios not contained in the initial dataset is unfeasible. Moreover, assembling realistic evaluation scenarios might also require exposing workers to hazards to determine the behavior of a tracking system in such situations.

A solution to the data problem may be synthetic data. Parker [17] defines synthetic data as "any production data applicable to a given situation that are not obtained by direct measurement". While this definition was originally meant for mathematical models, it can also be applied to machine learning datasets. The complexity of synthetic data generation can range from simple mathematical equations, to fully simulated virtual environments. Synthetic data was mainly used in software development processes, but has lately found application among machine learning research. Generating required data synthetically might be desirable as synthetic data has the ability to extend

existing datasets or to create new datasets with a significant reduction in effort. While traditional datasets have to be aggregated and labeled by hand, synthetic data can be accurately labeled automatically. As the ground truth is already available in a simulated environment, labeling becomes trivial. Synthetic data can be created completely from scratch (fully synthetic data) or based on real datasets for data augmentation (partially synthetic data). For example, Jaderberg et al. [18] created a dataset for text recognition by synthetically generating distorted, noisy images of text from a ground truth dictionary. Gupta et al. [19] improved on the idea of synthetic text recognition by rendering text onto real images. Tremblay et al. [20] used synthetic data as a form of domain randomization where random objects are placed in various environments. The amount of possible combinations enables a CNN to recognize the important part of a picture with minimal effort in dataset generation. The size of a synthetic dataset based on a 3D environment is nearly unlimited as can be seen in the SYNTHIA dataset [21]. Utilizing dynamic objects and render settings, scenarios can be varied indefinitely. One popular technique is co-opting video games with realistic graphics for various imaging tasks, such as depth estimation [22] or image segmentation [23]. Furthermore, synthetic datasets can be augmented by creating variations of lighting or environmental conditions like rain or fog. Tschentscher et al. [24] created a synthetic dataset of a car park in Unreal Engine 4 where camera angles and weather conditions can be changed at will. Through the use of this dataset, Horn and Houben [25] improved a k-NN based parking space detection algorithm by six percentage points while simultaneously reducing the time for generating the dataset.

As this paper focuses on the improvement of the detection and tracking of construction workers, a few particularities have to be taken into consideration. Construction sites depict more complex and harder to model environments than the well structured car park scenarios in [25]. Sites are essentially less structured and subject to greater changes. Construction material, machinery, and workers operate or interact with each other almost everywhere on the site. In the course of this, occlusions of the workers by construction machines and their loads occur frequently. Additionally, modeling human shapes and behaviors realistically is still a challenging task. Compared to the simple trajectories of cars, a worker's movement is significantly more complex and harder to predict. Dynamic environments such as these require a diversified dataset that covers all cases of working on a construction site, including dangerous situations which would be unethical to re-enact in a real world scenario.

Lately, image style transfer learning such as Cycle-Consistent Generative Adversial Networks (CycleGANs) [26] has been used for data augmentation. CycleGANs are able to learn a bidirectional image style mapping between two unlabeled image domains. The two image sets for training the GAN do not need to be paired sets, but there is still a significant amount of training data required. Wang et al. [27] use this technique to enhance their synthetic dataset for crowd counting by augmenting the synthetic dataset with image styles from existing datasets. Since the synthetic data in [27] is collected from a video game, the image style is significantly different from actual camera footage. However, the small size and low variation in environmental conditions of common real world datasets of construction sites complicate the training of a GAN. Additionally, in this work, full control over the synthetic scenes' setup is required to sufficiently model potential hazardous situations, which includes the image style. Thus, this paper does not use CycleGANs for data augmentation.

Hence, the aim of this paper is to investigate the usage of synthetically generated data for improving both, the training and the evaluation of construction site monitoring systems. For this, several 3D scenes of construction sites with varying lighting and environmental conditions are modeled. Sequences of these scenes are extracted to extend a self-recorded real world dataset. In order to identify the effects of synthetic data, a monitoring system is built that tracks pedestrian construction workers over time. The system detects workers in consecutive video frames using YOLOv3 [28] and tracks them by Kalman filtering. To prove the concept of this paper, first the detector's performance is evaluated when it is trained on the real world dataset only. Then, the results are compared to the performance when synthetic data is gradually added to the set of training samples. We found that providing well-chosen computer generated scenes to the training set improves the detection performance by

7.5 percentage points in mean average precision (mAP). In the end, the detector achieves a mAP of 94.0 %. If inaccuracies resulting from manual labeling and image scaling are taken into account, the mAP even increases to 99.2 %. In another setting, the tracking performance of the proposed monitoring system on real world and synthetic scenarios is compared. The experiment shows that the tracking system yields similar results on both kinds of datasets. This illustrates that synthetic data provides a reasonable alternative for the evaluation of the tracking performance on hazardous situations without exposing workers to risks.

The remainder of this paper is organized as follows: Section 2 introduces the datasets as well as the tracking system. Section 3 summarizes two experiments and their results to prove the proposed concept. First, the use of synthetic data for the training of the detector is investigated in Section 3.1. Second, Section 3.2 examines the comparability of real world and synthetic data in terms of the evaluation of a tracking system. The results of these experiments are discussed in Section 4 and an exemplary use case evaluating such a monitoring system is shown. Finally, Section 5 summarizes the findings and provides an outlook to future work.

### **2. Materials and Methods**

In order to investigate the effect of synthetically generated data on the training and evaluation of a vision-based tracking system in the scope of construction sites, we created a dataset from computer generated 3D scenes. This dataset is used to augment a small self-recorded dataset (see Section 2.1) of pedestrian workers walking across a construction site. As described in detail in Section 2.2, we modeled multiple 3D scenarios which depict various lighting and environmental conditions as well as different movement patterns for the represented workers. To evaluate our approach, we built a monitoring system that tracks pedestrian workers in video sequences. The underlying detector's hyperparameter settings are described in Section 2.3 while Section 2.4 outlines the tracking process, respectively.

### *2.1. Real World Dataset*

Since video datasets of construction sites including pedestrian workers from a top view perspective are rarely publicly available, we assembled a small generic dataset on our own. For this, we recorded several sequences of walking construction workers using a camera mounted in a height of around 20 m. This results in a bird's-eye-like view in the center of the images whereas the view becomes oblique near the borders. A reasonable dataset should exhibit a large variation in the data samples. This includes different lighting and environmental conditions of the scenes as well as various objects and differences in their appearance. To realize variation in our dataset, we chose two environmental scenarios. These contain different ground textures and lighting conditions as well as different types of construction sites and present objects. While the ground in one scene is uniformly paved (see Figure 1a), the other scene exhibits sandy, paved, and graveled parts, as shown in Figure 1b. The scenes are recorded at different daytimes to include different but realistic lighting conditions. This results in a well-exposed to overexposed illumination in the one scene at noon whereas in the other scene, in the evening, long shadows and lateral illumination are present. The scenes also vary in their setup and item selection. One shows a roadwork-like setup containing parked and driving cars as well as striped delineators, poles, and garbage cans. The other is a building work-like scene including a construction barrier and barrier tape. In both scenes, there are areas of different illumination conditions ranging from slightly underexposed to a high overexposure. The pedestrian workers wear safety vests and helmets throughout and walk randomly across both scenes including sudden stops and directional changes. They also interact with and relocate static objects like pylons and barriers.

The resulting dataset is labeled manually to provide ground truth for training and evaluation purposes. For this, workers are identified by rectangular bounding boxes closely framing their heads, shoulders, and safety vests. In Figure 1, these labels are represented by green rectangles.

**Figure 1.** Comparing the different environmental scenarios of the real world dataset in (**a**,**b**) to their 3D generated counterparts in (**c**,**d**), respectively. Ground truth labels are marked as green rectangles.

The dataset is split into two separate subsets. A share of the recorded video sequences is assigned to each subset. This way, the construction worker detector can be trained on one subset while the evaluation is carried out on the other subset unknown to the detector until then. The shares are assigned in a way that both environmental scenarios are equally distributed within each. Each subset only obtains every tenth frame of the assigned sequences in order to ensure a sufficient amount of variation among the images. The number of workers visible in each frame varies as workers occasionally entered and left the camera's field of view during recording. Although we do not identify particular workers, the number of visible workers per image and their locations are known since the sequences were manually labeled beforehand. In the following, each visible worker per image is referred to as a worker instance. In conclusion, the subsets are organized as follows: the training set consists of 660 images of both environmental conditions containing 1000 pedestrian worker instances while the evaluation set shows 2750 worker instances in 1300 images of different sections of the recorded sequences.

### *2.2. Synthetic Data Generation*

The synthetic dataset was built to emulate the environment and technical specifications of the real world dataset. Figure 2 summarizes the synthetic data creation process. Blender 2.80 was used for the creation of the different scenes, which were rendered using the Cycles rendering engine. All scenes were built by hand to represent real construction sites, depicting a varied set of scenes. Using a physically-based renderer allows for more realistic images than game engines, but also increases render time. For the comparison of real and synthetic data, the two real scenes were recreated (see Figure 1c,d). Overall, 8 different scenes with a total of 3835 frames and 32 tracked subjects were created for this work. Table 1 lists all created scenes, with sample frames shown in Figure 3. Similar to the real world dataset, the synthetic dataset incorporates different lighting and weather conditions, ground surface types, and surrounding clutter. As for the real dataset, this increases the variation of the data samples as well as modeling the most common conditions and construction site types. The lighting conditions include sunny, cloudy, and rainy. The incidence angle of the sun varies to simulate the different lighting conditions at different times of the day, including long shadows in the morning and evening. Overexposure emulates the effect present in the real world dataset (compare Figure 1a,c). The virtual workers wear different safety vests (in either green, yellow, or orange) and different helmets (in either white or yellow) to ensure variability in the dataset. To ensure that the pose and movement of the virtual workers are realistic, motion capture data from the Carnegie Mellon University Motion Capture Database was used. The motion capture data was rigged onto the construction worker model to create realistic animations. Virtual workers conduct multiple different actions in the created scenes, including walking, picking up objects, interacting with each other, or repairing or using equipment. Several other moving objects are added, such as cars, forklifts, and construction material. Dangerous situations are also included, where construction material is lifted above workers' heads (see Figure 3a,d). Workers may also be partly occluded by objects.

**Figure 2.** Process for creating synthetic scenes.



Ground truth labels were extracted automatically from the 3D data. For this, a spherical hull was placed, which encompasses the head and part of the shoulders of a virtual construction worker. The points of the hull for each subject in each frame were transformed into image coordinates, with the bounding box of the image coordinates then serving as the rectangular label. The labeling process was done completely automatically from the source 3D data, not requiring any manual labeling. Resulting labels are depicted in Figure 1.

(**a**) Scene 3 (**b**) Scene 4

(**c**) Scene 5 (**d**) Scene 6

(**e**) Scene 7 (**f**) Scene 8 **Figure 3.** Synthetic scenes 3–8. See Table 1 for descriptions.

### *2.3. Hyperparameter Setting*

In the course of this paper, we developed a simple construction worker tracking system using deep learning. We did not aim at building the system to serve as an optimal ready-to-use safety device but to investigate the possibilities of synthetic data for training and evaluation of such approaches in construction site scenarios.

Our system detects pedestrian construction workers in the camera images and tracks them over time in consecutive image frames. For detection, we apply YOLOv3 [28] as it combines classification and detection of the workers in a single CNN. The network was trained on parts of the labeled datasets described in Sections 2.1 and 2.2. Training the network from scratch is unfeasible due to the limited amount of real world data even though we have already made extensive use of YOLO's built-in data augmentation features. Hence, we rely on transfer learning to avoid overfitting the network to our small dataset during training. For this, we trained our network with a mini batch size of 64 based on the Darknet53 model, which has been pre-trained on ImageNet [29]. To quickly adjust the network towards the newly introduced class of pedestrian workers, we began the training with a high learning rate of 0.001. The learning rate was then scaled down by a factor of 0.1 each 3800 epochs until training ended after 10,400 epochs. This facilitates a finer adjustment of the weights so that these converge towards an optimal result. For regularization, we adapted the weights by a momentum of 0.9 and a weight decay of 0.0005.

Besides the hyperparameters directly affecting the learning behavior, YOLO's detection quality and especially its speed highly depend on the given input image size. Scaling down high resolution images may vanish smaller image features relevant for proper detection. Conversely, the larger the input image is, the longer it takes for YOLO to process the entire image. A suitable scaling of high resolution images with regard to the case of application has, thus, to be determined beforehand. By comparing the detector's performance on different input image sizes, we identify the size which yields optimal detection results in accordance with the preservation of real-time applicability. For this, we train and evaluate the detector's performance on various scales of the real world datasets described in Section 2.1. Since the input image size should be a multiple of 32 px because of YOLO's downsampling architecture, we start our tests with a size of 416 × 416 px, which is equivalent to a factor of 13. For further tests, we gradually increase that factor each time by 3 up to an image size of 800 × 800 px. The performances of these test runs are compared by the mAP score as proposed for the evaluation of the COCO Detection Challenge [30]. As depicted in Equation (1), besides averaging over recall values at *i* = [0, 0.1, 0.2, ... , 1.0] with *ni* values, this score also averages over different IoU thresholds *j* = [0.50, 0.55, ... , 0.95] in *nj* steps. As a consequence, more accurate detection results are rewarded with a higher score.

$$mAP = \frac{1}{n\_i n\_j} \sum\_{i} \sum\_{j} \text{Precision}(\text{Recall}\_{i,j}) \tag{1}$$

For reasonable monitoring or assistance systems, an accurate localization of pedestrian workers on construction sites is desirable. Thus, we choose the COCO mAP score as this score implicitly indicates the localization accuracy.

Although we found that detection accuracy generally rises with increasing image size, the detection speed rigorously drops. At the base image size, we measure a detection time for a single image of 46 ms with a mAP of 76.9 % on the evaluation dataset. At our maximum image size, the detection time increased to 97 ms. This corresponds to about 10 frames per second (fps), which would inhibit a reliable tracking as addressed in Section 2.4. Pointing out a mAP of 86.5 % at still about 15 fps, we decide for an image size of 608 × 608 px as a reasonable tradeoff between detection time and mAP. Using this setup, we are able to analyze at least every second frame at frame rates of up to 30 fps. On the one hand, this speed is still sufficient for the tracking while, on the other hand, we can take advantage of a reasonable detection quality.

### *2.4. Tracking*

After the workers are detected in a video frame, a track is assigned to each of them. For this, we apply Kalman filtering, which relies on a motion model only. An appearance model as it is provided by other tracking approaches would be redundant since our YOLO detector already takes care of this. For the purpose of tracking pedestrian workers, a simple motion model that only includes a worker's two-dimensional location (*x*, *y*) and velocity (*vx*, *vy*) is sufficient. According to this, our motion model describes the transition from one state *t* to the next state *t* + 1, as shown in Equation (2).

$$
\begin{pmatrix} x\_{t+1} \\ y\_{t+1} \\ v\_{x,t+1} \\ v\_{y,t+1} \end{pmatrix} = \begin{pmatrix} 1 & 0 & \Delta t & 0 \\ 0 & 1 & 0 & \Delta t \\ 0 & 0 & 1 & 0 \\ 0 & 0 & 0 & 1 \end{pmatrix} \begin{pmatrix} x\_t \\ y\_t \\ v\_{x,t} \\ v\_{y,t} \end{pmatrix} \tag{2}
$$

Beginning with an initial detection of a worker, the motion model is used to predict this worker's location in the ensuing video frame. Afterwards, the subsequent frame is evaluated. The workers' locations predicted by the Kalman filter in the preceding frame and the detections made by YOLO in the current frame are, then, matched using the Hungarian method. We use the locations of the matching detections to update the estimation of each track towards the actual worker's position. This prevents the tracks from drifting off. Detections that do not match any pre-existing track will create a new track. This way, our detector serves for both, initializing tracks and providing evidence for existing tracks. Those tracks for which no evidence is found will persist for a maximum of 12 frames without further evidence. From this, we can derive the workers' walking path trajectories as well as their prospective walking directions.

Although the Kalman filter was originally developed for linear tracking, it copes with non-linearity to some extent [31]. However, the pedestrian workers' movements can be highly non-linear. In the short time period between only a few consecutive video frames, though, the workers' movements can be assumed to be almost linear. For this reason, the tracks should be updated at least on every second frame. This drastically decreases the probability of occurrence of non-linearities. Considering this, the Kalman filter has to deal with non-linear state transitions only if workers are not detected for several video frames. An adequate detection beforehand is, thus, a necessary requirement for reliable tracking.

### **3. Results**

We conduct different experiments in order to examine the benefits of synthetically generated data. Our experiments investigate both the training and the evaluation of computer vision-based construction site monitoring systems. For this, we developed a tracking system based on YOLOv3 and Kalman filtering. The data generation and the tracking approach are both described in Section 2.

Our first experiment focuses on the effects of synthetic data on the training of the detector. For this, we contrast the performance of YOLOv3 trained on our real world dataset (see Section 2.1) to a detector additionally trained on different subsets of our synthetic dataset, as illustrated in Section 2.2. The suitability of synthetically generated data for the purpose of evaluating the tracking system's performance is, then, determined in a second experiment, shown in Section 3.2. We examine the similarity between the tracking performances of our system when applied to real world video sequences and to their synthetically generated counterparts.

### *3.1. Effects of Synthetic Data on the Training*

In Section 2.3, we determined the baseline mAP of 86.5 % on real world data at our desired input image size. This allows us to examine the effect of adding supplementary synthetic data during training on the detector's performance. For this, we successively add more and more subsets of our synthetic dataset to the set of training samples. These subsets exhibit different characteristics as described in the following. In the first trial, we add the synthetic scenes 1, 2, and 3 to the existing training set. In total, these scenes show 2130 pedestrian worker instances in 750 images. The amount of provided information is not significantly enhanced by these scenes. It only increases the sheer number of samples as the scenes imitate the real world samples, which have already made up the training dataset beforehand. This results in an increase of the mAP by 1.7 percentage points to a value of 88.2 %. Next, we further extend the training dataset by adding the synthetic scenes 4, 5, and 6, which consist of 1510 images showing 5170 worker instances in total. The compositions of these scenes differ from those of the scenes before, but the environmental conditions remain similar. Trained on the resulting dataset, our detector yields a mAP of 88.7 %. This is an increase of only 0.5 percentage points in comparison to the previously trained detector. By subsequently adding scene 7 containing 2490 worker instances in 620 images, we incorporate different lighting conditions into our training dataset. The scene is brightly illuminated, which leads to large cast shadows. Workers are consistently located in over- and underexposed areas. By this, the mAP drastically increases by 5.2 percentage points to 93.9 %. Further training samples of scene 8, which exhibit environmental conditions similar to scene 7, have only little effect on the detection quality. Although this scene

provides another 2490 worker instances in 620 images, the mAP only increases to 94.0 %. The bar graph in Figure 4 summarizes the increase of the mAP over the successively added subsets of our synthetically generated dataset.

**Figure 4.** Increase of YOLO's detection quality by successively adding synthetic data samples.

### *3.2. Tracking System Evaluation*

The final detector generated in the previous experiment (see Section 3.1) is embedded into a simple tracking system. Its detections are passed to a Kalman filter to estimate the detected workers' movements, as explicitly described in Section 2.4. By means of the developed tracking system, we investigate the suitability of computer generated 3D scenes for the evaluation of computer vision-based tracking systems. For this, we compare its tracking results on real and synthetic construction site scenarios. Our tracker is compared by means of the two different real world scenarios (see Section 2.1) and their computer generated counterparts described in Section 2.2. From each scenario we extract a continuous 10 second video sequence in which four pedestrian workers move across the scene varying their pace and direction in each sequence. All video sequences also include sudden stops as well as pedestrian workers coming close to and crossing each other. In the real world sequences, ground truth is labeled manually whereas this is done automatically (see Section 2.2) for the computer generated scenes.

We contrast the accuracy of our tracker on each of these video sequences by the average overlap score (AOS) and the center location error ratio (CER) metrics while its robustness is measured by the track length (TL). These metrics are adapted from [32] as shown in Equations (3) to (5).

$$\text{AOS} = \frac{1}{n} \sum\_{t=1}^{n} \frac{A\_t^G \cap A\_t^T}{A\_t^G \cup A\_t^T},\tag{3}$$

$$\text{CER} = \frac{1}{n} \sum\_{t=1}^{n} \frac{||\mathcal{C}\_t^G - \mathcal{C}\_t^T||\_2}{\text{size}(A\_t^G)},\tag{4}$$

$$\text{TL} = \frac{n}{N'} \tag{5}$$

where *n* and *N* denote the number of video frames in which a worker is tracked and the number of frames in which the worker is present, respectively. The workers' bounding box areas are indicated by *A* and their centers by *C*. The superscripts *G* and *T* indicate ground truth and tracked bounding boxes, respectively. Lastly, || ◦ ||<sup>2</sup> denotes the two dimensional euclidean distance and size(◦) represents the two dimensional size of an area.

Tables 2 and 3 summarize the tracking results on the first and second pair of scenes, respectively. As can be seen, in each of the sequences our tracker identifies four tracks. Each of these corresponds to one of the pedestrian workers as illustrated in Figure 5 for all sequences. No tracks are set up mistakenly tracking other objects in the scenes. With an average of 96 % in AOS and a very low CER on both real world scenes, our tracking system yields an accurate performance. In addition, the averages in TL of 97 % and 98 % highlight the systems' robustness. Deviations from the optimal TL are mainly due to the fact that it takes a few frames until sufficient consecutive detections are identified to start a track. In isolated cases, starting a track was delayed since the initial detection was complicated by odd lighting conditions. The measurements' averages of the synthetic scenes are slightly lower than those of the real world scenes, but still demonstrate an accurate performance of the tracking system.

(**a**) (**b**)

(**c**) (**d**) **Figure 5.** Tracking results on the real world sequences in (**a**,**c**) and their synthetic counterparts in (**b**,**d**). Ground truth labels are marked as green rectangles whereas differently colored rectangles illustrate tracked workers' positions. Each worker's trajectory is represented by a colored line and the prospective movement is depicted as a colored arrow.

**Table 2.** Comparison of our system's tracking performance on the first scene from our real dataset and its synthetic counterpart. IDs are assigned randomly and correspond to a single tracked worker each.


**Table 3.** Comparison of our system's tracking performance on the second scene from our real dataset and its synthetic counterpart. IDs are assigned randomly and correspond to a single tracked worker each.


### **4. Discussion**

Our first experiment, shown in Section 3.1, highlights the benefits of using computer generated data for training a detector on construction related scenarios when only little real world data is available. Making use of such data, we boosted our detector's precision from 86.5 % in mAP using only a small real world dataset to 94.0 % by augmenting the dataset with synthetic data. This amounts to an increase of 7.5 percentage points. Nevertheless, the experiment also shows that synthetic data samples have to be chosen carefully in order to properly increase the detection quality. Adding more data samples of similar kind to a small set of training data generally improves the detection quality. This becomes apparent when extending our dataset by the synthetic scenes 1, 2, and 3, which increases the mAP by 1.7 percentage points (see Figure 4). The detection results in Figure 6a,b illustrate that the detected regions become more accurate, which results in a higher mAP score. However, the need for more samples of one specific kind is satiated at a certain amount. The model learned by the detector, then, already considers all relevant features provided by such samples. No further insights can be drawn from such samples. As a consequence, the detection quality languishes as it occurred when adding the scenes 4, 5, and 6 to the training dataset. A further addition of more samples could cause the model to overfit to these training samples, resulting in an even worse performance on the evaluation set. The dataset extension by scene 7 points out that new samples possessing various conditions can further improve the detection results even though these are of synthetic nature. As shown by the comparison of Figure 6c,d, this enables the detector to even precisely identify workers that were only coarsely detected before due to unusual lighting conditions. Furthermore, the large amount of different training samples aggregated from the synthetic scenes 1–7 enhances the detector's generalization ability so that it is able to cope with slight occlusions and partial color changes due to overexposed lighting conditions (see Figure 6e,f). When adding too many samples of a kind, once more the results begin to languish illustrated by extending the dataset by scene 8. These findings show that computer generated data is exceedingly advantageous to augment a dataset by environmental and lighting conditions that do not occur often. Recording training data on construction sites over a long period in order to cover all conditions is a tedious task. Instead, recording only a short real world dataset at certain conditions would be sufficient while particular weather and lighting conditions could be simulated using a 3D environment.

Considering the severe downscaling of the input images, the achieved mAP is already a prominent performance. Due to the scaling factor of about 0.317, the ground truth regions shrink from the original 40–50 px to a size of only 13–16 px in the scaled images. Despite the fact that CNNs generally struggle with the detection of very small objects, a precise detection by means of IoU is difficult. Each pixel offset between the detection and the ground truth label in the scaled image implies an offset of about 3 px in the original image. Accordingly, if the detection is only 1 px off, the best possible IoU decreases to 0.94. An offset of 1 px in both x- and y-direction further decreases the IoU to 0.88 at best. Taking into account that manually labeling ground truth is never accurate to a single pixel, minor deviations between detection and ground truth are inevitable. Thus, an IoU of more than 0.85 can already be assumed to be optimal in this case. With respect to this, we can adapt Equation (1) so that the last bin of *j* contains the results for all IoU values ranging from 0.85 to 0.95. By this modified measure, our detector achieves a mAP of 99.2 %.

**Figure 6.** Comparing the detection results of YOLO trained on different datasets. Red rectangles display YOLO's detections whereas green rectangles indicate manually labeled ground truth. (**a**,**c**,**e**) show the results of a detector trained on our real world dataset only. The right column illustrates the detection improvements for the same image sections if particular synthetic data is additionally provided for training. In (**b**), the general accuracy is improved by extending the training dataset by similar data. In (**d**), detections in under- and overexposed areas are improved by adding samples with different lighting conditions. In (**f**), workers can be detected despite slight occlusions and dramatic overexposure after adding a large amount of data samples.

In conclusion, we showed in this experiment that computer generated data is capable of successfully augmenting a construction site related dataset for the purpose of training a CNN. The reasonable choice of synthetic training samples can considerably increase the detection quality. These findings correspond to those in other application areas. It shows that such data is not only advantageous for well structured scenarios like car parks, but also yields reasonable results in crucially

more complex environments like construction sites. This is also confirmed by the tracking results on the real world scenes of the second experiment. The simple tracking system based on a detector which was trained on only sparse real world data augmented with synthetic samples already enables suitable monitoring. This further substantiates the use of synthetic data for the training of CNNs in the context of construction sites.

Beyond these findings, the second experiment emphasizes the comparability of real world and synthetic scenes in terms of evaluating vision-based detectors and trackers. The comparison of the tracking results given in Tables 2 and 3 reveals that our tracker acts similar on both kinds of data. On both scenes, the accuracy measured by the AOS and CER on the synthetic sequences is slightly lower than these on real world sequences. This is not necessarily due to the nature of synthetic data but rather associated with the more precise ground truth labels on the synthetic dataset resulting from the automatic labeling method. These labels enclose the workers even closer than those in the manually labeled sequences so that minor deviations during detection result in lower AOS and CER. The comparison of the green ground truth rectangles in Figure 5a,b as well as in Figure 5c,d illustrates this graphically. Nevertheless, on average there is only a deviation of about four percentage points in AOS and a deviation of 0.0005 in CER on the first scene. Similarly, low deviations are given for the second scene with two percentage points in AOS and 0.0005 in CER. These results indicate a reasonable comparability of the tracking performance on real world and synthetic scenarios. Very similar track lengths on both scenes additionally confirm this finding.

The comparison has shown that the evaluation results on real world and computer generated video sequences resemble each other closely. Accordingly, the quality of a tracking system can be deduced on the basis of computer generated video sequences if sufficient real world data cannot be acquired. For construction site scenarios, this is often the case for hazardous situations since intendedly endangering workers should be avoided. Furthermore, weather conditions may appear on which the detector was not explicitly trained. On the basis of the similar evaluation results on synthetic and real world data, we demonstrate the capabilities accompanied by a virtual environment as developed in this paper in an exemplary way. In this example we use a computer generated video to show the evaluation of a tracking system on a risky situation without exposing any worker to a hazard. Furthermore, it illustrates that various environmental conditions can be simulated without the need for tedious repetitive recordings on multiple days. In order to highlight these benefits, we apply our tracking system to a modified version of synthetic scene 1. We change the weather conditions from a sunny to a heavy rainy day and include a crane lifting and pivoting its load above the heads of pedestrian workers. As depicted in Figure 7, again all four workers are tracked by our system despite the rainy weather conditions, which were not explicitly trained. Table 4 shows that only four tracks were assigned, each corresponding to one of the workers. Though, the trackers accuracy slightly decreases. This indicates that the performance of our tracker on a similar real world scenario should basically be still sufficient but its precision might decrease slightly. However, additionally training the detector on data samples of rainy days might counteract this. Furthermore, Table 4 unveils that the tracking is not interrupted even though the workers are temporarily occluded by the crane's load. This demonstrates that the tracking system proposed in this paper is even capable of dealing with certain hazardous situations. By identifying the hazardous area around the crane's load, the crane operator could be warned against approximating pedestrian workers so that risky situations could be prevented. Further tests have to verify whether the system can also cope with larger loads and more difficult weather conditions.

**Figure 7.** Tracking results on computer generated hazardous situation on a heavy rainy day. Green rectangles denote ground truth labels whereas differently-colored rectangles illustrate the tracked pedestrian workers.

**Table 4.** Tracking results of our system on synthetic scene 3. IDs are assigned randomly and correspond to a single tracked worker each.


### **5. Conclusions**

Productivity and safety management on construction sites could benefit from monitoring pedestrian workers. Based on recorded walking trajectories of workers provided by a tracking system, the workers' current paths could be assessed and optimized to solve certain tasks more efficiently. These trajectories also reveal the workers' attention with respect to hazards like falling edges or hazardous materials and areas. Owing to this, safety trainings could be tailored explicitly to the needs of the workers. If the localization of workers is conducted live, workers and machine operators could even be warned of looming hazardous situations, which enables them to counteract early. Computer vision methods are more suitable for such a precise tracking of workers all over the site due to various shortcomings arising from the radio-frequency technology used by tag-based alternatives. However, computer vision approaches have to be trained and evaluated on a vast amount of images. Appropriate datasets are rarely publicly available and recording a sufficient amount of data is extremely time-consuming. For this reason, we investigated the use of synthetic data, which is generated from 3D environments.

Besides a small real world dataset, we generated a synthetic dataset that covers diverse environmental and illumination conditions. In order to analyze the usability of data generated from 3D scenarios, we built a simple tracking system. This consists of a YOLOv3 detector identifying pedestrian workers and a Kalman filter tracking those workers in video sequences. In our experiments, we examined the suitability of synthetic data individually for training and evaluation purposes of a computer vision tracking system. First, we iteratively added more and more synthetic data samples

to the training dataset of our YOLO detector. Second, we compared the performance of our tracking system on real world and corresponding 3D generated video sequences.

We found that training on synthetic data samples significantly enhances the detection quality. In our experiments, we were able to boost our detector by 7.5 percentage points over a detector trained on a small real world dataset only. Though, the quality of the resulting detector is highly dependent on the choice of decent training samples. As for real world datasets, synthetic samples should also cover various environmental and lighting conditions. Furthermore, we found that a computer vision-based tracker performs very similarly on real world and 3D generated video sequences. Accordingly, the evaluation on synthetically generated scenes can already provide reliable insights regarding the strengths and weaknesses of a tracking system since its performance can be estimated considerably precisely. As a result, a vision-based tracking system can be tested on a variety of synthetically generated situations before being employed on a real construction site. By this, a flawless tracking can be guaranteed even for rare or exceptional situations.

The findings of our experiments are in accordance with those from other application areas, but additionally highlight that synthetic data are capable of modeling even the complex and dynamic environments of construction sites realistically. For construction site-related applications, this validation is relevant to a special degree. Since datasets are typically rare in this field, data augmentation using synthetic data could advance the use of vision-based approaches in the future. In particular, this facilitates incorporating conditionally occurring lighting and weather conditions into a dataset. By simulating, for example, snowfall or bright sunlight, in a virtual environment, data acquisition can be drastically accelerated and existing datasets can easily be completed. The use of synthetic data allows to model any weather and lighting conditions and to change the construction site setup at will. Otherwise, recordings all over the year and on different sites would be necessary to acquire a reasonable dataset. Further time savings and human resources result from the possibility of automatic labeling. Since the positions of all objects in a virtual scene are known exactly, labeling can be done fully automatically and with substantially higher precision than a manual labeling. Via the extension of various environmental conditions and a precise labeling, datasets can be prepared such that an optimal training is ensured. The resulting system can deal with manifold upcoming conditions without having seen these in the real world before. Besides the training of detectors, computer generated data is also valuable for the evaluation of vision systems. Again, virtual environments can be used to simulate a variety of scenarios to test the system on. This is particularly advantageous for events that rarely occur or are hazardous to replicate in the real world. Risky situations can be evaluated this way without exposing anyone to a serious hazard.

In summary, our major conclusions are as follows: We have shown the applicability of synthetically generated data for vision systems in the area of construction sites. Furthermore, we highlighted the benefits of such data for training and evaluation purposes for the underlying machine learning algorithms. As this paper outlines an overview of the possibilities accompanied with synthetic data, future work should investigate either the training or the evaluation phase in more detail. For now, the limits of using synthetic data for both, training and evaluation remain unclear. It should also be determined to what extent the use of synthetic data is beneficial and if certain scenarios cannot be modeled sufficiently. The impact of such data on different machine learning algorithms could be another topic to focus on. Finally, a tracking system optimally trained and tested on a combination of real world and synthetic data should be employed on a real construction site for a large-scale case study.

**Author Contributions:** Conceptualization, M.N. and P.H.; data curation, P.H.; investigation, M.N. and P.H.; methodology, M.N.; project administration, M.N.; resources, M.N. and P.H.; software, M.N.; supervision, M.K.; validation, M.N.; visualization, M.N. and P.H.; writing—original draft, M.N. and P.H.; writing—review and editing, M.K. All authors have read and agreed to the published version of the manuscript.

**Funding:** This research received no external funding.

**Acknowledgments:** The motion capture data used in this project was obtained from mocap.cs.cmu.edu. The CMU database was created with funding from NSF EIA-0196217. Furthermore, we acknowledge the support by the DFG Open Access Publication Funds of the Ruhr-Universität Bochum.

**Conflicts of Interest:** The authors declare no conflict of interest.

### **References**


© 2020 by the authors. Licensee MDPI, Basel, Switzerland. This article is an open access article distributed under the terms and conditions of the Creative Commons Attribution (CC BY) license (http://creativecommons.org/licenses/by/4.0/).

## *Article* **Environment Classification for Unmanned Aerial Vehicle Using Convolutional Neural Networks**

### **Carlos Villaseñor 1, Alberto A. Gallegos 2, Javier Gomez-Avila 1, Gehová López-González 2, Jorge D. Rios <sup>1</sup> and Nancy Arana-Daniel 1,\***


Received: 18 June 2020; Accepted: 16 July 2020; Published: 20 July 2020

### **Featured Application: The approach presented in this paper is implemented in an autonomous UAV to provide the ability to change its path according to ground position and weather conditions, since sustaining an aircraft when flying through a dense cloud is not possible.**

**Abstract:** Environment classification is one of the most critical tasks for Unmanned Aerial Vehicles (UAV). Since water accumulation may destabilize UAV, clouds must be detected and avoided. In a previous work presented by the authors, Superpixel Segmentation (SPS) descriptors with low computational cost are used to classify ground, sky, and clouds. In this paper, an enhanced approach to classify the environment in those three classes is presented. The proposed scheme consists of a Convolutional Neural Network (CNN) trained with a dataset generated by both, an human expert and a Support Vector Machine (SVM) to capture context and precise localization. The advantage of using this approach is that the CNN classifies each pixel, instead of a cluster like in SPS, which improves the resolution of the classification, also, is less tedious for the human expert to generate a few training samples instead of the normal amount that it is required. This proposal is implemented for images obtained from video and photographic cameras mounted on a UAV facing in the same direction of the vehicle flight. Experimental results and comparison with other approaches are shown to demonstrate the effectiveness of the algorithm.

**Keywords:** cloud detection; superpixel segmentation; convolutional neural networks; support vector machines

### **1. Introduction**

Unmanned Aerial Vehicles (UAVs) have gained popularity in the last decades due to their capability for moving in three-dimensional space. UAVs were first used for military purposes. However, they are now used for surveillance, research, monitoring, and search and rescue activities [1]. These kinds of vehicles are suited for situations that are too dangerous and hazardous where direct monitoring is not humanly possible [2].

One of the challenges of UAV is the loss of communication with the remote pilot. For this reason, it is necessary to provide the vehicle with a certain level of autonomy to maintain flight in such scenarios. A UAV must be able to adapt and change its path according to ground position and weather conditions, since sustaining an aircraft when flying through a dense cloud is not possible [3]. Given weather indicators that allow the detection of clouds, can be seen from long distances [4]; it is possible to develop an intelligent system capable of avoiding them.

Cloud detection is a very challenging task; each big water cluster has a unique amorphous shape, which is continuously changing; making it impossible to extract characteristic features to be tracked with some descriptor such as with Speeded Up Robust Features (SURF) [5], then, other methods to extract information are needed, such as segmentation based on color, texture, and illumination [6–9].

In Reference [10], several simple-to-implement descriptors with linear computational costs are presented, showing a good training and generalization. Results from a video camera mounted on a UAV reported satisfactory results for two and three class classification in real-time.

Our proposed scheme describes and implements an approach to classify three elements of the environment (ground, sky, and clouds), using Superpixel Segmentation (SPS) and Support Vector Machine (SVM) to pre-train a Convolutional Neural Network (CNN), which is a form of deep learning model, trained end-to-end from raw pixel intensity values to classifier outputs. The spatial structure of the images makes it suitable to work with this kind of networks, setting connectivity between the filters (or layers) and the parameter sharing, and discrete convolutions [11].

The used images in this work were captured by a camera mounted on a UAV provided by Hydra Technologies of Mexico®; an example of an obtained image is shown in Figure 1.

**Figure 1.** Captured image of a video stream from a camera mounted on an unmanned aerial vehicle (UAV). The image resolution is 720 width and 480 height at a 30 frames per second.

The outline of this papers is as follows: In Section 2, related work is presented. Section 3 presents a brief description of the used SVM whose output is used to pre-train the CNN. Section 4 presents the descriptors based on SPS methodology. In Section 5, the CNN architecture is described. Experimental results are presented in Section 6 and important conclusions are discussed in Section 7.

### **2. Related Work**

Most of the research done on cloud detection is ground-based, where clouds are captured with instruments that obtain continuous all-sky images at pre-defined time intervals [12,13]. For a UAV, it is impossible to keep these conditions since the update intervals of information need to be shorter. Moreover, algorithms should not have a high computational cost, because onboard computers may not have the same processing power and memory capabilities as an off-board station. Also, a computer with high processing power in a UAV would require a higher demand for energy, which would require batteries with higher capabilities increasing the UAV weight, affecting the fuel consumption of the aircraft negatively.

Other works solve the problem of object identification using an undirected graph [14]. Computing the graph association matrix could be computationally expensive; in the worst-case scenario, it is a problem of *O*(*n*2) complexity [7,14]. These approaches are not suitable for real-time applications working with high definition images [9]. In Reference [13], an automatic cloud detection for all-sky images using SPS is presented; the result and implementation of this algorithm shown in Figure 2. It can be seen that even if it is a good approximation, some information is lost in the final result. Considering these results and the computational complexity of the algorithm, it may not be suitable for these kinds of applications.

**Figure 2.** Steps of the algorithm described in Reference [11].

On the other hand, algorithms based on image matting [6,8,15] try to reduce computational complexity. These algorithms extract foreground objects in images, but they are not easy to implement and take long processing time [9]. In these approaches, the algorithm distinguishes only between two classes (sky or ground), and it is difficult to add more classes.

Recently, deep learning techniques have been used to solve many computer vision tasks [14,16–20]. In particular, CNNs are good image classifiers [21–26]. Approaches like the ones presented in References [27,28] use CNNs that are trained to predict a class for each pixel. In contrast, this paper employs a segmentation on top of a CNN to label these clusters of pixels as the clean sky, clouds and ground.

### **3. Support Vector Machines**

Vapnik introduced support vector machines in 1995, and they are widely used in classification tasks because of its simplicity and the convexity of the function to optimize [29]. Classification is treated as an optimization problem; the aim is to minimize a risk function *R* and maximize the separation between classes as represented by

$$\|w^\* = \underset{w \in \mathbb{R}^D}{\text{arg min}} F\left(w\right) = \frac{1}{2} \|w\|^2 + \zeta R\left(w\right),\tag{1}$$

where *w* is a normal vector orthogonal to the separating hyperplane, <sup>1</sup> <sup>2</sup> *w*<sup>2</sup> is a quadratic regularization term, and *ζ* > 0 is a fixed constant that limits the risk function. Equation (1) can be expressed using Lagrange multipliers as follows

$$\begin{aligned} \arg\max\_{a} & \sum\_{i=1}^{n} a\_i - \sum\_{i=1}^{n} \sum\_{j=1}^{n} a\_i a\_j \psi\_i \psi\_j \Omega \left(\beta\_i, \beta\_j\right) \\ & w^\* = \sum\_{i=1}^{n} a\_i \psi\_i \beta\_i. \end{aligned} \tag{2}$$

subject to 0 <sup>≤</sup> *<sup>α</sup><sup>i</sup>* <sup>≤</sup> *<sup>ζ</sup>* and <sup>∑</sup>*<sup>n</sup> <sup>i</sup>*=<sup>1</sup> *αiψ<sup>i</sup>* = 0. Where (*βi*,*ψi*) *n <sup>i</sup>* is a training set, from which *β<sup>i</sup>* is an *n*-dimensional input vector and *ψ<sup>i</sup>* its corresponding label. Notice that *α<sup>i</sup>* are Lagrange multipliers and Ω(*βi*, *βj*) is the value of the kernel matrix Ω defined by the inner product *φ* (*βi*) · *φ βj* , where *φ* is a non-linear mapping to a high dimensional space. The advantage of using this dual formulation is the use of kernels that introduce the feature space by implicitly mapping the input data into a higher-dimensional space where non-linearly separable data can be linearly separable [30,31].

A CNN requires a massive amount of training data; this task is usually tedious for a human. In that sense, the data used to pre-train the network has been created by an human expert and a SVM that classifies an image segmented with superpixels, that is, sub-areas represented by only a descriptor instead of having several values for every pixel in the sub-area.

#### **4. Descriptors**

Most of the descriptors are developed to classify only two classes and cannot be naturally scaled to *m* different classes. The descriptors presented in this section have linear complexity *O*(*n*), and a descriptor capable of increasing the number of classes to three is proposed.

#### *4.1. Descriptors Based on Superpixel Segmentation and Histogram*

In this section, three descriptors that use their histograms as features are described. Three images must be obtained to construct the required descriptors. Let (*R*, *G*, *B*) be the channels red, green, and blue, respectively; the descriptors will be obtained from *R* − *B*, *R*/*B*, and *RGB* images. Cloud detection algorithms commonly use color to determine if a region of the image is a cloud. Cloud particles have a similar dispersion of B and R intensity, whereas clear sky presents more *B* than *R* intensity [12,13].

For *N* pixels, *M* superpixels will be generated based on color similarity and proximity using Simple Linear Iterative Clustering (SILC) [32] in CIELAB color space. SILC initializes *M* clusters centers *Cm* = [*lm*, *am*, *bm*, *xm*, *ym*] *<sup>T</sup>* on a regular grid space, where (*l*, *a*, *b*) is the color vector in CIELAB space and (*x*, *y*) are the pixel coordinates. Each superpixel has an approximate size of *N*/*M* and the center will be located every *<sup>S</sup>* <sup>=</sup> <sup>√</sup>*N*/*M*.

SILC computes a distance *D* between pixel *i* and its nearest cluster center *Cm*

$$D = \sqrt{d\_c^2 + \left(\frac{d\_s}{s}\right) r^2},\tag{3}$$

where *r* ∈ [1, 40] is a constant that allows pondering between color similarity and spatial proximity, *dc* and *ds* are defined by

$$d\_c = \sqrt{\left(l\_j - l\_i\right)^2 + \left(a\_j - a\_i\right)^2 + \left(b\_j - b\_i\right)^2} \tag{4}$$

$$d\_s = \sqrt{\left(x\_j - x\_i\right)^2 + \left(y\_j - y\_i\right)^2}.\tag{5}$$

The clusters are adjusted to take the value of the main vector of the pixels in *Cm*, and a residual error *E* between the new cluster center and previous centers is computed using *L*<sup>2</sup> norm. The algorithm stops when *E* reaches a certain threshold.

The descriptor *β* of the superpixel *k* is obtained from a histogram of 16 values for each superpixel in *R* − *B* and *R*/*B* images. The intensity value of pixel *i* ∈ *k* is divided by 16 and rounded downward to its nearest integer value. In the case of the *RGB* image, a histogram for each channel is obtained.

### *4.2. Superpixel Segmentation with Gabor Filter*

For this approach, a pre-processing step is needed and is showed in Figure 3. Since clouds enhance the *R* − *B* difference, this image has been used, and its histogram has been normalized. Gaussian blur has been applied to reduce noise, before the binarization with Otsu's method [33], which is obtained by solving

$$
\sigma\_b^2 \left( t \right) = P\_0 P\_1 \left( \mu\_0 - \mu\_1 \right)^2,\tag{6}
$$

where *P*<sup>0</sup> and *P*<sup>1</sup> are class probabilities obtained from a histogram *L* and separated by a threshold *t*; and *μ*<sup>0</sup> and *μ*<sup>1</sup> are the means of the classes. This is represented by Equations (7)–(10):

$$P\_0\left(t\right) \quad = \sum\_{i=0}^{t-1} p\left(i\right) \tag{7}$$

$$P\_1\left(t\right) \quad = \sum\_{i=t}^{L-1} p\left(i\right) \tag{8}$$

$$\mu\_0\left(t\right) \quad = \sum\_{i=0}^{t-1} i \frac{p(i)}{P\_0} \tag{9}$$

$$\mu\_1\left(t\right) \quad = \sum\_{i=t}^{L-1} i \frac{p(i)}{P\_1}.\tag{10}$$

At this step, it is easy to classify clean-sky from clouds; however, as can be seen in Figure 4, it is not possible to make a distinction between clouds and ground. Because of this, it is necessary to use another descriptor capable of distinguishing between them. In this case, the Gabor filter [34] is applied to the original image to get the descriptor because of its ability to permit texture representation and discrimination. The filter has a strong response with structures in the image that have the same direction [35]. The following two-dimensional Gabor functions are used:

$$g\_{\lambda, \Theta, \rho} \left( x, y \right) = e^{-\left( \left( x'^2 + \gamma^2 y'^2 \right) / 2x^2 \right)} \cos \left( 2\pi \frac{x'}{\lambda} + \rho \right) \,. \tag{11}$$

$$\mathbf{x}' = \mathbf{x}\cos\Theta + \mathbf{y}\sin\Theta \tag{12}$$

$$y' = -x\cos\Theta + y\sin\Theta,\tag{13}$$

where *λ* is the wavelength, Θ is the orientation, *ρ* is the phase offset, *γ* is the aspect ratio, and *σ* = 0.56*λ* is the standard deviation.

(**a**) Original image (**b**) Binary image

**Figure 4.** Clouds and ground classes cannot be easily distinguished.

Four Gabor filters are calculated for Θ ∈ (*π*/4, *π*/(2, 3*π*/(4, *π*))). The filtered images are converted to grayscale, and the mean of the values of the image is added to the descriptor. The variance of superpixel *k*, in each Gabor filtered image is calculated and added to the descriptor *βk*. Moreover, spatial information has been included in the descriptor since ground superpixels will have lower spatial values, while clouds superpixels will have higher spatial values.

### **5. Convolutional Neural Networks**

CNNs are commonly used for processing data contained in a matrix or grid, such as images, that are represented by a 2D matrix. Their name comes from the mathematical operation called convolution, which is an operation on two functions to produce a third function that expresses how one of them is modified by the other. In computer vision and image processing, the convolution operation is used to reduce noise and enhance features in images.

Let us suppose that *s*(*t*) is the output of the convolution; the operation is given by

$$\text{abs}\left(t\right) = \int l\left(a\right)h\left(t-a\right)da,\tag{14}$$

where function *l* is the output of a sensor (and usually referred to as the input in CNN terminology), *h* is a weighting function (also known as the kernel), *a* is the age of the measurement. The convolution is commonly denoted with an asterisk as follows

$$s\left(t\right) = \left(l\*h\right)\left(t\right).\tag{15}$$

This data is usually discretized, and if time *t* can only take integer values then it is possible to define the convolution as a discrete operation as follows

$$s\left(t\right) = \sum\_{a = -\infty}^{\infty} l\left(a\right)h\left(t - a\right). \tag{16}$$

The input and the output are multidimensional arrays, and every element must be explicitly stored separately. It is assumed that every element out of the set of points, for which the values are stored, is zero; therefore, the infinite summation can be implemented over a finite number of array elements, and also, it can be used over more than one axis at a time. Let *I* be a two-dimensional image, *K* a two-dimensional kernel, the convolution for images is given by

$$S\left(i,j\right) = \left(I\*K\right)\left(i,j\right) = \sum\_{m}\sum\_{n} I\left(m,n\right)K\left(i-m,j-n\right) \tag{17}$$

and can graphically be described, as shown in Figure 5.


**Figure 5.** Graphical description of convolution operation.

The convolution presents two properties that can help to improve a machine learning system—sparse interactions and parameter sharing [36].

Due to its sparse interactions, it is necessary to store fewer parameters and fewer operations; however, units in the deeper layers may indirectly interact with a more significant portion of the input and describe more complicated interactions between pixels, as described in Figure 6.

**Figure 6.** Description of sparse interactions. Even if direct connections seem to be sparse, more units at deeper layers are indirectly connected.

A CNN consist of three steps. First, several convolutions in parallel produce a set of linear activations. Then, a detector step is implemented, where nonlinear activation functions take the linear activations as the argument. Finally, a pooling function is used to modify the output of the layer, making the representation invariant to small translations of the input [36].

### *Environment Classification with CNN*

CNNs have demonstrated effectiveness in image recognition, segmentation, and detection [11]. The architecture of the network is shown in Figure 7. Each layer uses a Rectified Linear Unit (ReLU) function for their activation; except for the last one, whose activation function is a sigmoid, and is given by *f* (*x*) = 1 (1 + *e*−*x*).

**Figure 7.** Convolutional Neural Network (CNN) architecture. In all cases stride = 1 and a zero-padding = 1. The output image has three channels (for sky, cloud, and ground), and each pixel is labeled based on its three channels values.

CNN is a class of deep learning model that requires a large quantity of data to be trained. In practice, it is relatively rare to access large data sets, and it is a tedious task for a human to generate them [21]. In this work, one part of this data is generated by the classification of the superpixels made by the SVM; nevertheless, training the CNN only with SVM information would make the CNN learn from a Support Vector Machine. Another set of training data was provided by a human expert to avoid this behavior. Finally, the training data were artificially enlarged using data augmentation.

#### **6. Experimental Results**

In this section, results of proposal are presented, the pre-train step is carried out with 1000 images provided by an SVM. Then, only twenty ground truth images classified by a human expert are used for supervised training. Table 1 shows ten test images used to demonstrate the effectiveness of the proposed algorithm. These photos were taken from three different flights at a fixed altitude, but different in each flight, and different weather conditions. Although they are not consecutive frames, pictures from rows 5 to 7 were taken from a straight and level flight; and there is little difference between them, however, the SPS-SVM clearly presents a different classification between these images. Additionally, the data set is artificially enlarged, applying geometric transformations to the training set. **Table 1.** Test results. The first column shows the original image. The second column shows the Superpixel Segmentation (SPS)–Support Vector Machine (SVM) classification. The ground truth, generated by a human is shown in the third column. In the fourth column, the classification made by the CNN is presented.

For each pixel, the CNN outputs the probability of belonging to each class. By using these probabilities as pixel intensities, we form grayscale images in Figure 8. Their histogram are also shown. Moreover, the probabilities of each class are scaled and presented in Figure 9 to demonstrate which pixels activate the output layer for each class.

**Figure 8.** Pixel distribution for each class.

**Figure 9.** Scaled probabilities for each class.

To display a better visualization of the performance, Table 2 shows the confusion matrices of both approaches. These matrices compare the prediction of the algorithm with the ground truth. The closer it gets to an identity matrix, the less the algorithm gets confused between classes.


**Table 2.** Confusion Matrices comparison. (S: Sky, C: Cloud, G: Ground).

As seen in Table 1, adding a few images from an human expert, avoids CNN to behave as an SVM. The advantage is that an human expert need to generate only twenty training images which the network can make a good generalization and correct mistakes generated by the SVM, for example, rows 5 and 6 in Table 1.

From the matrices in Table 2, recall, precision, and F1 score are computed to measure the effectiveness of the algorithm and to compare it with the SPS-SVM. These results are shown in Table 3. There is no entry for SVM in test 8 because, in such an experiment, only two classes were found (missing sky).

The confusion matrices for both techniques are very close. To get a better understanding for each matrix the macro versions of the recall, precision, and f1-score, in Table 3. In Figure 10, the same score to get a better visual understanding of proposal performance is plotted. The proposal overcome the SPS-SVM in almost all the samples (except for the sample seven).


**Table 3.** The measure of the test for both approaches. Bold letters denote the winner technique

**Figure 10.** The measure of the test for both approaches.

For both schemes, we find the hyperparameters heuristically guided by the train and test scores obtained from 30 executions episodes. Finally, in this paper, we do not show a run-time comparison because CNN was implemented in TensorFlow-Keras Framework; consequently, it runs over the Graphical Process Unit (GPU). On the other hand, the system SPS-SVM was implemented as a sequential algorithm to be executed on the CPU due to the complexity of its parallelization. CNN has lower run-time than SPS-SVM, but the comparison is not fair until we get a parallelized implementation of SPS-SVM.

### **7. Conclusions**

As can be seen in the previous section, the approach gives good results not only classifying the parts of the environment that are desired to be segmented into classes but also reducing the tedious labor of generating a data set by human hand. As seen on the results image, the proposal can classify with more detail than a SVM or a human using basic image editing tools.

The CNN for pixel classification commonly needs a big data set to train; in this paper, a CNN is pre-trained with the prediction of an SPS–SVM. Then, the SPS–SVM can be considered as a data augmentation process to generate synthetic labeled data.

The approach is fast enough to provide sensitive information in a short time, so a UAV can take decisions with recent information. Future work will focus on improving the classification by adding estimations on the different types of clouds that can be found in the environment and the risk they could represent for a UAV.

**Author Contributions:** Conceptualization, C.V. and N.A.-D.; methodology, C.V.; software, A.A.G.; validation, G.L-G., C.V. and J.G.-A.; formal analysis, N.A.-D. and A.A.G.; investigation, J.D.R.; writing—original draft preparation, J.G.-A.; writing—review and editing, G.L.-G., J.D.R. and J.G.-A.; visualization, A.A.G. and G.L.-G.; supervision, J.D.R.; project administration, C.V. and N.A.-D. All authors have read and agreed to the published version of the manuscript.

**Funding:** This research was funded by CONACYT México grants numbers CB256769, CB258068, and PN-4107.

**Acknowledgments:** The authors would like to thank Hydra Technologies de México for providing the data for the development of this work.

**Conflicts of Interest:** The authors declare no conflict of interest. The founders had no role in the design of the study; in the collection, analyses, or interpretation of data; in the writing of the manuscript, or in the decision to publish the results.

### **Abbreviations**

The following abbreviations are used in this manuscript:


### **References**


© 2020 by the authors. Licensee MDPI, Basel, Switzerland. This article is an open access article distributed under the terms and conditions of the Creative Commons Attribution (CC BY) license (http://creativecommons.org/licenses/by/4.0/).

## *Article* **Graphs Regularized Robust Matrix Factorization and Its Application on Student Grade Prediction**

### **Yupei Zhang, Yue Yun, Huan Dai, Jiaqi Cui and Xuequn Shang \***

School of Computer Science, Northwestern Polytechnical University, Xi'an 710129, Shaanxi, China; ypzhaang@nwpu.edu.cn (Y.Z.); yundayue@mail.nwpu.edu.cn (Y.Y.); daihuan@mail.nwpu.edu.cn (H.D.); cuijiaqi@nwpu.edu.cn (J.C.)

**\*** Correspondence: shang@nwpu.edu.cn

Received: 9 January 2020; Accepted: 21 February 2020; Published: 4 March 2020

**Abstract:** Student grade prediction (SGP) is an important educational problem for designing personalized strategies of teaching and learning. Many studies adopt the technique of matrix factorization (MF). However, their methods often focus on the grade records regardless of the side information, such as backgrounds and relationships. To this end, in this paper, we propose a new MF method, called graph regularized robust matrix factorization (GRMF), based on the recent robust MF version. GRMF integrates two side graphs built on the side data of students and courses into the objective of robust low-rank MF. As a result, the learned features of students and courses can grasp more priors from educational situations to achieve higher grade prediction results. The resulting objective problem can be effectively optimized by the Majorization Minimization (MM) algorithm. In addition, GRMF not only can yield the specific features for the education domain but can also deal with the case of missing, noisy, and corruptive data. To verify our method, we test GRMF on two public data sets for rating prediction and image recovery. Finally, we apply GRMF to educational data from our university, which is composed of 1325 students and 832 courses. The extensive experimental results manifestly show that GRMF is robust to various data problem and achieves more effective features in comparison with other methods. Moreover, GRMF also delivers higher prediction accuracy than other methods on our educational data set. This technique can facilitate personalized teaching and learning in higher education.

**Keywords:** robust matrix factorization; student grade prediction; educational data mining; side information graph; personal teaching and learning

### **1. Introduction**

In high school education, student grade prediction (SGP) can make great sense for aiding all stakeholders in the education process. For students, SGP can help them to choose suitable courses or exercises for increasing their knowledge, and even to make their pre-plans for academic periods. For instructors, SGP can help them to adjust learning materials and teaching programs based on student ability, and to find the students that are at risk of disqualification in course progress. For educational managers, SGP can help them to check the curriculum program and to arrange the courses in a scientific order. All stakeholders of the educational process could have a better self-plan to improve education outcomes and then have a higher graduation rate. SGP is an important problem for scientific education in STEM (Science, Technology, Engineering, Mathematics), referred to in the work of G. Shannon et al. [1].

Student grade prediction aims to predict the final score/grade of course enrolled by a target student in the next academic term. SGP provides a useful reference to evaluate educational outputs in advance and is thus significant necessary for various tasks towards personalized education, such as ensuring on-time graduation [2] and improving learning grade [3,4]. Over the past years, many studies have paid attention to SGP and have already developed many methods [5].

Existing methods can be principally divided into three categories depending on their formulation, as follows: (1) Classification problem. SGP is recast as labeling the target student with the predefined grade tags and was solved by classification models, such as decision tree [6], logic regression [7,8] and support vector machine [9,10]. (2) Regression problem. By taking the grade as the response variable, SGP is rewritten into assigning scores following the features of student or course, such as linear regression [5,11,12], neural network [13–15] and random forest [9]. (3) Matrix completion. Since grade records can be poured into a matrix, SGP is also formulated as predicting the missing values of the student-course matrix with each element being a course grade [16]. This formulation is usually solved by the popular method of matrix factorization and has been extensively studied, leading to many effective approaches [17,18]. In particular, based on the same dataset, Thai-Nghe et al. compared matrix competition with traditional regression methods such as logistic/linear regression and the experimental results show that matrix competition can improve prediction results [19].

MF based methods aim to learn the latent features of student and course from the given grade data and then uses these features for SGP [20]. Here, we review the related works that using MF techniques. Traditional MF was employed to implicitly encode "slip rate" (the probability that the student knows how to solve a question but makes a mistake) and the "guess rate" (the probability that the student does not know how to solve a question but guesses correctly) of the student in an examination, resulting in an excellent performance on the educational data set of KDD (Knowledge Discovery and Data Mining) Cup 2010 [21]. In References [22,23], Non-negative Matrix Factorization (NMF) was used to integrate the nonnegativity of student grade. Tensor factorization (TF) was exploited to take the temporal effects into account in Reference [24], due to the improvement of the ability of students. Since grade matrix is implicitly low rank, low-rank matrix factorization (LRMF) was investigated in data sets from the online learning platform in the work of Lorenzen et al. [25]. But the existing MF based methods often suffer from the issues of missing data, corrupted data, and data noise. Especially, they fail to consider the side information which is included in the other handy educational data, such as background data and daily behavior data in school.

Since the *L*2-norm based reconstruction is sensitive to outliers and data corruptions, Lin et al. proposes to use *L*1-norm instead of *L*2-norm to enhance the robustness [26–28]. Besides, we often have massively available side information data in real-world applications. Rao et al. proposes a method of graph regularized alternating least squares (GRALS) to integrated two graphs from the side information data of movies and viewers [29]. More specifically, in the real context of high education, the data set usually has the following properties: (1) The grade matrix is heavily lost for course selection and corrupted by some human factors. (2) The students with similar backgrounds are likely to have similar performance in a course [30]. For example, two students both have more exercises in computer programming, and then they may both obtain a perfect grade at their course of *C language* with a high probability. (3) The courses with similar knowledge tend to give rise to a similar grade for a student. For instance, *C Language* is similar with *Data Structure* while *C Language* is not similar with *History*, thus student who is good at *C Language* is likely to have good performance in *Data Structure* but not necessarily *History*.

To this end, we put forth a novel MF method, called double graph regularized robust matrix factorization (GRMF), following by applying GRMF for SGP as shown in Figure 1. GRMF not only uses the robust loss function from RMF-MM but also integrates two side information graphs constructed using the background data of students and courses. The MM algorithm can effectively solve the resulting optimization problem. Two-folds contributions of our paper are summarized as follows:


**Figure 1.** The proposed workflow of student grade prediction using GRMF.

The rest of this paper is organized as follows—in Section 2, we formulate the problem of SGP, followed by brief reviewing the MF technique. We present GRMF in Section 3 and the GRMF algorithm in Section 4. Section 5 shows the experimental results on movie rate prediction, image recovery, and SGP. Section 6 finally concludes this paper.

### **2. Ralated Works**

In this section, we formulate the problem of SGP in the form of mathematics, followed by introducing the promising technique of matrix factorization.

### *2.1. Student Grade Prediction (SGP)*

In current higher education in university, the teachers provide a "one-size-fits-all" curriculum, while the students enroll many courses to obtain academic credit. To graduate on time, the student expects to know which course he/she can pass with high score/grade, while the teacher expects to know which student has a risk of failure in his/her course. Hence, the problem of predicting the student grade at a course is significant to improve the educational outcomes.

Generally speaking, the grade of one student at a target course can be inferred by his/her learning records, including historical grades in enrolled courses, academic behaviors and his/her background [31,32]. In this paper, we make the following assumption—the grade can be determined by the latent features of student and course, where those features can be derived from the data of students and courses. We explicitly define the task of SGP as follows:

**Problem 1 (Student Grade Prediction):** Let *g*(*s*, *c*) be the grade of student *s* at course *c*. Denote by **<sup>u</sup>***<sup>s</sup>* the feature of student *<sup>s</sup>* and **<sup>v</sup>***<sup>c</sup>* the feature of course *<sup>c</sup>*. Given the grade matrix **<sup>M</sup>**, SGP aims to seek the mapping <sup>H</sup>(**u***s*, **<sup>v</sup>***c*), such that *g*(*s*, *c*) <sup>=</sup> <sup>H</sup>(**u***s*, **<sup>v</sup>***c*) for all grades in **<sup>M</sup>**.

To solve Problem 1, we should extract **u** and **v** and design a mapping using the given data matrix **M**. Most research designs or learns the features by using the background information [33,34], such as student age and credit time, or the student grades on all finished courses. Since both of them are helpful, in this paper, we combine both information for SGP through developing the MF [26].

#### *2.2. Matrix Factorization*

Letting **<sup>M</sup>** <sup>∈</sup> <sup>R</sup>*<sup>m</sup>* <sup>×</sup> *<sup>n</sup>* be the given matrix, MF aims to seek two latent feature matrices **<sup>U</sup>** <sup>∈</sup> <sup>R</sup>*<sup>m</sup>* <sup>×</sup> *<sup>k</sup>* and **<sup>V</sup>** <sup>∈</sup> <sup>R</sup>*<sup>m</sup>* <sup>×</sup> *<sup>k</sup>* to approximate **<sup>M</sup>**. The traditional MF can be written as:

$$\min\_{\mathbf{U},\mathbf{V}} ||\mathbf{M} - \mathbf{U}\mathbf{V}^T||\_{F'} \tag{1}$$

where *<sup>k</sup>* is the number of latent features predefined in **<sup>U</sup>** and **<sup>V</sup>**, and || · ||*<sup>F</sup>* is the Frobenius norm. Optimization problem (1) can be solved by various algorithms, such as Majorization Minimization (MM) [35], alternating the direction of the method of multipliers (ADMM) [36], simulated annealing (SA) [37]. Besides, many variants of MF have been proposed, including LRMF [25], NMF [22] and TF [24].

To enhance the robustness, robust matrix factorization via majorization minimization (RMF-MM) employs *L*1-norm instead of *LF*-norm as the reconstruction term [26]. The objective problem of RMF-MM is:

$$\min\_{\mathbf{U},\mathbf{V}} ||\mathbf{W} \odot \left(\mathbf{M} - \mathbf{U}\mathbf{V}^T\right)||\_1 + \frac{\lambda}{2}||\mathbf{U}||\_F^2 + \frac{\lambda}{2}||\mathbf{V}||\_{F,\prime}^2\tag{2}$$

where || · ||<sup>1</sup> is *L*1-norm of matrix, *λ* > 0 is a regularization parameter and **W** is defined as follows:

$$\mathbf{W}\_{ij} = \begin{cases} 0, & \text{the value of } \mathbf{M}\_{ij} \text{ is missing} \\ 1, & \text{otherwise.} \end{cases} \tag{3}$$

The problem above can be effectively optimized by MM algorithm. The results in the experiments by Lin et al. show that RMF-MM is robust to high missing rate or severe data corruption [26].

Since RMF-MM can effectively learn the features from noisy data and then uses the features for prediction, we reformulate the Problem 1 for employing this novel technique, as follows:

**Problem 2. (SGP-MF):** Given a student grade matrix **M**, SGP-MF aims to extract **U** for students and **V** for courses such that **M** = **UV**. Then the target grade is predicted by

$$\mathbf{g}(\mathbf{s}, \mathbf{c}) = \mathbf{M}\_{\mathbf{s}, \mathbf{c}} = \mathbf{u}\_{\mathbf{s}}^{T} \mathbf{v}\_{\mathbf{c} \prime} \tag{4}$$

where *<sup>g</sup>*(*s*, *<sup>c</sup>*) is the grade of student *<sup>s</sup>* on course *<sup>c</sup>*, **<sup>u</sup>***<sup>s</sup>* is the *<sup>s</sup>*-th row of **<sup>U</sup>** and **<sup>v</sup>***<sup>c</sup>* is the *<sup>c</sup>*-th row of **<sup>V</sup>**. And **<sup>M</sup>***s*,*<sup>c</sup>* is the element of *<sup>s</sup>*-th row, *<sup>c</sup>*-th column of matrix **<sup>M</sup>**.

The reason we consider the Formula (4) is the fact that a student enrolls on a course and obtains a grade. This fact motivates us to obtain the student's features and course's features, given the grade matrix. In this paper, we consider this problem using the matrix factorization (MF) method. As in Formula (4), each grade *Ms*,*<sup>c</sup>* is made by **u***<sup>T</sup> <sup>s</sup>* **v***<sup>c</sup>* to obtain the latent features

However, RMF-MM fails to consider the side-information data that is often available. The method of graph matrix factorization (GMF) is an approach to integrate the neighborhood structure of **M**, but it does not work for matrix completion [38]. Based on GMF, we here solve the SGP by combining two side information graphs with RMF-MM.

#### **3. Double Graph Regularized Robust Matrix Factorization**

In this section, we present our motivation for considering side information data in SGP and encode them into two graphs, followed by our objective problem and its detail optimization with MM.

#### *3.1. Motivation*

In real-world education, various related information can be obtained from the student, such as background, daily life, and student behaviors, as well as course. These side information data contain the relationships among students and courses that can be used for enhancing the prediction performance. Hence, we in this paper propose to encode them in two graphs, followed by integrating them into RMF.

More specifically, we list some *observations*: (1) The family background, such as the economic situation and educational level of their parents, influences the scope of student knowledge [39]. (2) The background of students, such as majors and ages, may affect their habits of thinking and learning. (3) The related course contains much overlapping knowledge or similar skills. (4) Courses taught by an identical teacher are similar in the style of teaching and testing [40].

From the above observations, we have the follows: On the one hand, it is believed that students with a similar background can obtain similar performance. On the other hand, two similar courses tend to have similar grade distribution.

### *3.2. Side Information Graph*

Considering the row/column vectors of **M** as data points, each row vector of **U**/**V** is the low-rank representation of the corresponding row/column in **M**. Note that each row in both **M** and **U** corresponds to a student, while each column in both **M** and **V***<sup>T</sup>* corresponds to a course. Besides, we have side information feature matrixes from students and courses, denoted by **S***u* and **S***v*. Following above, if two students/courses are close in terms of **S***u* /**S***v*, then the corresponding rows of **U**/**V** are also close to each other [41,42].

In order to simultaneously integrate the side information of students and courses, we knit two similarity graphs using **S***<sup>u</sup>* and **S***<sup>v</sup>* instead of using **M** [38,43,44]. That is the reason that the graphs here are called side information graph. The method of building graph is as follows. Denote by **Q** = {**S**, **E**|**G**} the side information graph, where **S** includes all data points from students or courses, **E** is the set of edges, and **G** contains all weights on all edges. **G** is constructed by:

$$\mathbf{G}\_{ij} = \begin{cases} e^{-\frac{\left\|\mathbf{s}\_i - \mathbf{s}\_j\right\|^2}{\sigma}} & \text{, } \mathbf{s}\_i \in N\_k\left\{\mathbf{s}\_j\right\} \text{ or } \mathbf{s}\_j \in N\_k\left\{\mathbf{s}\_i\right\} \\\\ 0 & \text{, } \text{ otherwise} \end{cases} \tag{5}$$

where **<sup>s</sup>***<sup>i</sup>* is corresponding to the data point in **<sup>S</sup>***<sup>u</sup>* or **<sup>S</sup>***v*, *<sup>σ</sup>* is the kernel parameter and *Nk*{**x**} indicates the set of *k* neighbors to sample **<sup>x</sup>**. The details can be found in the literature [41].

Since the similarity relationships encoded in the side information graphs are constructive for learning the latent features, we hope to preserve them in **U** and **V**. Taking **U** for example, we, as usual, employ the following objective [41]:

$$\mathcal{R}\_1 = \frac{1}{2} \sum\_{i,j} \mathbf{G}\_{i,j} ||\mathbf{u}\_i - \mathbf{u}\_j||\_2^2 = tr\left(\mathbf{U}^T \mathbf{H}\_u \mathbf{U}\right),\tag{6}$$

where *tr*(·) denotes the trace of a matrix, **u***<sup>i</sup>* is the row of **U** and **H***<sup>u</sup>* = **D** − **G**, **D***ii* = ∑*<sup>j</sup>* **G***i*,*j*. Similarly, we can knit the side information graph of course and then obtain two Laplacian regularization terms.

### *3.3. The Objective Problem of GRMF*

With the idea of integrating the side information, we combine the objective of RMF-MM and the two Laplacian regularizations, as follows:

$$\begin{aligned} \min\_{\mathbf{U}, \mathbf{V}} & ||\mathbf{W} \odot (\mathbf{M} - \mathbf{U}\mathbf{V}^{\mathrm{T}})||\_{1} + \frac{\lambda}{2} \left( ||\mathbf{U}||\_{F} + ||\mathbf{V}||\_{F} \right) + \\ & \frac{\alpha}{2} \left( tr \left( \mathbf{U}^{T}\mathbf{H}\_{\mathrm{u}}\mathbf{U} \right) + tr \left( \mathbf{V}^{T}\mathbf{H}\_{\mathrm{v}}\mathbf{V} \right) \right), \end{aligned} \tag{7}$$

where *λ* > 0, *α* ≥ 0 are two trade-off parameters, and **H***u*/**H***<sup>v</sup>* is defined in the above section. From (7), we can believe that GRMF can reach a better performance than RMF-MM ,since GRMF degenerates into RMF-MM when *α* is zero.

The main difference between GRMF and RMF-MM lies in the graph Laplacian regularizers of (6), where GRMF integrates more data priors. While GRALS uses *L*2-norm for data fidelity [29]. GRMF proposes to adopt *L*1-norm and thus is more robust to data noise and pollution.

The SGP problem is first described as a machine learning problem, shown in Problem 1. We assume the grade is determined by the student's latent features and the course's latent features. This assumption is general. The problem is then reformulated by matrix factorization (MF), since we plan to adopt MF to learn the latent features. In order to consider the noise in the given grade matrix, we reformulated the objective of MF by *L*<sup>1</sup> normal, because the noise is considered from the grade temper, slipping, and so forth. Finally, for better prediction result, we consider the relationship of students and the relationship of courses in our robust MF model though two graph regularization items. Our objective is thus shown in Equation (7).

### **4. GRMF Algorithm**

In this section, we use a majorization-minimization algorithm to solve problem (7). Suppose that we already have obtained (**U***k*, **<sup>V</sup>***k*) after the *k*-th iterations. We split (**U**, **<sup>V</sup>**) as the sum of (**U***k*, **<sup>V</sup>***k*) and the unknown residue (Δ**U***k*, Δ**V***k*):

$$(\mathbf{U}\_{k+1\prime}\mathbf{V}\_{k+1}) = (\mathbf{U}\_{k\prime}\mathbf{U}\_k) + (\Delta \mathbf{U}\_k \Delta \mathbf{V}\_k). \tag{8}$$

The task can now be finding the small increment (Δ**U***k*, **<sup>Δ</sup>V***k*) in the *k*-th iteration such that the objective function keeps decreasing. To seek the best (Δ**U***k*, **ΔV***k*), we employ the linearized Direction Method with Parallel Splitting and Adaptive Penalty (LADMPSAP) [45]. We made the detailed procedure of this optimization in Appendix A. We summarize the main flow of GRMF to make the paper self-contained in Algorithm 1, shown as below:

**Algorithm 1** Graph regularized Robust Matrix Factorization (GRMF) by Majorization Minimization

```
Input: M ∈ Rn×m, α, and λ
Output: U and V
Method:
  Initialize U0 and V0 with using SVD on M;
  ΔU0 = ΔV0 = 0; ε1 = ε2 = 1e − 6.
  While not converged when we arrived (Uk, Vk), do
      Let t = 1;
      While not converged, do
          Update ΔUt and ΔVt via LADMPSAP;
          t = t + 1;
      End while
      (ΔUk, ΔVk)=(ΔUt, ΔVt);
      Update U and V in parallel:
          Uk+1 = Uk + ΔUk;
          Vk+1 = Vk + ΔVk;
      Check the convergence coditions, if
          Vk+1 − Vk < ε1 and Uk+1 − Uk < ε2;
```
### **End while**.

#### **5. Experimental Results**

In order to evaluate the performance of GRMF, we conducted the following experiments: (1) testing GRMF, RMF-MM and on MOVIELENS 100*k* datasets and a public image data; (2) comparing GRMF with several fashion methods for student grade prediction, including RMF-MM [26], GRALS [29], MF [46], NMF [22], PMF [47], KNN(*k*-Nearest Neighbor) [48] and column mean [49] using the real educational dataset from our university. Note that MF is the standard matrix factorization solved with gradient descent; column-mean is the mean scores of historical grades of target course; and for KNN-mean, we obtained the *k* neighbor students and then computed the grade mean. The code and data sets are available on our website, https://github.com/ypzhaang/student-performanceprediction.

### *5.1. Evaluation Metric*

Three metrics are used for evaluating the results: Root Mean Squared Error (RMSE), *L*1-norm Error (Err1) [26], PSNR (Peak Signal to Noise Ratio) and Acc (Accuracy rate). Especially, in our paper, Acc is computed as follows:

$$Acc = \frac{\sum\_{i=1}^{n} \Delta \mathbf{g}\_i}{n},\tag{9}$$

where

$$\Delta \mathbf{g} = \begin{cases} 1, & |\left(\mathbf{g}\_{\rm tr} - \mathbf{g}\right)| \ge 0.5 \\ 0, & |\left(\mathbf{g}\_{\rm tr} - \mathbf{g}\right)| < 0.5 \end{cases} \tag{10}$$

in which *gre* is the predicted grade while *g* is the true grade and *n* is the number of grades.

### *5.2. Test on a Toy Data from Movie Dataset*

MovieLens data sets were collected by the GroupLens Research Project at the University of Minnesota. These data sets consist of 100,000 ratings (1–5) from 943 users on 1682 movies, background information from users (e.g., age, occupation, and zip code) and movies (e.g., title, release date, and genre). Besides, users who have less than 20 ratings or do not have completed demographic information were removed. In this test experiment, we draw out a toy data set from MovieLens to

probe the effectiveness, convergence, and parameter effects of GRMF. And in the toy data set, the user ids are less than 200, and the movie ids are less than 300.

#### 5.2.1. Rating Prediction and Algorithm Convergence

We divided the toy data set into a training set and test set by random sampling. To evaluate the small toy data, we employed a five-fold cross validation that trains models on four-fold samples and tests on the remaining samples. Whereby we constructed two five-nearest neighborhood graphs from the background data of both users and movies. We chose suitable parameters for achieving best performance using all the mentioned methods. Note that the optimal parameters of GRALS were selected in Reference [29].

Table 1 shows the prediction results from using four methods on the toy data set. It is easy to observe that: (1) MF is better than RMF-MM and GRALS in terms of RMSE, but worse than the two latter compared to Err1. (2) RMF-MM has better performance on Err1 than GRALS, which is more robust to evaluate. (3) Overall, our method delivers the best results using either RMSE or Err1. All the above says that GRMF can benefit from the side information data to enhance rate prediction performance.

In addition, Figure 2 displays the convergence proceeding of GRMF on the toy data. As is shown, GRMF can converge to stable Err1 after about 16 iterations. With more observations on other data sets, Algorithm 1 can have a fast convergence and arrive at an effective solution.

**Table 1.** Err1 and RMSE on toy dataset.


**Figure 2.** The value of Err1 versus iterations of Graph Regularized Robust Matrix Factorization (GRMF) on toy dataset.

5.2.2. The Effects of Parameters on Rating Prediction

We have the graph regularization parameter *α*, the regularization parameter *λ* and the rank of factorized matrices *k* in the objective (7) of GRMF. We here discuss the effects of these three parameters on the prediction performance utilizing the above toy data set on our prediction experiment.

The two parameters of *α* and *λ* vary in wide ranges as is shown in Figure 3b. Figure 3b shows the 3D curve of Err1 created under the effects of *α* and *λ*. From the curve, we observe that there is a broad range of parameter pairs that can be available for producing decent prediction results. Besides, we also probe the effect of the parameter *k*, shown in Figure 3a. The results show that GRMF has the most stable performance under varying *k*, while the RMF-MM has the worst performance.

**Figure 3.** The effects of the parameters of GRMF on the Err1.

### *5.3. Evaluation on Image Data Set*

The problem of image recovery is often formulated as matrix completion. Since the top singular values dominate the main information, most of the images could be regarded as a low-rank matrix. Hence, we apply the proposed method to recover the image from its noisy version. This test aims to recast the experiment conducted in the work of Lin et al. [26]. Concretely, we pollute the images (https://sites.google.com/site/zjuyaohu/) with Gaussian noise or salt-and-pepper noise, then recover the images from the noisy version in comparison with the methods of RMF-MM and GRALS.

#### 5.3.1. Gaussian Noise

We added Gaussian noise with the variance being 1 and mean being 0 to *g* percent of the observed pixels, where *g* is the corruption ratio. Figure 4b shows the example image which was corrupted with Gaussian noise. *g* was varied in the range of [45, 90] to observe the performance in various situations. We ran the three methods to recover the corrupted image in Figure 4b, where the side-information data consists of the rows and columns of the corrupted image. Figure 5 shows the PNSRs (Peak Signal to Noise Ratios) from the three compared methods. From the curves, GRMF consistently achieves the highest PSNRs on all test cases. When the corruption ratio increases, GRMF delivers a much better result than RMF-MM. Note that GRALS has a weak performance because its reconstruction term is very sensitive to data pollutions.

Figure 4c–e depicts the resultant images from the case of *g* = 80, using the three methods. As is clear, our GRMF produces the best visualization, while the other methods suffer a few horizontal or vertical lines.

**Figure 4.** The PNSRs of Image recovery with Gaussian noise.

**Figure 5.** Evaluation of image recovery with Gaussian noise in term of PSNR. The black line that marked with "Corrupted" means the PNSR of the corrupted images.

#### 5.3.2. Salt-And-Pepper Noise

We added the salt-and-pepper noise with noise density varying from 0.05 to 0.65 with a step of 0.05 to image and obtained the corrupted image, like Figure 6b. Then, we ran the three compared methods for denoising the corrupted image where the side-information consists of the rows and the columns of the corrupted image. Figure 7 shows the results of image denoising by GRALS, RMF-MM, and GRMF. From Figure 7 it is clear that GRMF delivers the best performance on PSNR when the noise density is less than 0.4 but drops down if the noise density is greater than 0.5, where the other two methods obtained worse results. The reason is that most pixels of the image are corrupted so that the graphs are difficult to obtain well in a noisy situation. In addition, Figure 6 shows the resultant images when the noise density is 0.4, where our method touches the highest PNSR of 22.88 dB.

### *5.4. Application on Educational Data Set*

The data were collected from the school of Computer Science, Northwestern Polytechnical University (NPU), across students who joined in the past five years, that is, from 2013 to 2017. We collected all the score/grade recorders before the fall of 2017, together with the side information.

**Figure 6.** The Results of Image recovery with salt-and-pepper noise.

**Figure 7.** Evaluation of image recovery with Salt-and-pepper noise.

More specifically, our dataset contains the grades, the side data of student and the side data of courses, respectively denoted by NPU-G, NPU-S, and NPU-C for short. NPU-G is composed of 1325 × 832 grade records from 1325 students at 832 courses. NPU-S contains 25 description features of 1325 students, such as ages, gender, and department. NPU-C includes 18 description features of 882 academic/elective courses, such as hours, type, and course credit. In addition, at least 15 students enrolled and obtained grades in each course, and students starting university in 2013 and 2014 have already completed their program.

SGP in our educational data set has the following challenges: (1) Data sparsity. There are 832 courses in NPU-G, but each student is only required to enroll in a small number of courses, i.e., about 85 courses in our data. (2) data corruption. Many subjective factors affect the final grade, e.g., subjective questions. (3) missing data. A few students do not attend the final exam, and thus give an empty grade in the information system. All this noisy information makes our problem very challenging.

#### 5.4.1. Educational Data Preprocessing

For NPU-G, we removed the students who had lost most of the data records or had taken less than 4 courses, and then removed those courses that were taken by less than 15 students, followed by deleting the secondary courses to ensure a single record per student. Finally, we formulated the remaining records from 882 students and 82 courses into the matrix **M**, ordered by scholar terms. In addition, we transformed the scores into grade 1–6 using the following piecewise function:

$$y = \begin{cases} 1 & 0 < x < 60; \\ 2 & 60 \le x < 70; \\ 3 & 70 \le x < 80; \\ 4 & 80 \le x < 90; \\ 5 & 90 \le x < 100; \\ 6 & x = 100. \end{cases} \tag{11}$$

where *x* is the score in the grade record while *y* is its corresponding grade.

Responding to **<sup>M</sup>** <sup>∈</sup> <sup>R</sup>882×15, we also removed the student and the course from NPU-S and NPU-C. In all collected side descriptions, we selected 15 and 12 important features for NPU-S and NPU-C, respectively, using teaching experience. Finally, we formulated them into matrices **<sup>S</sup>***<sup>u</sup>* <sup>∈</sup> <sup>R</sup>882×<sup>15</sup> for students and **<sup>S</sup>***<sup>v</sup>* <sup>∈</sup> <sup>R</sup>82×<sup>12</sup> for courses.

#### 5.4.2. Implementation Details

We here predict the student grade for each academic term, because of the usual stages at the university. Hence, we used historical records as a training set to predict the grade in the next term. That is, our model was trained on the records from the 1-th to the (*t* <sup>−</sup> 1)-th terms and was tested on the *t*-th terms. Concretely, in the SGP tasks for the *t*-th term, we built *k*-nearest neighborhood graph **G***u*/**G***<sup>v</sup>* on the side data of students and courses **S***u*/**S***v*, respectively. Then, we learned the latent features of student and course on training data using our model, followed by computing the evaluation matrix Err1 and Acc. We conduct this experiment on six data splits, where the sizes of training sets and test sets are listed in Table 2.

**Table 2.** The size of training sets and test sets.


In order to compare with other methods for SGP, we conducted an experiment using MF (S. Rendle, 2010 [46]), NMF (C.S. Hwang, et al., 2015 [22]), PMF (B. Jiang, et al. [47]), KNN ( N.C. Wong, et al., 2019 [48]) and column mean (M. Sweeney, et al. [49]). Besides, we also implement SGP using RMF-MM (Z. Lin, et al., 2018 [26]) and GRALS(N. Rao, et al. [29]). For each method, we selected the optimal one from the wide range suggested by their related reports.

### 5.4.3. Experimental Result and Discussion

Figure 8a and Figure 8b show the prediction results from varying all six terms by various methods in terms of Err1 and Acc. From the curves and comparisons, we observe that: (1) as the semester progresses, the prediction decreases in Err1 and increases in accuracy rate; (2) both GRMF and GRALS are better than other comparable methods; (3) GRMF is not only better than RMF but also outperforms GRALS. (4) the prediction performance of colMean can be regarded as a base performance of SGP. Both GRMF and GRALS can perform better than colMean over all the terms while other methods, including the RMF-MM performance are worse than colmean in most cases.

From these observations, we derive the following conclusions: (1) As the semester progresses, we obtain more information about the student/course which is reflected in the better prediction performance. (2) Side information data of student and course is helpful for SGP. (3) The combination of the side information and the robust *L*<sup>1</sup> regularizers in our methods GRMF improves the prediction performance effectively. (4) The methods using side information do perform well but other comparable methods cannot handle the prediction task well in the real education context due to the complex problem of real educational data. (5) Our proposed method outperforms traditional classification methods and regression methods. (6) The proposed method GRMF can achieve the accuracy of 65.4% in the sixth term, which is more interesting than the other methods.

**Figure 8.** The effects of the parameters of GRMF on the Err1 and Acc.

### **6. Discussion and Conclusions**

In this paper, we solve the student grade prediction (SGP) problem by proposing a novel matrix factorization method that is dubbed GRMF. GRMF integrates the side information with the robust objective function of matrix factorization, which can be effectively solved by the MM optimization algorithm. The extensive experiments are conducted on movie data, image data, and our education data for testing the performance on rate prediction, image recovery, and SGP. The evaluation results by the used matrices show that GRMF can deliver a better performance than all compared methods. In SGP, GRMF can achieve the highest accuracy of about 65.4%. However, it is still weak in our challenging data. We will improve GRMF and try other fashionable methods to pursue a higher accuracy, while boosting a personalized education.

In addition, a function *f* that maps from **<sup>U</sup>** and **<sup>V</sup>** to the grade matrix **<sup>G</sup>** could be used to achieve a better prediction model, due to the gap between the predicted grade and the real grade. That is because the noise is often caused by accidental events, like exam slipping and guessing. Our study has this limitation on considering this noise in grade prediction. Adding this map *f* may help to obtain more accurate results in the real-world environment. We leave this study for future work.

**Author Contributions:** Conceptualization, Y.Z. and Y.Y.; Data curation, Y.Y. and J.C.; Formal analysis, Y.Z.; Funding acquisition, Y.Z. and X.S.; Investigation, Y.Y. and H.D.; Methodology, Y.Z.; Resources, J.C. and X.S.; Software, Y.Z. and Y.Y.; Supervision, X.S.; Validation, H.D. and J.C.; Writing—original draft, Y.Z. and Y.Y.; Writing—review & editing, H.D., J.C. and X.S. All authors have read and agreed to the published version of the manuscript.

**Funding:** This research was funded by the National Natural Science Foundation of China (Grants No. 61802313, 61772426, U1811262) and the Fundamental Research Funds for Central Universities (Grant No. G2018KY0301).

**Acknowledgments:** We thank the editors and any reviewers for their helpful comments.

**Conflicts of Interest:** The authors declare no conflict of interest.

### **Appendix A. Objective Minimization**

Suppose that we already have obtained (**U***k*, **<sup>V</sup>***k*) after the *k*th iterations. We split (**U**, **<sup>V</sup>**) as the sum of (**U***k*, **U***k*) and the unknown residue (Δ**U**, Δ**V**).

$$(\mathbf{U}, \mathbf{V}) = (\mathbf{U}\_k, \mathbf{U}\_k) + (\Delta \mathbf{U}, \Delta \mathbf{V}) \tag{A1}$$

In a similar way, the graph regularization of (6) can be rewritten as follows:

$$\begin{aligned} L(\Delta \mathbf{U}, \Delta \mathbf{V}) &= tr\left( \left( \mathbf{U}^T + \Delta \mathbf{U}^T \right) \mathbf{H}\_{\text{ul}} \left( \mathbf{U} + \Delta \mathbf{U} \right) \right) \\ &\quad \text{tr}\left( \left( \mathbf{V}^T \Delta \mathbf{V}^T \right) \mathbf{H}\_{\text{v}} \left( \mathbf{V} + \Delta \mathbf{V} \right) \right) \end{aligned} \tag{A2}$$

With (7) and (8), our task is to minimize the following:

$$\begin{aligned} \min\_{\mathbf{A}\mathbf{U},\Delta\mathbf{V}} & H\_k \left(\Delta\mathbf{U}, \Delta\mathbf{V}\right) = \\ \min\_{\mathbf{A}\mathbf{U},\Delta\mathbf{V}} & \parallel \mathbf{W} \odot \left(\mathbf{M} - \left(\mathbf{U}\_k + \Delta\mathbf{U}\right) \left(\mathbf{V}\_k^T + \Delta\mathbf{V}\_T\right)\right) \parallel\_1 \\ & + \frac{\lambda}{2} \left(\parallel \mathbf{U} + \Delta\mathbf{U} \parallel\_F^2 + \parallel \mathbf{V} + \Delta\mathbf{V} \parallel\_F^2\right) \\ & + \frac{\alpha}{2} L(\Delta\mathbf{U}, \Delta\mathbf{V}) \end{aligned} \tag{A3}$$

Now our task is to find a small increment (Δ**U**, Δ**V**) such that the objective function keeps decreasing. Inspired by [26], we try to relax (9) to a convex surrogate.

By using the triangular inequality of norms, we arrive at the following inequality:

$$\begin{aligned} &H\_k \left(\Delta \mathbf{U}, \Delta \mathbf{V}\right) \\ &\leq \|\mathbf{W}\odot\left(\mathbf{M} - \mathbf{U}\_k \mathbf{V}\_k^T - \Delta \mathbf{U} \mathbf{V}\_k^T - \mathbf{U}\_k \Delta \mathbf{V}^T\right)\|\_1 \\ &+\frac{\lambda}{2}\left(\|\left|\mathbf{U} + \Delta \mathbf{U}\right\|\|\_F^2 + \|\left|\mathbf{V} + \Delta \mathbf{V}\right\|\|\_F^2\right) \\ &+\frac{a}{2}L\left(\Delta \mathbf{U}, \Delta \mathbf{V}\right) + \|\left|\mathbf{W}\odot\Delta \mathbf{U} \Delta \mathbf{V}^T\right\|\_1. \end{aligned} \tag{A4}$$

Besides, we can introduce the following relaxation:

$$\begin{aligned} &\|\;\mathbf{W}\odot\left(\Delta\mathbf{U}\Delta\mathbf{V}^{T}\right)\|\_{1} \\ &\leq \frac{1}{2}\left\|\;\mathbf{A}\_{\mathsf{u}}\Delta\mathbf{U}\right\|\_{F}^{2} + \frac{1}{2}\left\|\;\mathbf{A}\_{\mathsf{v}}\Delta\mathbf{V}\right\|\_{F}^{2}. \end{aligned} \tag{A5}$$

For simplicity, we define *<sup>J</sup>k* (Δ**U**, <sup>Δ</sup>**V**) as follows:

$$\begin{split} \|\mathbf{J}\_{k}(\mathbf{\varDelta U},\mathbf{\varDelta V})\| \\ = & \|\mathbf{W}\odot\left(\mathbf{M}-\mathbf{U}\_{k}\mathbf{V}\_{k}^{T}-\mathbf{\varDelta U}\mathbf{V}\_{k}^{T}-\mathbf{U}\_{k}\mathbf{\varDelta V}^{T}\right)\| \| \\ + & \frac{\lambda}{2}\left(\|\|\mathbf{U}+\mathbf{\varDelta U}\|\|\_{F}^{2}+\|\|\mathbf{V}+\mathbf{\varDelta V}\|\|\_{F}^{2}\right) \\ + & \frac{\alpha}{2}L(\mathbf{\varDelta U},\mathbf{\varDelta V}). \end{split} \tag{A6}$$

Then we have the relaxed function of *Hk* (Δ**U**, <sup>Δ</sup>**V**). Our optimization problem can be recast as:

$$\begin{split} F\_{\mathbf{k}} \left( \Delta \mathbf{U}, \Delta \mathbf{V} \right) &= f\_{\mathbf{k}} \left( \Delta \mathbf{U}, \Delta \mathbf{V} \right) \\ &+ \frac{1}{2} \parallel \Lambda\_{\mathbf{u}} \Delta \mathbf{U} \parallel\_{\mathbf{F}}^{2} + \frac{1}{2} \parallel \Lambda\_{\mathbf{v}} \Delta \mathbf{V} \parallel\_{\mathbf{F}}^{2} . \end{split} \tag{A7}$$

Thus, our optimization problem (9) can be further rewritten as:

$$\begin{aligned} &\min\_{\mathbf{E},\mathbf{A}\mathbf{U},\mathbf{A}\mathbf{V}} \parallel \mathbf{W} \odot \mathbf{E} \parallel\_1 \\ &+ (\frac{\lambda}{2} \parallel \mathbf{U} + \Delta \mathbf{U} \parallel\_F^2 + \frac{1}{2} \parallel \mathbf{A}\_\mu \Delta \mathbf{U} \parallel\_F^2) \\ &+ (\frac{\lambda}{2} \parallel \mathbf{V} + \Delta \mathbf{V} \parallel\_F^2 + \frac{1}{2} \parallel \mathbf{A}\_\nu \Delta \mathbf{V} \parallel\_F^2) \\ &+ \frac{\kappa}{2} L(\Delta \mathbf{U}, \Delta \mathbf{V}) \\ &\text{s.t.} \quad \mathbf{M} - \mathbf{U}\_k \mathbf{V}\_k^T = \mathbf{E} + \Delta \mathbf{U} \mathbf{V}\_k^T + \mathbf{U}\_k \Delta \mathbf{V}^T. \end{aligned} \tag{A8}$$

where **Λ***u*, **Λ***<sup>v</sup>* are diagonal matrices.

We optimize the objective by the Linearized Alternating Direction Method with Parallel Splitting and Adaptive Penalty (LADMPSAP) [45], as follows.

### *Appendix A.1. Updating E*

Fixing other variables, updating **E** is equivalent to the following problem:

$$\min\_{\mathbf{E}} \parallel \mathbf{W} \odot \mathbf{E} \parallel\_1 + \parallel \mathbf{E} - \mathbf{E}^i + \hat{\mathbf{Y}}^i / \delta\_\varepsilon^{(i)} \parallel\_F^2 \,. \tag{A9}$$

where

$$\mathbf{\hat{Y}}^i = \mathbf{Y}^i + \boldsymbol{\beta}^i (\mathbf{E}^i + \boldsymbol{\Delta} \mathbf{U}^i \mathbf{V}\_k^T + \mathbf{U}\_k \boldsymbol{\Delta} \mathbf{V}^{iT} - \mathbf{M} + \mathbf{U}\_k \mathbf{V}\_k^T), \tag{A10}$$

and *δ* (*i*) *<sup>e</sup>* = *ηeβ<sup>i</sup>* , *η<sup>e</sup>* = 3*Le* + *ε*, where 3 is the number of variables which have to be updated in parallel, such as **E**, Δ**U***<sup>i</sup>* , and Δ**V***<sup>i</sup>* . Specially, *Le* is the squared spectral norm of the linear mapping on **E**, which is equal to 1, and *ε* is a small positive scalar. Then we update **E** by:

$$\mathbf{E}^{i+1} = \mathbf{W} \odot \mathbf{S}\_{\sigma\_{\mathbf{f}}^{(i)}} (\mathbf{E}^i - \hat{\mathbf{Y}}^i / \delta\_{\mathbf{t}}^{(i)}) + \mathbf{w} \odot (\mathbf{E}^i - \hat{\mathbf{Y}}^i / \delta\_{\mathbf{t}}^{(i)}),\tag{A11}$$

where **S** is the shrinkage operator [50]:

$$\mathbf{S}\_{\varepsilon}(\mathbf{x}) = \max(|\mathbf{x}| - \varepsilon, 0)\mathbf{sgn}(\mathbf{x}),\tag{A12}$$

where **w**¯ is the complement of **W**.

### *Appendix A.2. Updating* **ΔU**

Updating **ΔU** is to solve the following problem:

$$\begin{aligned} &\min\_{\mathbf{AU}} \frac{\lambda}{2} \parallel \mathbf{U}\_k + \Delta \mathbf{U} \parallel\_F^2 \\ &+ \frac{\alpha}{2} tr((\mathbf{U}\_k^T + \Delta \mathbf{U}^T) \mathbf{H}\_\mathbf{u} (\mathbf{U}\_k + \Delta \mathbf{U})) + \frac{1}{2} \parallel \mathbf{A}\_\mathbf{u} \Delta \mathbf{U} \parallel\_F^2 \\ &+ \frac{\delta\_u^{(i)}}{2} \parallel \Delta \mathbf{U} - \Delta \mathbf{U}^i + \mathbf{\hat{Y}} \mathbf{V}\_k / \delta\_\mu^{(i)} \parallel\_F^2 \end{aligned} \tag{A13}$$

where *δ* (*i*) *<sup>u</sup>* <sup>=</sup> *<sup>η</sup>uβ<sup>i</sup>* and *<sup>η</sup><sup>u</sup>* <sup>=</sup> <sup>3</sup> **<sup>V</sup>***<sup>k</sup>* <sup>2</sup> <sup>2</sup> +*ε*. Since all terms in (A13) is convex, (A13) is a convex problem and its closed solution can be obtained by:

$$\begin{aligned} \Delta \mathbf{U}^{i+1} &= \\ (\lambda \mathbf{I}\_{\text{ $\boldsymbol{\mu}$ }} + a \mathbf{H}\_{\text{\boldsymbol{\mu}}} + \mathbf{A}\_{\text{\boldsymbol{\mu}}}^{T} \mathbf{A}\_{\text{\boldsymbol{\mu}}} + \boldsymbol{\delta}\_{\text{\boldsymbol{\mu}}}^{(i)} \mathbf{I}\_{\text{\boldsymbol{\mu}}})^{-1} \\ (-\lambda \mathbf{U}\_{k} - a \mathbf{U}\_{k} \mathbf{H}\_{\text{\boldsymbol{\mu}}} + \boldsymbol{\delta}\_{\text{\boldsymbol{\mu}}}^{(i)} \mathbf{A} \mathbf{U}^{i} - \boldsymbol{\delta}\_{\text{\boldsymbol{\mu}}}^{(i)} \mathbf{Y}^{i} \mathbf{V}\_{k} / \boldsymbol{\delta}\_{\text{\boldsymbol{\mu}}}^{(i)}), \end{aligned} \tag{A14}$$

where *m* can be found in the paper [26].

### *Appendix A.3. Updating* **ΔV**

Similar to **ΔU** , updating **ΔV** can be achieved by:

$$\begin{split} \Delta \mathbf{V}^{i+1} &= \\ (\lambda \mathbf{I}\_m + a \mathbf{H}\_v + \mathbf{A}\_v^T \mathbf{A}\_v + \delta\_v^{(i)} \mathbf{I}\_m)^{-1} \\ (-\lambda \mathbf{V}\_k - a \mathbf{V}\_k \mathbf{H}\_v + \delta\_v^{(i)} \Delta \mathbf{V}^i - \delta\_v^{(i)} \mathbf{\hat{Y}}^i \mathbf{U}\_k / \delta\_v^{(i)}). \end{split} \tag{A15}$$

### *Appendix A.4. Updating* **Y** *and β*

We update **Y** and *β* as follows:

$$\begin{split} \mathbf{Y}^{i+1} &= \mathbf{Y}^i + \boldsymbol{\beta}^i (\mathbf{E}^{i+1} + \boldsymbol{\Delta} \mathbf{U}^{i+1} \mathbf{V}\_k^T \\ &+ \mathbf{U}\_k \boldsymbol{\Delta} \mathbf{V}^{(i+1)T} \mathbf{U}\_k \mathbf{V}\_k^T - \mathbf{M} \mathbf{)}, \end{split} \tag{A16}$$

$$\boldsymbol{\beta}^{i+1} = \min(\boldsymbol{\beta}^{\max}, \boldsymbol{\rho}\boldsymbol{\beta}^{i}),\tag{A17}$$

where *ρ* is defined by:

$$\rho \,\,=\begin{cases} \rho\_{0\prime} & \text{if } \mathbf{Q} < \varepsilon\_1 \\ 1, & \text{otherwise}, \end{cases} \tag{A18}$$

and

$$\begin{split} \mathbf{Q} = \beta^i \max(\sqrt{\eta\_{\ell}}||\mathbf{E}^{i+1} - \mathbf{E}^i||\_{F'} \\ \sqrt{\eta\_{\boldsymbol{\varPi}}}||\Delta \mathbf{U}^{i+1} - \Delta \mathbf{U}^i||\_{F'} \\ \sqrt{\eta\_{\boldsymbol{\varPi}}}||\Delta \mathbf{V}^{i+1} - \Delta \mathbf{V}^i||\_F \rangle / ||\mathbf{M} - \mathbf{U}\_k \mathbf{V}^T\_k||\_F. \end{split} \tag{A19}$$

In addition, the stopping criterion of iteration can be derived from KKT condition [45]:

*β<sup>i</sup>* max( <sup>√</sup>*ηe*||**E***i*+<sup>1</sup> <sup>−</sup> **<sup>E</sup>***<sup>i</sup>* ||*F*, <sup>√</sup>*ηu*||Δ**U***i*+<sup>1</sup> <sup>−</sup> <sup>Δ</sup>**U***<sup>i</sup>* ||*F*, <sup>√</sup>*ηv*||Δ**V***i*+<sup>1</sup> <sup>−</sup> <sup>Δ</sup>**V***<sup>i</sup>* ||*F*)/||**<sup>M</sup>** <sup>−</sup> **<sup>U</sup>***k***V***<sup>T</sup> k* ||*F* < *ε*1, (A20) ||**E***i*+<sup>1</sup> <sup>−</sup> <sup>Δ</sup>**U***i*+1**V***<sup>T</sup> <sup>k</sup>* <sup>−</sup> **<sup>U</sup>***k*Δ**V**(*i*+1)*T***U***k***V***<sup>T</sup> k* ||*F <sup>k</sup>* ||*<sup>F</sup>* <sup>&</sup>lt; *<sup>ε</sup>*2. (A21)

/||**<sup>M</sup>** <sup>−</sup> **<sup>U</sup>***k***V***<sup>T</sup>*

Finally, Algorithm A1 is here rewritten in details as follows:

**Algorithm A1** Graph Regularized Robust Matrix Factorization (GRMF) by Majorization Minimization

**Input**: **<sup>M</sup>** <sup>∈</sup> <sup>R</sup>*n*×*m*, *<sup>α</sup>*, and *<sup>λ</sup>*

### **Output**: **U** and **V Method**:

Initialize **<sup>U</sup>**<sup>0</sup> and **<sup>V</sup>**<sup>0</sup> with using SVD on **<sup>M</sup>**; **<sup>E</sup>**<sup>0</sup> <sup>=</sup> **<sup>M</sup>** <sup>−</sup> **<sup>U</sup>**0**V***<sup>T</sup>* <sup>0</sup> , and <sup>Δ</sup>**U**<sup>0</sup> = <sup>Δ</sup>**V**<sup>0</sup> = **<sup>Y</sup>**<sup>0</sup> = 0. Besides, *ρ*<sup>0</sup> = 1.5. and *ε* = *ε*<sup>1</sup> = *ε*<sup>2</sup> = *ε*<sup>3</sup> = 1*e* − 5.

**While** not converged when we arrived [**U***k*, **V***k*], do Let *t* = 1 and *β*<sup>0</sup> =∝ (*m* + *n*)*ε*1;

**While** (A20) and (A21) are not satisfied do

Update **E***<sup>t</sup>* by (A11); Update Δ**U***<sup>t</sup>* and Δ**V***<sup>t</sup>* via (A14) and (A15); Update **Y***<sup>t</sup>* by (A16); Update *β<sup>t</sup>* by (A17); *t*=*t*+1;

### **End while**

Update **U** and **V** in parallel:

**U***k*+<sup>1</sup> = **U***<sup>k</sup>* + Δ**U***t*; **V***k*+<sup>1</sup> = **V***<sup>k</sup>* + Δ**V***t*;

Check the convergence coditions, if

```
Vk+1 − Vk < ε2 and Uk+1 − Uk < ε3;
```
**End while**.

### **References**


© 2020 by the authors. Licensee MDPI, Basel, Switzerland. This article is an open access article distributed under the terms and conditions of the Creative Commons Attribution (CC BY) license (http://creativecommons.org/licenses/by/4.0/).

## *Article* **Use of Deep Multi-Target Prediction to Identify Learning Styles**

### **Everton Gomede 1,\*, Rodolfo Miranda de Barros <sup>2</sup> and Leonardo de Souza Mendes <sup>1</sup>**


Received: 28 January 2020; Accepted: 26 February 2020; Published: 4 March 2020

### **Featured Application: Our results can be applied to identifying of students' learning style providing adaptation to e-learning systems.**

**Abstract:** It is possible to classify students according to the manner they recognize, process, and store information. This classification should be considered when developing adaptive e-learning systems. It also creates a comprehension of the different styles students demonstrate while in the process of learning, which can help adaptive e-learning systems offer advice and instructions to students, teachers, administrators, and parents in order to optimize students' learning processes. Moreover, e-learning systems using computational and statistical algorithms to analyze students' learning may offer the opportunity to complement traditional learning evaluation methods with new ones based on analytical intelligence. In this work, we propose a method based on deep multi-target prediction algorithm using Felder–Silverman learning styles model to improve students' learning evaluation using feature selection, learning styles models, and multiple target classification. As a result, we present a set of features and a model based on an artificial neural network to investigate the possibility of improving the accuracy of automatic learning styles identification. The obtained results show that learning styles allow adaptive e-learning systems to improve the learning processes of students.

**Keywords:** deep multi-target prediction; Felder–Silverman learning style; adaptive e-learning systems; artificial neural network; deep learning

### **1. Introduction**

According to Willingham [1], people are normally curious but are not naturally acceptable masterminds; unless cognitive conditions are adequate, humans abstain from reasoning. This behavior is attributed to three properties. To begin with, reasoning is used with moderation; human's visual system is proficient to instantly take in a complex scene, although it is not inclined to instantly solve a puzzle. Additionally, reasoning is tiresome because it requires focus and concentration. Finally, because we ordinarily make mistakes, reasoning is uncertain. In spite of these aspects, humans like to think. Solving problems produces pleasure because there is an overlap between the brain's areas and chemicals that are important in learning and those related to the brain's natural reward system [1]. Thus, adjusting a student's cognitive style might help to improve the student's reasoning capacity.

Moreover, according to Felder and Silverman [2], learning styles (a part of cognitive styles) describe students' preferences on how some subject is presented, how to work with that subject matter, and how to internalize (acquire, process, and store) information [2]. According to Willingham [1], students may have diverse preferences on how to learn. Thus, knowing a student's learning style

can help in finding the most suitable way to improve the learning process. There are some studies which show that learning styles allow adaptive e-learning systems to improve the learning processes of students [3–7].

We characterize learning as the procedure through which information is encoded and stored in the long-term memory. For instance, customizing content to the learning styles of students is seen as useful to learning in different ways, for example, improving fulfillment, increasing learning results, and decreasing learning time [3]. Several research works in learning systems proposed that students' learning improved when the instructor's teaching style matched the learning style of the students [5].

According to Bernard et al. [5], there are several methods to classify learning styles. One of the most known is the Felder–Silverman learning styles model (FSLSM). This model proposes four dimensions to classify learning styles, with each dimension classified in an interval ranging from 0 to 11. The first dimension of this model is active/reflective, which determines if someone prefers first experimenting with some subject and then reasoning about it (active) or reasoning first and then experimenting on the subject (reflective). The second dimension is sensing/intuitive, which determines if someone prefers touching things to learn (sensing) or observing things to induce information (intuitive). The third dimension, visual/verbal, determines if someone prefers to see charts, tables, and figures (visual) instead of reading or listening to texts (verbal), or the contrary. Finally, the sequential/global dimension determines if someone prefers to get information in a successive manner, learning step-by-step (sequential), or to get an outline of the information first and go into detail afterwards, without a predefined sequence (global).

According to Willingham [1], the prediction of any learning styles theory is that a particular teaching method might be good for some students but not so good for others. Therefore, it is possible to take advantage of different types of learning. It is important to understand the difference between learning ability and learning style. Learning ability is the capacity for or success in certain types of thought (math, for example). In contrast to ability, learning style is the tendency to think in a particular way, for example, thinking sequentially or holistically, which is independent of context [1]. There is an enormous contrast between the popularity of learning styles in education and evidence of their usefulness. According to Pashler [4], the reasonable utility of using learning styles to improve student learning still needs to be demonstrated. However, in their study, Kolb and Kolb [6] examined recent research advancements in experimental learning in higher education and analyzed how it can improve students' learning. In their studies they concluded that learning styles can be based on research and clinical observation of learning styles' score patterns and applied throughout the educational environment through an institutional development program.

As indicated by Willingham [1], psychologists have made a few approaches to test this learning proposition, leading to some hypotheses. First, learning style is considered as stable within an individual. In other words, if a student has a particular learning style, that style ought to be a stable part of that student's cognitive makeup. Second, learning style ought to be consequential; therefore, using a specific learning style should have implications on the outcomes of the student's learning. Thus, learning styles theory must have the following three features: (1) a specific learning style should be a stable characteristic of a person; (2) individuals with different styles should think and learn differently; and (3) individuals with different styles do not, on average, differ in their ability. Traditionally, learning styles are mainly measured through surveys and questionnaires, where students are asked to self-evaluate their own behaviors. However, this approach presents some problems. First, external interference can disturb the results during its application. Second, the outcomes are influenced by the quality of the survey or questionnaire. Finally, different students may interpret questions in different ways [5].

According to Bernard et al. [4], the characterization of learning styles is a problem that deals with many descriptors and many outputs. The descriptors may arise from many sources, such as logs, questionnaires, and databases. Moreover, the descriptors are usually associated with learning objects, such as forums, contents, outlines, quizzes, self-assignments, examples, and other types of resources. The outputs are used to permit the comprehension of learning style resulting from a combination of descriptors, which may indicate whether a student can be classified as active/reflective, sensing/intuitive, visual/verbal, or sequential/global, based on his/her approach to recognize, process, and store information. This problem is relevant because it is the first step to understand the cognitive condition to improve learning using e-learning systems [5].

In this context, our research aims to investigate the use of computational intelligence (CI) algorithms to analyze and improve the accuracy of autonomic approaches to identify learning style. Our hypothesis is that if the learning style can be correctly identified using CI then the student's learning preference may also be predicted. Thus, we conducted this research to identify features that may represent a student's learning style based on massive information (big data) collected in a massive open online course (MOOC) environment and use these features to classify these learning styles. We also investigate whether a theory of learning style might be more suited to classification than others. Finally, we investigate algorithms to overcome limitations found in contemporary works.

This paper is organized as follows: This first section presents basic considerations and justifications for this work and defines its main objectives. Section 2 presents an overall review of the key topics treated here and the main definitions upon which this work is based. Section 3 presents the main concepts behind learning styles classification and describes the proposed model. Section 3 also presents the data structures used to characterize the subjects, along with their materials and methods. Section 4 presents the results obtained from the data analysis and specifies recommendations to stakeholders. In the last section, conclusions and future developments are presented.

### **2. Related Work and Concepts**

According to Truong [7] and Normadhi [3], researchers have been searching for mechanisms to automatically detect student's learning styles based on different models. The process of automatic learning style detection can be divided into three subproblems: (a) select a suitable learning model, (b) select the descriptors and targets to represent a student's online behavior (in a MOOC), and (c) select the algorithm (and hyperparameters) which fit to the multi-target prediction issue. This procedure is shown in Figure 1.

**Figure 1.** The process to build a model for automatic detection of learning style [5]. In this paper, we aim to investigate the steps a, b, and c. MOOC = massive open online course.

As shown in Figure 1, to classify students' learning styles, some researchers focused on the use of algorithms while others focused on the application of this model using traditional methods, such as questionnaires (dashed line). In this section, we compare papers that used these approaches.

### *2.1. Learning Styles Model Selection*

According to Willingham [1], learning styles theory predicts that a particular teaching method may be good for one person, but not good for another. Therefore, in order to optimize a student's capacity to learn we need to exploit these different methods of learning. As previously said, it is imperative to comprehend the difference between learning ability and learning style. Learning ability is the capacity for, or success in, certain types of subjects (math, for example). In contrast to ability, learning style is a tendency to reason in a particular way, for example, sequentially or holistically. As already pointed out, there is a differentiation between the popularity of learning styles approaches within education and the lack of credible evidence for its utility. As indicated by Pashler [4], whether characterization of students' learning styles has any reasonable utility has yet to be determined. However, an investigation by Kolb and Kolb [6] examined ongoing advancements in the hypotheses and research on experiential learning and explored how it can help improve learning in higher education. In addition, Kolb and Kolb presumed that learning styles are based on both research and clinical observation of the patterns of learning styles' scores and can be applied throughout the educational environment by an institutional development program.

Various components of learning styles have been researched, both conceptually and empirically [3]. In addition, numerous hypotheses and multiple taxonomies attempting to describe how people reason and learn have been proposed, arranging individuals into distinct groups. Moreover, as indicated by Omar et al. [8], different learning style instruments to research and pedagogical purposes have been produced. From FSLSM theory, there are four dimensions that describe learning styles: processing, perception, input, and understanding. According to Truong [7] and Normadhi [3], many researchers developed automatic detection models based on the FSLSM, whose four dimensions are directly derived from its four objectives—processing, perception, input, and understanding. The processing dimension characterizes the active and reflective learners which are identified by their interest in performing physical or theoretical activities. Active students are the individuals who prefer to work in groups and perform numerous activities whereas reflective students prefer to work alone and perform some exercises. The sensitive and the intuitive learners are characterized by the perception dimension. Sensitive learners are those who are more attentive and careful and normally achieve their goals with few trials, presenting a high rate of exercise completion and reaching high performance in exams. On the other side, intuitive learners often become bored by details, show carelessness, and only achieve their goals with several trials presenting a low rate of exercise completion and reaching low performance in exams. The input measurement recognizes students by their inclination upon the visual or the verbal content and processes when studying and participating in group activities. Finally, the understanding dimension decides whether students incline towards the sequential or global methodology on understanding subjects of study. Sequential students prefer to move toward study and information in a sequential manner, similar to a road map, whereas global students prefer to get an overview and afterward dive into details, attempting to comprehend specific points and link that information with others [9].

To start the automatic learning style determination, the initial step is to choose a reasonable learning styles model. Nowadays, more than 70 models have been proposed, with some overlapping in their selection approaches. According to Truong [7], these models present some issues in terms of validity and reliability, with most of them presenting similar performances. From these, the Felder–Silverman learning styles model (FSLSM) is frequently used for automatic learning styles identification. Graf et al. [10] presented three reasons to select the FSLSM: (a) it uses four dimensions, allowing for more detailed classification; (b) it describes the preference to gather, process, and store information; and (c) it deals with each dimension as a tendency instead of an absolute type. These dimensions can be seen as a continuum with one learning inclination on the extreme left and the other on the extreme right, as per Saryar [11].

### *2.2. Descriptors and Target Selection*

There are three primary sources of features that are recognized: log files, static information, and other personalization sources. The potential sources of data and the corresponding characteristics can be summarized as follows:


Regardless of the several predictors considered, none of the research addressed how different attributes contribute to predicting learning styles. The finding of such comparisons can assume an important role in improving the efficiency of different predictions.

### *2.3. Classification Algorithm Development and Evaluation*

One of the most popular strategies used to classify and evaluate is the rule-based algorithm, in which researchers interpret different styles, according to the hypotheses, into different statistical rules. This method is used in Bayesian network and naïve Bayes rules. Moreover, other algorithms such as artificial neural network, ant colony optimization, particle swarm optimization, genetic algorithm, and decision tree can also be applied for classification. Among these algorithms, the one that accomplishes the optimal accuracy is artificial neural network (ANN) [5]. The common manner used to evaluate the models is splitting the dataset into training and test sub-datasets [17].

### *2.4. Related Works*

In a review paper, Truong [7] presents a study summarizing several works in an overview of models used for learning styles classification. This paper analyzed 51 works, dividing learning style classification into three subproblems. According to the author, the models can be categorized into those that change over the time, those that change over situations, and those that do not change. In addition, the utilization of learning styles provides instructors with a tool to comprehend their students. Truong also shows that there is an association between learning styles and career choices. Based on this, suggestions and direction to support profession path planning can be developed. The author also divides the studies into those that only classify learning style and those that make predictions based on descriptors provided by user behavior. This last one is used for personalization and recommendation in e-learning systems.

In another survey paper, Normadhi et al. [3] stated that the techniques used to recognize personality characteristics can be divided into three categories: (a) questionnaires, (b) computer-based detection, and (c) both. Computer-based identification strategies are most often used to improve obtaining personality trait data in a student profile by analyzing implicit user input. These techniques are considered more accurate than the questionnaire techniques because they respond quickly to changes in the learner's personality characteristics. Computer-based recognition techniques can be categorized

as machine learning, non-machine learning, and hybrid. Additionally, computer-based recognition techniques can be important for new students since information is initially insufficient to construct an appropriate student profile. In addition, the authors state that most researchers use personality traits in the cognition learning domain category (62.82%), in the adaptive learning environment, and in the model dimension, which are frequently used in the Felder–Silverman model (FSLSM). The authors also claim that the results of identification techniques have a positive and large influence on adaptive learning environments. For example, exploring observational assessment for adaptive e-learning environments is especially relevant. Research that conducts experiments to compare the effectiveness and efficiency of identification techniques is additionally highly encouraged. Finally, future examinations ought to explore and investigate the strength and weaknesses of personality traits that map into the learning object and materials selected [3].

Bernard et al. [5] investigated four computational intelligence algorithms (artificial neural network, ant colony optimization, genetic algorithm, and particle swarm optimization) to improve the accuracy of learning style detection. As a result, the authors achieved an average accuracy of 80% using artificial neural network. The authors also pointed out the drawbacks of using questionnaires, such as (a) it is assumed those learners are motivated to fill out the questionnaire; (b) they will fill it out fully (without influence); and (c) they understand how they prefer to learn. The authors used the FSLSM and relevant behavior descriptors from Graf et al. [10]. The authors also linked these descriptors with learning styles indicating that each descriptor is associated with a learning style. These descriptors are based on different types of learning objects including content, outlines, examples, exercises, self-assessment, quizzes, and forums. These descriptors consider the time which a student spends on a certain type of learning object (e.g., content\_stay) and how often a student visits a certain type of learning object (e.g., content\_visit). Moreover, questions were classified based on whether they are about concepts, if they require details or a general view of knowledge, if they include graphics, or if they use text only. These questions also deal with developing or interpreting solutions. Further, the authors presented metrics to evaluate the results. The performance of the proposed approaches was measured using four metrics: (a) SIM (similarity), (b) ACC (accuracy), (c) LACC (the lowest accuracy); and (d) % Match (percentage of students identified with a reasonable accuracy).

Another original paper, Sheeba and Krishnan [9] proposed a way to deal with classifying students' learning style based on their learning behavior. This approach is based on a decision tree classifier for the development of significant rules which are required for accurately distinguishing learning styles. This approach was experimented on 100 students for an online course created in the Moodle learning management system (LMS). In this experiment the authors accomplished the average accuracy of 87% in process, perception, and input dimension. The authors also presented two methods utilized for automatic recognition of learning styles: data-driven and literature-based approaches. The data-driven approach uses sample data to build a classifier that memorizes a learning style instrument. This approach predominantly uses artificial intelligence (AI) classification algorithm which takes the learner model as input and returns learners' learning style preferences as output. The literature-based approach utilizes simple rules to calculate learning styles from the quantity of matching hints. They used a dataset from web log files containing all the behaviors that the learner performed in Moodle LMS. These logs were automatically created when the students used the system. It records all the activities of forums, chats, exercises, assignments, quizzes, exam delivery, and frequency of accessing course materials [10].

Thus, our work aims to contribute to the papers analyzed here by proposing methods and procedures to overcome the current limitations. Here we first identify that the descriptors of previous studies are related with specific dimensions in the computational model. However, the psychological model, FSLSM [2], does not follow the same approach. For example, a student visits a course outline, activating the descriptor "outline\_visit", which can be interpreted as a unique feature of the perception dimension (sensing/intuitive) [4,5]. Therefore, we investigate the interference of all the descriptors in the four dimensions using a multiple classification technique. Second, the strategy of labeling the

dataset is vague. The logs provided for MOOC do not label learning style, only the behavior. The authors do not provide a clear method to label the dataset used in the training model [4,5]. Moreover, they do not present common problems related to datasets, such as imbalanced datasets [18]. Third, the context of dataset is not described in a comprehensible way. For example, the authors present the information of students' level (undergraduate students), however they do not indicate the average age, type of course, duration of course, frequency, and results (pass/fail) [4,5]. Moreover, with respect to computational intelligence techniques the authors do not provide an explicit strategy to overcome overfitting problems, a strategy to achieve optimal parameters (such as number of hidden layers in an artificial neural network), and a strategy to train and test the built models. Finally, there is a lack of performance metrics, such as f-score, recall, precision, sensitivity, and others [4,5]. This harms the possibility of comparing with related works (current and future) and does not present a different analysis from the test results.

### **3. Materials and Methods**

The manner of integrating learning styles into an adaptive e-learning system may be divided into two essential areas: the build of a learning styles prediction model using online data (or the online learning styles classification model) and the application of this built model to an adaptive e-learning system. The development begins with choosing the learning styles model, for example, FSLSM. This is followed by determining the data sources and the learning styles attributes, and classification algorithm selections. After the evaluation, the suitable classification models and their outcomes are carried out for specific factors of the adaptive e-learning system.

The first step to build a model based on a computational intelligence algorithm is to collect and prepare a dataset. The students' behavior was collected from an LMS (learning management system) developed specifically for this experiment. The learning objects used were content, outlines, self-assessments, exercises, quizzes, forums, questions, navigation, and examples. The behavior was collected as described in Table 1. The 100 students graduated in Computer Science and enrolled in a post-graduate program in Computer Science and Project Management. The 26 descriptors were based on the Sheeba and Krishnan [9] and Bernad et al. [5] models. These descriptors were grouped into nine learning objects which are presented to the student in an LMS course. The dataset was composed of three types of measure: (a) "count", which represents the number of times a student visits a learning object; (b) "time", which represents the time the student spends in a learning object; and (c) "Boolean", which represents the students' results when responding to questions on a quiz. This record was collected for 15 days, and to summarize all the results obtained by the students, each descriptor was represented by the average of the students' logs. The questions on the self-assessment quizzes were categorized based on whether they are about facts or concepts, require knowledge about details or overview knowledge, include graphics, charts or text only, and address building or interpreting solutions. Table 1 shows the descriptors that were collected from the LMS. These descriptors also are considered as independent variables to build our model.

The resulting dataset does not provide a description of a learning style for each student. This information is necessary to train an algorithm based on supervised learning [5]. To overcome this problem, we used the Felder–Silverman questionnaire, (the original questionnaire which we used can be viewed at [19]) an adaptation to collect each student's learning style. This questionnaire classifies a student in FSLSM using four dimensions: (a) processing (active/reflective); (b) perception (sensing/intuitive); (c) input (visual/verbal); and (d) understanding (sequential/global). This classification is constructed defining a range for each dimension (for example, processing) from (−11:0) (active) to (0:11) (reflective), and so forth. The dataset's labels are shown in Table 2. These labels are also considered as dependent variables in our model.


**Table 1.** Descriptors of behavior collected from massive open online course (MOOC) (or independent variables).

<sup>1</sup> The time and count are the average of everything accessed from a student for 15 days. Range of count >0 and range of time is from 0 to 120 s (this limit is because of session expiration time of MOOC).



<sup>1</sup> The count is the result of Felder–Silverman's questionnaire (0 values are not present).

The labels shown in Table 2 represent the students' learning behavior in the FSLSM. In these cases, the 0 values (or absence of preference) are not considered in labels because when a student fills out the questionnaire, he/she needs to choose the options which represent a dimension. The overall working flow for this process is shown in Figure 2, below.

**Figure 2.** The process to build the dataset (independent and dependent variables). The questionnaire was used to label the dataset for each student's observation.

As shown in Figure 2, step 1 collects data from MOOC when a student interacts with a course. Then, in step 2 the system fills out the questionnaire for this student. In step 3, the results from the questionnaires based on the FSLSM are fed into a dataset. Finally, in step 4, the descriptors (independent variables) and labels (dependent variables) are combined with the raw dataset to produce an extended student classification dataset.

Since the dataset scale might be different for each student's measure (count, time, and Boolean), the next step to proceed with the dataset construction is to normalize the data to suitably compare information among students. Neural networks can be used to normalize data in order to improve their accuracy [5]. When analyzing two or more attributes it is often necessary to normalize the values of the attributes (for example, content\_stay and content\_visit), especially in those cases where the values are vastly different in scale. We use the range normalization [17] described in Equation (1):

$$\mathbf{x}\_{\mathbf{i}}^{\prime} = \frac{\mathbf{x}\_{\mathbf{i}} - \min\_{\mathbf{i}} \{ \mathbf{x}\_{\mathbf{i}} \}}{\max\_{\mathbf{i}} \{ \mathbf{x}\_{\mathbf{i}} \} - \min\_{\mathbf{i}} \{ \mathbf{x}\_{\mathbf{i}} \}} \tag{1}$$

After this transformation, the new attribute takes on values in the range (0, 1). Moreover, we converted the range of each dimension (processing, perception, input, and understanding) from (−11:0) and (0:11) to (0, 1). This transformation is required for two reasons: (a) the learning styles are a tendency [2,4], thus, to represent a student as active/reflective, we used a binary variable (e.g., active or reflective, instead of 11 times active or 11 times reflective) as a relaxation problem strategy, and (b) to improve the accuracy of the algorithm to classify four outputs. This operation is shown in Equation (2).

$$\begin{array}{l} \text{IF (active\\_reflective } < 0 \text{) (THEN active\\_reflective = true ELSE false)}\\ \text{IF (sensing\\_intative } < 0 \text{) (THEN sensing\\_intative = true ELSE false)}\\ \text{IF (visual\\_verbal } < 0 \text{) (THEN visual\\_verbal = true ELSE false)}\\ \text{IF (sequential\\_global } < 0 \text{) (THEN sequential\\_global = true ELSE false)} \end{array} \tag{2}$$

In this case, each dimension receives TRUE to element at left and FALSE to element at right. For example, if a student's processing dimension is <0, then the student receives TRUE denoting that it is active. On the other hand, if a student's processing dimension is >0, then the student receives FALSE, denoting that it is reflective.

In addition, we investigate whether the dataset is imbalanced for each target. Imbalanced datasets mean the instances of one class is larger than the instances of another class (for example, more sequential rather than global in understanding dimension), where the majority and minority of class or classes are taken as negative and positive, respectively [11]. Figure 3 shows the distribution of each target.

**Figure 3.** The dataset's target distribution of each dimension. The % of each student's preferences are represented, such a, sequential 42%, global 58%, visual 38%, verbal 62%, sensing 45%, intuitive 55%, active 49%, and reflective 51%.

As shown in Figure 3, this dataset does not have imbalanced data for any of the targets. For each dataset, the imbalance ratio (IR) is given by the division of the majority class by the minority class [18]. As a result, we obtained active\_reflective (1.04), sensing\_intuitive (1.22), visual\_verbal (1.63), and sequential\_global (1.38).

The algorithm chosen for multi-target prediction was artificial neural network (ANN) for five reasons: (a) there is evidence that this algorithm is better suited to solve learning style classification problems [5]; (b) since many authors use this algorithm, we can compare our results with other published ones [3]; (c) ANN works well with rather small datasets, which is important for this line of research considering that typical datasets are rather small [17]; (d) the problem can be translated to the network structure of an ANN; and (e) ANN allows multiple outputs analyzed at the same time. Moreover, the ANN architecture we used is feedforward multilayer perceptron, which means a neural network with one or more hidden layers [17,20].

The hidden layers act as feature detectors; as such, they play an important role in the operation of a multilayer perceptron. As the learning process advances throughout the multilayer perceptron, step by step the hidden neurons start to discover the features that describe the training data. They do so by performing nonlinear processing on the input data and transforming them into a new space, called the feature space. In this new space, the classes of interest in a pattern-classification task, for instance, may be more easily separated from each other than they could in the original input data space. Indeed, it is the creation of this feature space through supervised learning which distinguishes the multilayer perceptron from perceptron. Literature suggests that the number of hidden layers should be between log T (where T is the size of the training set) and 2× the number of inputs [17].

A popular approach for training multilayer perceptron is the back-propagation algorithm, which incorporates the least mean squares (LMS) algorithm as a special case. The training proceeds in two steps. In the first one, referred to as the forward phase, the synaptic weights of the network are updated and the input signal is propagated through the network, layer by layer, until the output. As a consequence, in this phase, adjustments are limited to the activation potentials and outputs of the neurons in the network. (b) In the second one, called the backward phase, an error signal is produced by evaluating the output of the network with an expected response. The resulting error signal is propagated through the network, again layer by layer, but this time the propagation is performed in the backward direction. In this second step, successive updates are made to the synaptic weights of the network. Calculation of the updates for the output layer is straightforward, but it is much more difficult for the hidden layers [17].

The back-propagation algorithm affords an approximation to the trajectory in weight space computed by the method of the stochastic gradient descent [17]. The smaller the value of the learning rate parameter α, the smaller the changes to the synaptic weights in the network. Consequently, it will be from the first iteration to the next in a smoother fashion during the trajectory in weight space. This improvement, however, is attained at the cost of a slower rate of learning. The learning rates used were between 0.01 and 0.1, in steps of 0.01, leading to the following values: (0.01, 0.02, 0.03, 0.04, 0.05, 0.06, 0.07, 0.08, 0.09, 0.1) [17].

A training set is one of labeled data (for example, if a student is active or reflective) providing known information, which is used in supervised learning to build a classification or regression model. The training dataset is used to train the model (weights and biases in the case of artificial neural network) and then the model can see and learn from this data.

The model test is a critical but frequently underestimated part of model building and assessment. After preprocessing the data, they are needed to build a model with the potential to accurately predict further observations. If the built model completely fits the training data, it is probably not reliable after deployment in the real world. This problem is called overfit and needs to be avoided. A common manner on how to address the lack of an independent dataset for model evaluation is to reserve part of the learning data for this purpose. The basis for analyzing classifier performance is a confusion matrix (CF). This matrix describes how well a classifier can predict classes.

A typical cross validation technique is the k-fold cross validation. This method can be viewed as a recurring holdout method (holdout method divides the original dataset in two subsets, training and testing datasets). The whole data is divided into k equal subsets and each time a subset is assigned as a test set, the others are used for training the model. Thus, each observation gets a chance to be in the test and training sets; therefore, this method reduces the dependence of performance on test–training split and decreases the variance of performance metrics. Further, the extreme case of k-fold cross validation will occur when k is equal to the number of data points. It would mean that the predictive model is trained over all the data points except by one, which takes the role of a test set. This method of leaving one data point as a test set is known as leave-one-out cross validation (LOOCV) [17]. This technique is show in Figure 4.

**Figure 4.** Leave-one-out cross validation technique (LOOCV). It is a special case of cross validation where the number of folds equals the number of instances in the dataset [17].

As shown in Figure 4, each iteration leaves one observation to test and the others to train. Therefore, the number of iterations is the number of observations in the dataset. The use of the leave-one-out procedure allows the model to be tested with all observations and prevents us from wasting these observations. This method was used to split the original dataset.

A classifier is evaluated based on performance metrics computed after the training process. In a binary classification problem, a matrix presents the number of instances predicted by each of the four possible outcomes: number of true positives (#TP), number of true negatives (#TN), number of false positives (#FP), and number of false negatives (#FN). Most classifier performance metrics are derived from the four values [21]. We used the following metrics in order to improve the accuracy of our model (Equations (3)–(14)) [22].

$$\text{Sensitivity} = \frac{\text{TP}}{\text{TP} + \text{FP}} \tag{3}$$

$$\text{Specificity} = \frac{\text{TN}}{\text{FN} + \text{TN}} \tag{4}$$

$$\text{Prevalence} = \frac{\text{TP} + \text{FP}}{\text{TP} + \text{FN} + \text{FP} + \text{TN}} \tag{5}$$

$$\text{PPV} = \frac{\text{sensitivity} \ast \text{prevalence}}{(\text{sensitivity} \ast \text{prevalence}) + ((1 - \text{specificity}) \ast (1 - \text{prevalence}))} \tag{6}$$

$$\text{NPV} = \frac{\text{specificity} \cdot (1 - \text{prevalence})}{((1 - \text{sensitivity}) \cdot \text{prevalence}) + ((1 - \text{specificity}) \cdot \* (1 - \text{prevalence}))} \tag{7}$$

$$\text{Detection Rate} = \frac{\text{TP}}{\text{TP} + \text{FN} + \text{FP} + \text{TN}} \tag{8}$$

$$\text{Detection Percentage} = \frac{\text{TP} + \text{FN}}{\text{TP} + \text{FN} + \text{FP} + \text{TN}} \tag{9}$$

$$\text{Balanceed Accuracy} = \frac{\text{sensitivity} + \text{specificity}}{2} \tag{10}$$

$$\text{Precision} = \frac{\text{TP}}{\text{TP} + \text{FN}} \tag{11}$$

$$\text{Recall} = \frac{\text{TP}}{\text{TP} + \text{FP}} \tag{12}$$

$$\text{F1} = \frac{(1+\alpha^2) \text{\*} \text{precision} \ast \text{recall}}{(\alpha^2 \text{\*} \text{precision}) + \text{recall}} \tag{13}$$

$$\text{Accuracy} = \frac{\text{TP} + \text{TN}}{\text{TP} + \text{FP} + \text{TN} + \text{FN}} \tag{14}$$

For binary problems, the sensitivity, specificity, positive predictive value, and negative predictive value were calculated using the positive argument. The overall method is shown in Figure 5.

**Figure 5.** The overall method for learning styles classification based on count and time descriptors and targets. The behavior of the student is presented to the Multi Layer Perceptron (MLP) to train and get the weights which explain the labeled dataset.

As shown in Figure 5, the behavior of the 100 students is presented to the Multi Layer Perceptron (MLP) in order to train the neural network. The weight of each synapse (neuron connection) is obtained and the result is compared to the expected outcome (Equations (3)–(14)). When the accuracy (shown in Table 3) is optimal (i.e., without improvement during the training step), the neural network training stops. The pseudocode that explains this method is presented below.


*accuracy* = 0 *variation* = *max while*(*variation* > 0.01){ *present the subset o f train and test to MLP weights* = *weights* − *delta*\_*weights* ∗*learning*\_*rate compare the predicted outcome to expected variation* = *the di f erence o f the predicted and expected update accurary* }

**Table 3.** Metrics of each dimension.


<sup>1</sup> NIR and ACC represent, respectively, No Information Rate (NIR) and Accuracy (ACC).

As shown in Pseudocode 1 all the subsets (train and test) are presented to the MLP to train and test. Accuracy is one of the metrics of Table 3 that is used to build the model, avoiding overfitting and underfitting [17].

### **4. Discussion**

In this section, the results from the experiments are presented and discussed. We initially investigated the aspects of three types of variables: (a) count descriptors, (b) time descriptors, and (c) target descriptors where, this last one is of type count. There were no outliers found in the dataset. The median of the type count descriptors was around four accesses by the element (content\_visit, outline\_visit, etc.), as shown in Figure 4. The time descriptors define the time spent in each element (content\_stay, outline\_stay, etc.). The zero (0) value represents that the element had no access. The median time spent in type time was around 60 s and there was restriction that limited access at 120 s because of a time session limit. Finally, the target variables' median was around 0, which express the balanced learning styles dataset in each dimension (input, processing, perceive, and store), which means that the students are about evenly distributed between active and reflexive classes. These values are shown in Figure 6.

**Figure 6.** Box plot analysis of dataset (count and time descriptors and targets). There is a limitation to the time descriptor by 120 s and count by 10 times.

We also explored the frequency from each preference dimension before target transformation into binary variables (Equation (2)). As a result, we obtained the students' learning styles for each dimension. The dimensions active\_reflective, sensing\_intuitive, and sequential\_global presented an approximately uniform distribution; however, the visual\_verbal dimension presented a concentration close to −5, which represents a preference by the students to acquire visual information. Figure 7 shows this analysis.

We also investigated the possibility of using the dataset to identify students' preferences. If, for example, a determined set of attributes represents one of the four learning dimensions, these attributes may help in the dimensionality reduction and improve classifier precision [7,16]. The groups were investigated using the k-means algorithm to identify natural clusters in dataset. The k-means algorithm was used with k = {2, 3, 4, 5}. As a result, we obtained clusters with two and three groups with low overlap. By using a number of groups up to three, the resulting clusters overlapped. These results are shown in Figure 8.

**Figure 7.** Learning style frequency distribution of each dimension.

**Figure 8.** Natural clusters in dataset using k-means (with k = {2, 3, 4, 5}). No natural cluster was identified and, therefore, the classes are not linearly separated.

Additionally, another analytics technique, known as principal component analysis (PCA), was applied to investigate other relevant attributes or correlations and whether targets, of each dimension, might be explained by some descriptor (count and time). This is an important issue for dimensionality reduction in order to improve accuracy and reduce the cost of the model build [20]. These results are shown in Figure 9.

The dataset is balanced for the dimensions perceive, processing, and store, and presents some variation in the dimension input. In addition, there is not a predominant descriptor, making it possible to use all descriptors for the model construction.

*Appl. Sci.* **2020**, *10*, 1756

**Figure 9.** Principal component analysis (PCA) of each dimension, processing, perception, input, and understanding, respectively.

We may identify the onset of overfitting through the use of leave-one-out (special case of k-fold cross validation), for which the training data are split into an estimation subset and a validation subset. The estimation subset of examples is used to train the network in the usual way, except for a minor modification; the training session is stopped periodically (i.e., every so many epochs), and the network is tested on the validation subset after each period of training.

In our procedure, we varied the numbers of hidden layers for each model to determine a suitable number and provide the optimal result. The best model built presents two hidden layers, 26 neurons of input, and four neurons of output. The resulting model is shown in Figure 10.

This model evaluates all descriptors, simultaneously providing the students' classification in each learning dimension. This is a multi-target prediction algorithm [19]. Equations (4)–(14) were used to evaluate the model. We used the confusion matrix (CF) [17] to classify the predictions. The results of each dimension are shown in Table 3.

The best model built achieved 85%, 76%, 75%, and 80% accuracy in each target attribute, active\_reflective, sensing\_intuitive, visual\_verbal, and sequential\_global, respectively. These results are better than the results presented by Bernard et al. [5] (80%, 74%, 72%, and 82% in the same order) except for sequential\_global, and simultaneously deal with all descriptors and targets instead of one at a time; we generated one model while Bernard et al. [5] generated four models to solve the problem. Moreover, we provided many performance metrics for each dimension to support further research to compare and improve their results (Table 3). In addition, we investigated the CF using area under the curve (AUC) and receiver operating characteristics (ROC). For each target the results were superior to Bernard et al. [5]. Figure 11 shows these results.

**Figure 10.** The built model with 2 hidden layers, 26 neurons of input, and 4 neurons of output.

All metrics indicated that the model might be a method to automatically classify a student in a MOOC environment. The relation between descriptors improves the accuracy (as show in specificity and sensitivity from Table 3). Moreover, multi-target prediction (MTP) is a class of algorithms used with the simultaneous prediction of multiple target variables of diverse types, and the model using the Felder–Silverman procedure is by far, the most popular theory applied in adaptive e-learning system [5]. Meanwhile, from another point of view, the accuracy (and other performance metrics) of the outcomes using the proposed approach could be further improved by the use of a big dataset. Another limitation of the current research is that the results of the experiments were only tested on a platform with a specific subject (computer science and project management). The consistency of performance needs to be tested when it runs with different learning management systems and/or other

online courses (for example, administration, economics, and so forth). Our future work will involve further exploration of the performance metrics and practical implications in different environments.

### **5. Conclusions and Further Development**

This paper presents an automatic approach to identify learning styles of student behavior in a MOOC using a Computer Intelligent Algorithm (CIA) called deep artificial neural network (ANN). Assessment with the data of 100 students was performed, demonstrating the overall accuracy of the approach for each of the four Felder–Silverman learning styles model (FSLSM) learning style dimensions. It can be observed from the results obtained that this approach may be used to identify students' learning style based on their behavior in MOOC. This approach reduces the noise of questionnaires [3–5], allows classification when necessary to check if the style has changed over time [1], and allows data to be stored for future use.

Thus, by identifying students' learning styles, adaptive learning systems can use this information to provide more accurate personalization, leading to improved satisfaction and reduced learning time [3]. In addition, students can directly benefit from the more accurate identification of the learning styles, being able to leverage their strengths in relation to learning styles and understanding their weaknesses. In addition, teachers can use this learning style information to provide students with more accurate advice, which, as pointed out before, becomes more useful for students as learning style identification becomes more accurate as well. Moreover, students with similar learning styles may work together in the same classroom to improve their learning experience and help the teachers with their methods. Additionally, other stakeholders in the education ecosystem, such as parents, teachers, and administrators, can make use of such an approach to improve education in general [2,21].

Suggestions to further works may include the practical application of this approach through MOOC plug-ins. Different algorithms can be tested by comparing their results with works of the artificial neural network, since this work presents reference values based on the confusion matrix that can be replicated to other algorithms. Social issues can also be investigated to identify whether they influence learning styles. Concept drift (CD) should be investigated to identify if the target variables modify over time and compare the results to learning process questionnaires (LPQs). Finally, we may investigate how learning styles work in the propagation of information in networks based on complex networks [17]. This is an important topic, with great impact in real-world applications because it is a base to recommendation systems, and may be used to improve students' learning processes.

**Author Contributions:** This paper was developed with the active contribution of the doctorate candidate E.G., the orientation of L.d.S.M. and R.M.d.B. L.d.S.M. and R.M.d.B. are both advisors of the candidate E.G. and in addition to heading the research group, are mainly responsible for the propositions and hypotheses being proposed and tested. L.d.S.M. research areas are in Intelligent Cities and applied problems in Computational Intelligence. R.M.d.B. main research fields are in Governance, Strategic Planning, and applied Computational Intelligence. E.G. main research field is applied Computational Intelligence for Smart Cities. His main contribution to the paper was in processing data with computational intelligence techniques and mathematical models. All participants gave important contributions to the formulation of the hypotheses and the definition of the data analytics for the educational dataset. Based on The CRediT Roles, the contributions are E.G.: conceptualization, data curation, formal analysis, investigation, methodology, software, and writing—draft; R.M.d.B. and L.d.S.M.: project administration, resources, supervision, validation, writing—review and editing. All authors have read and agreed to the published version of the manuscript.

**Funding:** This research received no external funding.

**Conflicts of Interest:** The authors declare no conflicts of interest.

#### **References**


© 2020 by the authors. Licensee MDPI, Basel, Switzerland. This article is an open access article distributed under the terms and conditions of the Creative Commons Attribution (CC BY) license (http://creativecommons.org/licenses/by/4.0/).

## *Article* **Transfer Learning from Deep Neural Networks for Predicting Student Performance**

### **Maria Tsiakmaki, Georgios Kostopoulos, Sotiris Kotsiantis \* and Omiros Ragos**

Department of Mathematics, University of Patras, 26504 Rio Patras, Greece; m.tsiakmaki@gmail.com (M.T.); kostg@sch.gr (G.K.); ragos@math.upatras.gr (O.R.)

**\*** Correspondence: sotos@math.upatras.gr

Received: 14 February 2020; Accepted: 19 March 2020; Published: 21 March 2020

**Abstract:** Transferring knowledge from one domain to another has gained a lot of attention among scientists in recent years. Transfer learning is a machine learning approach aiming to exploit the knowledge retrieved from one problem for improving the predictive performance of a learning model for a different but related problem. This is particularly the case when there is a lack of data regarding a problem, but there is plenty of data about another related one. To this end, the present study intends to investigate the effectiveness of transfer learning from deep neural networks for the task of students' performance prediction in higher education. Since building predictive models in the Educational Data Mining field through transfer learning methods has been poorly studied so far, we consider this study as an important step in this direction. Therefore, a plethora of experiments were conducted based on data originating from five compulsory courses of two undergraduate programs. The experimental results demonstrate that the prognosis of students at risk of failure can be achieved with satisfactory accuracy in most cases, provided that datasets of students who have attended other related courses are available.

**Keywords:** transfer learning; deep learning; educational data mining; student performance prediction

#### **1. Introduction**

Transferring knowledge from one domain to another has gained a lot of attention among scientists in the past few years. Consider the task of predicting student performance (pass/fail) in higher education courses. According to the traditional supervised learning approach, a sufficient amount of training data, regarding a specific course *CA*, is required for building an accurate predictive model which is subsequently used for making predictions on testing data derived from the same course. If the testing dataset is derived from a different course, *CB*, sharing some common characteristics with course *CA* (hereinafter referred to as related or similar courses), then transfer learning is the appropriate machine learning methodology for building accurate learning models in a more efficient manner, since it could contribute to the improvement of the predictive performance of the target domain model (course *CB*) exploiting the knowledge of the source domain (course *CA*) [1]. In a nutshell, a learning model is trained for a specific task using data derived from a source domain and, subsequently, it is reused for another similar task in the same domain or the same task in a different domain (target domain) [2,3]. More generally, when we lack information about a problem, we could train a learning model for a related problem, for which there is plenty of information, and apply it to the existing one.

Transfer learning is currently gaining popularity in deep learning [4]. Not long ago, it was claimed as the second "driver of machine learning commercial success", whereas supervised learning was the first one [5]. Pre-trained deep networks, usually trained on large datasets and thus requiring significant computation time and resources, are employed as the starting point for other machine learning problems due to their ability to be repurposed either for a new or for a similar task. Therefore, these networks could support complex problems in a more efficient way, since they can decrease the training time for building a new learning model and finally improve its generalization performance [6].

In recent years, several types of Learning Management Systems (LMSs) have been successfully adopted by universities and higher education institutions, recording a variety of student learning features and gathering huge amounts of educational data. Educational Data Mining (EDM) is a fast-growing scientific field offering the potential to analyze these data and harness valuable knowledge from them. To this end, a plethora of predictive algorithms have been effectively applied in educational contexts for solving a wide range of problems [7]. However, building predictive models in the EDM field through transfer learning methods has been poorly studied so far. Therefore, the main question in the present study is whether a predictive model trained on a past course would perform well on a new one. Boyer and Veeramachaneni observe that courses (a) might evolve over time in a dissimilar way, even if they are not much different in terms of context and structure, (b) are populated with different students and instructors, and (c) might have features that cannot be transferred (e.g., a feature defined on a specific learning resource which is not available on another course) [8]. In addition, the complexity of LMSs as well as the course design have a significant impact on the course progress during the semester [9]. Therefore, there may be problems where transfer learning might not reflect the anticipating results, showing some uncertainty about the predictive accuracy of the newly created learning model [10].

In this context, the present study aims to propose a transfer learning methodology for predicting student performance in higher education, a task that has been extensively studied in the field of EDM through traditional supervised methods. To this purpose, we exploit a set of five datasets corresponding to five undergraduate courses, each one lasting one semester, all supported by a Moodle platform. Initially, we form all the unique pairs of datasets (twenty pairs in total) matching the features of the paired courses one by one and generating new features if necessary. Next, a deep network model is trained by using the dataset of the first course and, subsequently, it is applied on the dataset of the second course for further training after a predefined number of epochs. Deep networks have been successfully applied in the EDM field for solving important educational problems, such as predicting student performance [11–14], dropout [15–17], or automatic feature extraction [18]. The main objective is to discover whether transfer learning accelerates training and improves the predictive performance utilizing the potential of deep neural networks in the EDM field. On this basis, we hope to provide a useful contribution for researchers.

The remainder of this paper is organized as follows. In the next section, we discuss the transfer learning approach, while in Section 3 we present an overview of some related studies in the EDM field. The research goal, together with an analysis of the datasets and description of the proposed transfer learning method, is set in Section 4. The experimental results are presented in Section 5, while Section 6 discusses the research findings. Finally, Section 7 summarizes the study, considering some thoughts for future work.

#### **2. The Transfer Learning Approach**

The traditional supervised learning methods exploit labeled data to obtain predictive models in the most efficient way. Let us consider the task of predicting whether a student is going to successfully pass or fail the examinations of an undergraduate course *CA*. In this case, the training and the testing set are both derived from the same domain (course). The training set is used to build a learning model *h* by means of a classification algorithm (e.g., a deep network) and subsequently it is applied on the testing set for evaluating its predictive performance (Figure 1a). Some key requirements for achieving high performance models are the quality and sufficiency of the training data which are, unfortunately, not always easy to meet in real world problems. In addition, the direct implementation of model *h* for a different course *CB* or a new task (e.g., predicting whether a student is going to drop out of the course) seems rather difficult. The existing model does not have the ability to generalize well to data

coming from a different distribution, while, at the same time, it is not applicable, since the class labels of the two tasks are different.

**Figure 1.** The traditional machine learning process (**a**), the transfer learning process (**b**).

Contrasting these methods, knowledge transfer or transfer learning intends to improve the performance of learning and provide efficient models in cases where data sources are limited or difficult and expensive to acquire [1,2], primarily due to their generalization ability to heterogeneous data (i.e., data from different domains, tasks and distributions [19]). Transfer learning might help us to train a predictive model *h* based on data derived from course *CA* (source course) and apply it on data derived from a different but related course *CB* (target course), which are not sufficient to train a model, for predicting the performance of a student. This indeed, is the aim of transfer learning: transfer the knowledge acquired from course *CA* to course *CB* and improve the predictive performance of model *h* (Figure 1b) instead of developing a totally new model, on the basis that both datasets should share some common attributes (i.e., common characteristics of students, such as their academic achievements or interactions within an LMS).

More formally, the transfer learning problem is defined as follows [1,20]:

A domain D is formed by a feature space X and a marginal probability distribution *P*(*X*), where *X* = {*x*1, *x*2, ... , *xn*} ∈ X. A learning task T is formed by a label space Υ and an objective predicted function *f*(·). The function *f* can be also written as *P*(*y <sup>x</sup>*), representing the conditional probability distribution of label *y* given a new instance *x*. *P*(*y <sup>x</sup>*) is learned from the training data {X, <sup>Υ</sup>}. Given a source domain D*<sup>S</sup>* = X*S*, *PS*(*X*) , its corresponding learning task T*<sup>S</sup>* = ΥS, *fS*(·) , a target domain D*<sup>T</sup>* = X*T*, *PS*(*X*) and its corresponding learning task T*<sup>T</sup>* = ΥT, *fT*(·) , the purpose of transfer learning is to obtain an improved target predictive function *fT*(·) by using the knowledge in D*<sup>S</sup>* and T*S*, where D*<sup>S</sup>* - D*<sup>T</sup>* and T*<sup>S</sup>* - T*T*. The fact that D*<sup>S</sup>* - D*<sup>T</sup>* means that either X*<sup>S</sup>* - X*<sup>T</sup>* or *P*(*XS*) - *P*(*XT*), where *XSi* ∈ X*<sup>S</sup>* and *XTi* ∈ X*T*. Similarly, the fact that T*<sup>S</sup>* - T*<sup>T</sup>* means that either Υ<sup>S</sup> - Υ<sup>T</sup> or *fS*(·) *fT*(·).

The inequalities contained in the definition form four different transfer learning settings:


Based on the above definition and conditions, three types of transfer learning settings are identified [1,8,18,21]: inductive transfer learning, transductive transfer learning and unsupervised transfer learning. In inductive transfer learning, the source domain is different but related to the target domain (D*<sup>S</sup>* - D*T*) regardless of the relationship between the tasks. In transductive transfer learning, both source and target task are the same (T*<sup>S</sup>* = T*T*), while the domains are different (D*<sup>S</sup>* - D*T*). Finally, in unsupervised transfer learning, the tasks are different (T*<sup>S</sup>* - T*T*), while both datasets do not contain labels. The latter type is intended for clustering and dimensionality reduction tasks.

### **3. Related Work**

Predicting students' learning outcomes is considered one of the major tasks of the EDM field [22]. This is demonstrated by a great number of significant studies which put emphasis on the development and implementation of data mining methods and machine algorithms for resolving a plethora of predictive problems [23]. These problems are mainly intended to predict the future value of an attribute (e.g., students' grades, academic performance, dropout, etc.) based on a set of input attributes that describe a student. One typical problem is to detect whether a student is going to successfully pass or fail a course by the end of a semester based on his/her activity on the LMS, as in this study. The successful and accurate detection of students at risk of failure is of vital importance for educational institutions, since remedial measures and intervention strategies could be applied to support low performers and enhance their overall learning performance [24]. It is therefore necessary to build very accurate and robust learning models. Transfer learning could contribute to improving these models, since prior knowledge regarding a specific task could be useful to another similar task. Transfer learning is an approach which has still not been sufficiently examined in the field of EDM, as evidenced by the study of the current literature. To the best of our knowledge, there are few studies focusing on resolving prediction problems through transferring learning models from one domain to another, although this prospect is appealing. These studies indicate that building models based on a particular course and them applying to a new one (different but somehow related) is a rather complex task, which, unfortunately, does not always reflect the anticipating outcomes [10]. A list of some notable works regarding transfer learning in the EDM field are presented in the following paragraphs.

Ding et al. investigated the transferability of dropout prediction across Massive Online Open Courses (MOOCs) [9]. Therefore, they presented two variations of transfer learning based on autoencoders: (a) using the transductive principal component analysis, and (b) adding a correlation alignment loss term. The input data were click-stream log events of mixtures of similar and dissimilar courses. The proposed transfer learning methods proved to be quite effective for improving the dropout prediction, in terms of Area Under Curve (AUC) scores, compared to the baseline method. In a similar study, Vitiello et al. [25] examined how models trained on a MOOC system could be transferred to another. Therefore, they built a unified model allowing the early prediction of dropout students across two different systems. At first, the authors confirmed significant differences between the two systems, such as the number of active students and the structure of courses. After that, they defined a set of features based on the event logs of the two systems. Overall, three dropout prediction experiments were conducted: one for each separate system, one where each system applied a learning model built on the other system and one where the dataset contained data from both systems. The accuracy measure was above the baseline threshold (0.5) in most cases.

The method put forward by Hunt et al. [26] examined the effectiveness of TrAdaBoost, an extended AdaBoost version in the transfer learning framework, for predicting students' graduation rates in undergraduate programs. The dataset was based on a set of academic and demographic features (152 features in total) regarding 7637 students of different departments. Two separate experiments were conducted, each time using specific data for the training set. In the first experiment, the training set comprised all students apart from those studying engineering, while in the second one, the training set comprised all students that were suspended on academic warnings. The experimental results showed that the TrAdaBoost method recorded the smallest error in both cases. In the same context, Boyer and Veeramachaneni suggested two different approaches for predicting student dropout taking into account the selection method of the training data and how to make use of past courses information [8]. Therefore, several tests were performed using either all available information for a learner or a fixed subset of them. In addition, two different scenarios were formulated: inductive and transductive transfer learning. The experimental results indicated that the produced learning models did not always perform as intended. Very recently, Tri, Chau and Phung [27] proposed a transfer learning algorithm, named CombinedTL, for the identification of failure-prone students. Therefore, they combined a case-based reasoning framework and four instance-based transfer learning algorithms (MultiSource, TrAdaboost, TrAdaboost, and TransferBoost). The experimental results showed that the proposed method outperformed the single instance-based transfer learning algorithms. In addition, the authors compared the CombinedTL with typical case retrieval methods (k-NN and C4.5), experimenting with a varying number of target instances, finding that the performance of the proposed method was improved as the number of target instances was increased.

The notion of domain adaptation is highly associated with transfer learning. Zeng et al. [28] proposed a self-training algorithm (DiAd) which adjusts a classifier trained on the source domain to the target domain based on the most confident examples of the target domain and the most dissimilar examples of the source domain. Moreover, the classifier is adjusted to the new domain without using any labeled examples. Very recently, López-Zambrano et al. [29] investigated the portability of learning models based on Moodle log data regarding the courses of different universities. The authors explored whether the grouping of similar courses (i.e., similarity level of learning activities) influence the portability of the prediction models. The experimental results showed that models based on discretized datasets obtained better portability than those based on numerical ones.

### **4. Research Methodology**

### *4.1. Research Goal*

The main purpose of our study is to evaluate the effectiveness of transfer learning methods in the EDM field. More specifically, we investigate whether a deep learning model that has been trained using student data from one course can be repurposed for other related courses. Deep neural networks are represented by a number of connecting weights between the layers. During the training process, these weights are adjusted in order to minimize the error of the expected output. Therefore, the main notion behind the suggested transfer learning approach is to initialize a deep network using the pre-tuned weights from a similar course. Two main research questions guide our research:

(1) Can the weights of a deep learning model trained on a specific course be used as the starting point for a model of another related course?

(2) Will the pre-trained model reduce the training effort for the deep model of the second course?

### *4.2. Data Analysis*

In the present study, we selected data regarding five compulsory courses of two undergraduate programs offered by the Aristotle University of Thessaloniki in Greece. More precisely, three courses (Physical Chemistry I (Spring 2018) and Analytical Chemistry Laboratory (Spring 2018, Spring 2019)) were offered by the department of Chemical Engineering, while two courses (Physics III (Spring 2018, Spring 2019)) were offered by the department of Physics. Table 1 provides detailed information regarding the gender and target class distribution of the five courses.


**Table 1.** Gender and target class distribution of the courses.

Each course was supported by an online LMS, embedding a plethora of resources and activities. The course pages were organized into topic sections containing the learning material in the form of web pages, document files and/or URLs, while the default announcements forum was enabled for each course allowing students to post threads and communicate with colleagues and tutors. Each course required the submission of several assignments, which were evaluated on a grading scale from zero to 10. All sections were available to the students until the end of the semester, while the course final grade corresponded to the weighted average of the marks of all submitted assignments and the finishing exam. Note that successful completion of the course required a minimum grade of five.

For the purpose of our study, the collected datasets comprised six different types of learning resources: forums, pages, recourses, folders, URLs and assignments (Table 2). For example, course *C*<sup>1</sup> was associated with one forum, seven pages, 17 resources, two folders and eight assignments, three of which were compulsory. Regarding the forum module, we recorded the total number of views for each student. We also recorded the total number of times students accessed a page, a resource, a folder or a URL. Moreover, two counters were embedded in the course LMS, aggregating the number of student views (course total views) as well as the number of every type of recorded activity for a student (course total activity). Learning activities that were not accessed by students were not included in the experiments, while a student who did not access a learning activity was marked with a zero score. Finally, a custom Moodle plugin was developed, enabling the creation of the five datasets [30].


**Table 2.** Features extracted from students' low-level interaction logs.

It is worth noting that there were certain differences among the five courses (Tables 1 and 2). At first, they were offered by different departments (Physics and Chemical Engineering) and they had different format and content. Although courses *C*2, *C*<sup>4</sup> and *C*3, *C*<sup>5</sup> encompassed the same topic—that is, Physics and Chemistry, respectively—their content varied depending on the academic year of study. In addition, courses *C*1, *C*2, *C*<sup>4</sup> were theoretical (Physical Chemistry and Physics), while *C*3, *C*<sup>5</sup> were laboratory courses (Analytical Chemistry Lab). Moreover, each course required the submission of a different number of assignments. Finally, it should be noted that different students attended these courses.

### *4.3. The Proposed Transfer Learning Approach*

The present study intends to address the problem of transferring knowledge across different undergraduate courses. Hence, we employed a simple deep neural network architecture, comprised four layers: an input layer, two hidden dense layers and an output one. The input layer consists of input units corresponding to each one of the dataset input features (Table 3). The first hidden layer has 12 hidden units and the second one has eight. Both dense layers use the Relu activation function. Finally, the output layer consists of a single neuron employing the sigmoid activation function and the binary cross entropy loss function for predicting the output class (pass or fail).

The experimental procedure was divided into three distinct phases (Figure 2). In the first phase, we constructed all the unique pairs of courses that could be formed (ten pairs of courses in total). Each time, the related datasets were rebuilt to share a common set of features. For each pair of courses, we made use of the following notation:

$$\left\{\mathbb{C}\_{i}, \mathbb{C}\_{j}\right\} \; i, j \in \{1, 2, 3, 4, 5\}, \; i \neq j. \tag{1}$$


**Table 3.** Features of the paired datasets.

**Figure 2.** The three-step process of the proposed method.

In order to create a common set of features for each pair of courses, we matched features of the first course to related features of the second course one by one. Among the common features were the gender as well as the course total activity and course total views counters. Therefore, the first assignment of the first course was matched with the first assignment of the second course, the second assignment of the first course was matched with the second assignment of the second course and so forth, while the same procedure was followed for all the six types of resources. In cases where a matching feature was not found, a new feature was created, with zero values for each instance. For example, the *C*<sup>1</sup> course contained features related to seven page resources, whereas the *C*<sup>2</sup> course contained features related to six page resources (Table 2). Finally, the new {*C*1, *C*2} pair of datasets contained seven features regarding the page resources, since a new empty feature was created and added in the *C*<sup>2</sup> course dataset, thus matching to the seventh feature of the *C*<sup>1</sup> course (Table 3).

The second phase refers to the training process of the two supporting deep networks. The first one was trained on the new source course *Ci* in order to extract its adjusted weights, while the second one was trained on the new target course *Cj* in order to calculate the baseline evaluation. In both cases, we calculated the accuracy metric, which corresponds to the percentage of correctly classified instances, while the models were trained for 150 epochs. In addition, the 10-fold cross validation resampling procedure was adopted for evaluating the overall performance of the deep network models.

The third phase was the most fundamental, since it implemented the transfer learning strategy. The deep model of the target course was fitted from scratch, but this time the network weights were initialized using the previously calculated weights from the source course (second phase). The pre-trained model was further tuned by running it each time for a certain number of epochs (hereinafter denoted as *Ci*,*j*): zero (i.e., the starting point), 10, 20, 30, 40, 50, 100 and 150. Algorithm 1 provides the pseudocode of the proposed transfer learning method. All the experiments were conducted using the Keras library in Python [31].


### **5. Results**

The averaged accuracy results (over the 10 folds) are presented in Table 4. For each pair, we conducted two experiments, using each course alternatively as the source course and the other one as the target course. Therefore, we evaluated 20 distinct combinations formed by the five courses. For each pair, we highlighted in bold the cases where the transfer model produced better results than the baseline. Overall, it is observed that the model *Ci*,*<sup>j</sup>* benefits the predictions of the source course *Ci*, since the predictive performance of the transfer learning deep network is better than the baseline *Cj*.


**Table 4.** Averaged accuracy results.

A one-tailed, paired t-test (*a*=0.05) was conducted for verifying whether the improvement in the transfer model was statistically significant. Therefore, we compared the accuracy results obtained by the baseline deep network (using the target course dataset), with the results obtained by the transfer method, iteratively, for each number of epochs. Since the p-value is inferior or equal to 0.05, we conclude that the difference is significant in all cases except the starting point where the number of epochs equals zero (Table 5). Moreover, the *p*-value is gradually decreased as the number of epochs increase from every epoch from 10 to 100.



The analysis of the experimental results, in question-and-answer format, underlines the efficiency of the proposed method for transferring knowledge from one course to a related one.

1. Can the weights of a deep learning model trained on a specific course be used as the starting point for a model of another related course?

At the starting point for each transfer learning model (i.e., zero epochs) we used the weights estimated by the previously trained deep network models (on 150 epochs) instead of starting with randomly initialized weights. For example, at the starting point of the *C*4,2 transfer model, we used the weights estimated by the *C*<sup>4</sup> model.

Comparing the results of the pretrained weights without further tuning (i.e., zero epochs) to the baseline model, an improvement is noticed in half of the datasets (10 out of 20). The statistical results (t-test) confirm that the difference is not significant when the pre-trained model is not further tuned for the second dataset (target course *Cj*), since p-value=0.2449>*a*=0.05. However, the transfer model prevails in 16 out of 20 datasets when it is further tuned for only 10 epochs.

2. Will the pre-trained model reduce the training effort for the deep model of the second course?

Overall, the increase in the number of epochs improves the performance of the proposed transfer learning model. Moreover, the improvement is significant for every number of epochs, apart from the starting point, as statistically confirmed by the t-test results. It is worth noting that the transfer model prevails in 18 out of 20 datasets after 100 epochs, where the lowest p-value is 0.0002.

In addition, we can detect three cases of overfitting, since the accuracy ceases to improve after a certain number of iterations and begins to decrease. Particularly, this is observed in the cases where *C*<sup>1</sup> starts with *C*<sup>2</sup> weights, *C*<sup>2</sup> with *C*<sup>1</sup> weights and *C*<sup>4</sup> with *C*<sup>1</sup> weights. For instance, *C*<sup>1</sup> outperforms the baseline with an accuracy measure of 0.7768 after 100 epochs of retuning the preloaded weights of *C*2. However, after 150 epochs the accuracy is decreased to 0.7552.

### **6. Discussion**

An important finding to emerge in this study is that even a small amount of prior knowledge from a past course dataset could result in a fair measure of accuracy for predicting student performance in a related current course. This was verified by a plethora of experiments that have been carried out regarding twenty different pairs of five distinct one-semester courses, investigating the effectiveness of transfer learning in deep neural networks for the task of predicting at-risk students in higher education. In most cases, the transfer model obtained better accuracy than the baseline one. An improvement was noticed in half of the datasets (10 out of 20) using the pretrained weights from the source course (i.e., zero epochs). There was also a considerable accuracy improvement in most cases (16 out of 20) when the pre-trained model was further tuned for 10 to 40 epochs. Therefore, fine-tuning provides a substantial benefit over training with random initialization of the weights, thus leading to higher accuracy with fewer passes over the data. Overall, there was only one case where the transfer learning did not achieve better results (*C*5,4). Hence, it is evident that it is not always feasible to transfer knowledge from one course to another one. In addition, it is worth noting that the type of course, laboratory or theoretical, does not seem to directly affect the predictive accuracy of the transfer learning model. This indicates that there is a slight uncertainty about the transferability level of a predictive model. The definition of what is a "transferable" model is where this ambiguity lies. A model trained on a set of courses is considered to be "transferable" if it achieves respectively fair results on a new, related course [10].

We believe this is yet another important attempt towards transfer knowledge in the educational field. Further, there are key issues to be considered such as measuring the degree of similarity between two courses (i.e., the number and form of learning activities), the type of attributes and the duration of the course. Finally, it is similarly important to build both simple and interpretable transferable models that could be easily applied by educators from one course to another [29]. Therefore, more studies are required on the current topic for establishing these results.

### **7. Conclusions**

In the present study, an effort was made to propose a transfer learning method for the task of predicting student performance in undergraduate courses. The identification of failure-prone students could lead the academic staff developing learning strategies that aim to improve students' academic performance [32]. Transfer learning enables us to train a deep network using the dataset of a past course (source course) and reuse it as the starting point for a dataset of a new related course

(target course). Moreover, it is possible to further tune the repurposed model. Our findings proved that a fair performance was achieved in most cases, while the proposed method handily outperforms the baseline model.

Transfer learning offers many future research directions. Our results are encouraging and should be validated by larger samples of courses from different departments and programs. An interesting task is to apply a model for a specific task, such as the prediction of student's performance, for another related task, such as the prediction of student's dropout or for regression tasks (e.g., for predicting students' grades). In a future work we will also investigate the efficiency of transfer learning in imbalanced datasets obtained from several educational settings. If someone has only the target task, but also has the ability to choose a limited number of additional training data to collect, then active learning algorithms can be used to make choices that will improve the performance on the target task. These algorithms may also be combined into active transfer learning [33].

**Author Contributions:** Conceptualization, M.T. and S.K.; methodology, S.K.; software, M.T.; validation, G.K., S.K. and O.R.; formal analysis, M.T.; investigation, M.T.; resources, S.K.; data curation, G.K.; writing—original draft preparation, M.T.; writing—review and editing, G.K.; visualization, M.T.; supervision, S.K.; project administration, O.R.; funding acquisition, S.K. All authors have read and agreed to the published version of the manuscript.

**Funding:** This research received no external funding.

**Conflicts of Interest:** The authors declare no conflict of interest.

### **References**


© 2020 by the authors. Licensee MDPI, Basel, Switzerland. This article is an open access article distributed under the terms and conditions of the Creative Commons Attribution (CC BY) license (http://creativecommons.org/licenses/by/4.0/).

## *Article* **Analysis of Cross-Referencing Artificial Intelligence Topics Based on Sentence Modeling**

**Hosung Woo 1, JaMee Kim <sup>2</sup> and WonGyu Lee 3,\***


Received: 5 May 2020; Accepted: 25 May 2020; Published: 26 May 2020

**Abstract:** Artificial intelligence (AI) is bringing about enormous changes in everyday life and today's society. Interest in AI is continuously increasing as many countries are creating new AI-related degrees, short-term intensive courses, and secondary school programs. This study was conducted with the aim of identifying the interrelationships among topics based on the understanding of various bodies of knowledge and to provide a foundation for topic compositions to construct an academic body of knowledge of AI. To this end, machine learning-based sentence similarity measurement models used in machine translation, chatbots, and document summarization were applied to the body of knowledge of AI. Consequently, several similar topics related to agent designing in AI, such as algorithm complexity, discrete structures, fundamentals of software development, and parallel and distributed computing were identified. The results of this study provide the knowledge necessary to cultivate talent by identifying relationships with other fields in the edutech field.

**Keywords:** Machine learning analysis; sentence modeling; topic analysis; cross referencing topic

### **1. Introduction**

Information technology (IT) is driving changes in society and leading to a new paradigm shift in national development across the globe. Among them, artificial intelligence (AI)-related research and talent cultivation are becoming the basis for national development, while emerging as a competitive edge. Through AI.gov, the United States has promoted AI research and development at the government level since February 2019 [1], and at the Future Strategy Innovation Conference in March 2019, Japan presented AI, quantization, and biotechnology as three strategic technologies that will lead Japan's development [2]. South Korea's Ministry of Science and ICT also announced its plans to foster AI talent by 2022 according to the "AI Research & Development Strategy" of May 2018 [3].

These interests in AI from different countries are motivating talent cultivation. This is seen in the creation of AI degree programs at universities, the dedication of new colleges to AI, the design of new AI degree programs, etc. The U.S. has already implemented talent development plans at universities [4,5], while Japan has introduced its future strategic plans through linking AI with primary and secondary education. China, through the "New Generation Artificial Intelligence Development Plan" and the "AI Innovation Action Plan" of 2017 and 2018, has begun AI-focused primary and secondary education where students study the core technologies of AI [6,7]. India also announced an AI curriculum inclusion for eighth and ninth graders at the 2019 Central Board of Secondary Education [8].

As a way of implementing various AI professional policies, AI is being taught at the primary school level; however, to date, no core knowledge has been defined regarding AI education. In the absence of these core content standards, the structures of AI education vary among different educational institutions [9,10]. In a similar context, the British Prime Minister emphasized the importance of establishing rules and standards for AI technologies at the 2018 World Economic Forum. Although AI is now in the spotlight, both academically and economically, it is widely regarded as a union of knowledge containing topics of various fields. Therefore, developing a body of knowledge in the field of AI will also help in establishing an academic foundation for AI.

A body of knowledge is a reconstruction of the knowledge area (KA) based on the knowledge that experts in academic fields must obtain. It is important for the topics covered in the knowledge areas to be constructed in a way that they correlate with other areas or topics. This means that the body of knowledge based on semantic relations can be said to correspond to the overall knowledge that one must acquire in the given academic field.

Therefore, this study focused on deriving areas and topics that correlate with the field of AI. To fulfill its goals, this study derived knowledge areas and topics using sentence models. AI's body of knowledge extracted from sentence models will provide implications for which research topics should be continued from the academic perspective. In particular, by identifying relationships with other areas, it will also contribute to constructing the knowledge that is required to develop professionals at the university level.

#### **2. Related Research**

A body of knowledge analysis aims to identify the hierarchies of knowledge areas or relationships among topics. In some cases, a visual representation is used for a clear knowledge sequence or secondary utilization of knowledge. Analysis and reorganization of a body of knowledge should either be performed by experts or computer systems. Using an expert can be expensive and time-consuming, and maintaining consistency is quite challenging. Therefore, computer systems are used instead. This section discusses previous studies on body of knowledge analysis and sentence modeling using computer systems.

### *2.1. Sentence Modeling*

Analyzing the body of knowledge for each academic field takes curriculum composition or evaluation into account. Various types of research related to body of knowledge analysis have been conducted. The main types are discussed below.

The first type of body of knowledge analysis research is category-based research. Based on knowledge about the units of the Computing Curricula 2001, which can be regarded as the computer science (CS) body of knowledge, research was performed to set the flow of the teaching syllabus. A syllabus-maker that can develop or analyze a teaching syllabus has been proposed [11]. By using the body of knowledge of Computer Science Curricula 2013 (CS2013), the distribution of knowledge areas for a teaching syllabus has also been predicted [12]. This combined learners with learner information in an attempt to provide them with personalized learning paths. Further, Ida (2009) proposed a new analysis method, called the library classification system, that is based on the classification information in the curriculum's body of knowledge, instead of the body of knowledge.

In the abovementioned studies, analysis was performed based on categories such as KA and knowledge units (KU), which do not take specific topics or the contents of each body of knowledge into account. Hence, although they can be useful in identifying the entire frame or hierarchy of a body of knowledge, they have a limited ability to identify meaning based on the detailed contents of the subject unit.

The second type of body of knowledge analysis research is word-based research, which uses stochastic topic modeling. Topic modeling estimates topics based on the distribution of words contained in a document. Latent Dirichlet allocation (LDA) is one of the popular methods used in topic modeling [13]. Sekiya (2014) used supervised LDA (sLDA) to analyze the changes in the body of knowledge of CS2008 and CS2013 [14]. More specifically, ten words that could represent each KA in the body of knowledge were extracted and compared based on common words. Another study was based on the CS2013 body of knowledge, which applied an extended method called the simplified, supervised latent Dirichlet allocation (ssLDA) [15]. Syllabus topics from the top 47 universities worldwide in 2014–2015 were quantified. Subsequently, based on the related topics, features of similar topics were analyzed.

In 2018, the Information Processing Society of Japan presented the computer science body of knowledge with 1540 topics. The presented topics ranged from approximately one to five words. For example, if the topic called "Finite-state Machine" is processed on a word basis, the result would include various machine-related topics such as Turing machines, assembly level machines, and machine learning. This means that it is difficult to extract the exact meaning from the given topic. As with previous studies, unigram-based research through topic modeling is suitable for classifying words that appear in the body of knowledge by topic, but it can be semantically limited as it treats topics segmentally.

For analysis based on an accurate understanding of the body of knowledge, it is necessary to apply unprocessed topics. This would help in providing a clear meaning of the topics as well as to guess the relationships among the words. Therefore, this study uses sentence modeling to find topic meanings in the body of knowledge and to identify the relationships among these topics.

### *2.2. Knowledge Areas Analysis Research*

Sentence modeling is an important problem in natural language processing [16]. It allows the insertion of sentences into vector spaces and uses the resulting vectors for classification and text generation [17]. Convolutional neural networks (CNNs) and recurrent neural networks (RNNs) are generally used in sentence modeling approaches.

CNNs that save local information about sentences use filters, which are applied to local characteristics on the layers [18]. Figure 1 depicts a simple CNN structure with a convolutional layer on top of a word vector obtained from an unsupervised neural language model. This model achieved several excellent benchmark results even without parameter tuning [19].

**Figure 1.** Structure of convolutional neural networks (CNN).

The CNN models used in early computer vision are also effective in natural language processing, and they have been applied in semantic parsing, search query retrieval [20], sentence modeling [21], and other traditional natural language processing (NLP) works [22].

RNNs process word inputs in a particular order and learn from the order of appearance of particular expressions. Thus, modeling can process semantic similarities between sentences and phrases. RNNs can also use their own models, but they are used as an extension of the modified model. Manhattan long short-term memory (LSTM), which is a modified RNN model, uses two LSTMs to read and process word vectors that represent two input sentences. This model has been shown to outperform the complex neural network model. The structure of the Manhattan LSTM is shown in Figure 2 [23]. The main feature of this model is that LSTM and the Manhattan metric are used based on a Siamese network structure that contains two identical subnetwork components. The hidden state h1 is learned through word vectors and randomly generated weights. Then, the hidden state sentence ht is generated based on the input function using hidden state ht-1 and position t. Subsequently, the semantic similarity of sentences is measured using the vectors of the final hidden state. For example, "he," "is," and "smart" are the words in the vectors xi; x1 is the input vector of h1 and is used to calculate the status value of h1; h2 is calculated by referring to the previous state value and x2; and the final hidden state, h3, is calculated through x3 and the previous state value h2. "a," "truly," "wise," and "man" are treated in the same manner, and similarities are measured with the final vectors calculated by processing two sentences.

**Figure 2.** Structure of Manhattan LSTM.

In addition to the abovementioned models, modified RNN models such as bidirectional LSTM, multidimensional LSTM [24], gated recurrent unit [25], and recursive tree structure [26] have been applied to text modeling through the existing model architectural modifications. RNN-based models are generally calculated from the word order of the input and output sentences. Sequential processing in the model learning process can reflect the syntax and semantics of the given sentences, but it has a slow computational speed and parallelism, which is a limitation.

There are also structures that allow modeling dependencies regardless of the length of the input or output sentences for computation efficiency. Figure 3 shows the structure of the transformer model, which is based only on attention mechanisms without using CNN or RNN [27]. Attention mechanisms refer back to the entire input sentence of the encoder at every time-step of the output word prediction from the decoder. At this stage, the entire input sentence is not considered. Instead, the focus is on the input words at a specific time based on the similarity. The transformer model allows parallelism, and it can emphasize the values that are the most closely related to the query through "attention" in the encoder and decoder.

**Figure 3.** Structure of the transformer model.

The CNN, Manhattan LSTM, and transformer models complement each other's strengths and weaknesses. Therefore, it is essential to implement all three models to evaluate the semantic similarity of body of knowledge topics and subsequently select the model with the highest accuracy for conducting cross-referencing among the topics.

#### **3. Methods**

### *3.1. Experimental Procedure*

AI is an artificial implementation of human intelligence, partially or fully, based on the broad concept of "smart" computers. Intelligent systems (IS) operate in the same manner as AI, but they do not feature deep neural networks that support self-learning [28]. In the field of CS, IS contain AI-related contents. In this study, an IS was used to derive the body of knowledge of AI. The procedure used to derive cross-references among topics presented in the IS knowledge area of the CS field was as follows.


### *3.2. Subject of Analysis*

"Computer Science Curricula CS2013," which is a standard body of knowledge of the CS field presented by ACM and IEEE Computer Association, was selected. The body of knowledge CS2013 comprises eighteen KA, and each KA contains approximately ten KUs, as shown in Figure 4. Each KU subject has three tiers: tier 1 covers the basic introductory concepts, tier 2 covers the undergraduate-major concepts, and the elective tier covers the contents that are more advanced than the undergraduate-major contents.

**Figure 4.** Computer Science Curricula CS2013 body of knowledge.

To derive cross-references among the various topics in the body of knowledge of AI, this study was conducted based on the topics of IS, and each topic was classified as follows. First, the CS2013 body of knowledge was divided into knowledge areas, knowledge units, and topics. Second, CS2013 classified topics of KU into Tier 1, Tier 2, and elective for all eighteen KAs. Third, the IS consisted of four units of knowledge: fundamental issues, basic search strategies, basic knowledge representation and reasoning, and basic machine learning. All four KUs are of the Tier 2 level with 31 topics. Fourth, 17 areas excluding IS consisted of 323 Tier 1 and 327 Tier 2 topics. Thus, the similarities among the 31 topics of IS and 660 topics of 17 KUs were analyzed based on levels.

### *3.3. Sentence Model Performance*

CNN, Manhattan LSTM, and multi-head attention networks (a transformer model) were implemented, and their performances were compared. SNLI and QQP were used as corpuses for model training. For each corpus, 90% of the data was used as the training set, and 10% was used as the testing set to verify the accuracy of the model.


The content and accuracy of the models are detailed in Table 1.

In QQP, CNN had an accuracy of 83.8%, which was superior to Manhattan LSTM's 82.8% accuracy. Conversely, Manhattan LSTM had an accuracy of 80.6% in SNLI, which was the highest accuracy among the three models. Multi-head attention networks showed better accuracy than CNN in SNLI, but the mean accuracy was the lowest among all. In this study, Manhattan LSTM was used, and it showed the highest accuracy when the model accuracy for both corpuses were converted to average.


**Table 1.** Content and accuracy of sentence models.

### *3.4. Setting the Similarity of the Sentence Model*

This study was performed based on the Manhattan LSTM, which had the highest average accuracy of the three models used. In the two sentences, the semantic similarity level in the vectors of the final hidden state was modified to provide outputs according to the threshold values, as shown in Figure 5.

**Figure 5.** Structure of the sentence model used.

The threshold value was set at two levels through various experiments. In other words, the similarity between topics is either greater than 0.95 or just 0.9 but less than or equal to 0.95.

Manhattan distance was applied as a similarity function. In general, the Euclidean distance is not used in the problem of determining similarity because it causes learning to be slow and it has difficulty correcting errors owing to vanishing gradients in the early stage of learning [29]. On the other hand, Manhattan distance can match values so that two sentences are close to one if they are semantically similar and close to zero otherwise, without a separate activation function to determine the output of the neural network in the form of *e*−*x*. The cost function uses MSE for this difference.

### **4. Application Results**

#### *4.1. Sentence Model Performance*

Table 2 presents the results of the analysis of the 31 topics of IS and 323 topics in other areas. More semantically similar topics were extracted as having "0.90<similarity≤0.95" than "0.95 < similarity." This was due to "0.95 < similarity" being more robust than "0.90 < similarity ≤ 0.95."


**Table 2.** Tier 1: Topic pairs with high similarity between IS and topics in other areas.

There were 74 pairs extracted from 57 topics of DS with similarity range of "0.90 < similarity ≤ 0.95." In addition, 31 topics of IS and 57 topics of DS were compared as 1,767 (31 × 57) topic pairs. Next, many topic pairs with high similarity appeared in the order of SDF and AL. For "0.95 < similarity," six pairs were extracted from five areas; namely, AL, DS, NC, SDF, and SP. AL had two pairs while the other four areas had a pair each. The details are shown in Figure 6.

**Figure 6.** "0.95 < similarity": Topics with high relevance in Tier 1.

The topic pairs with "0.95 < similarity" were focused upon because a high similarity is shown in a robust state. In other words, it was determined that knowledge could be extracted in the order of higher similarity with IS first.

We now examine "Reflexive, goal-based, and utility-based" from the perspective of AI. An agent of AI is an autonomous process that automatically performs tasks for users. It is one of the software systems with independent functions that perform tasks and typically operate in a distributed environment [30]. To perform tasks, an agent interacts with different agents through its own reasoning method using a database of non-procedural processing information called knowledge. Then, the agent continues to act based on the learning and purpose-oriented skills obtained from the experiences. This means that "Reflexive, goal-based, and utility-based" explains how to design an agent program. In this study, as shown in Figure 6, the following were included in the knowledge to be handled before learning the agent.

First, agents interact with external environments using sensors to achieve their goals in the complex and fluid real-world environments. In other words, agents can be said to be faced with a complex problem that requires considering various situations. Divide-and-conquer strategies are effective in solving complex problems. With agent designs, it is easy and simple to approach and solve several smaller sub-problems [31]. This problem-solving strategy can be applied not only to agent design but also to general problem-solving.

Second, to achieve a predefined goal or solution, a systematic search method is applied for problem-solving in the agents "depth and breadth-first traversals," which is a search algorithm in a common search strategy [32].

Third, frames are a means of expressing knowledge in the agents. A knowledge consists of small packets called frames where the contents of a frame are specific slots with values. A topic, "permutations and combinations," is a basic concept that helps with finding or identifying new knowledge [33].

Fourth, there is a need to understand the "references and aliasing" topic to recognize and reason knowledge in agents. There are physical difficulties such as memory limits in perceiving all the knowledge in a computing environment. To solve this problem, knowledge can be limited to certain categories and then perceived. Processing can then be done by referencing the address of the memory where the knowledge is stored [34].

Fifth, the main rationale for constructing multi-agent systems is that by forming a community, multiple agents can provide more added values than one agent can provide. Additionally, agents can participate in an agent society through communications, and they can acquire services owned by other agents through interactions. In other words, even if an agent does not have all the information, it can provide various services because of its interactions with other agents [35]. The concept of "multiplexing with TCP and UDP" has been applied to implement the means of information exchange and communication in the agent society.

Finally, the progress and globalization of various technologies including agents are affecting society. Ethical and moral issues that can arise from agents are indispensable factors in terms of education. "Moral assumptions and values" can be said to be a required topic for the development and use of technology [36].

The topics that have semantic similarities with IS at the "0.90 < similarity ≤ 0.95" range are shown in Table 3.

The areas with the highest distribution of similar topics were in the order of DS (Discrete Structures), SDF (Software Development Fundamentals), and AL (Algorithms and Complexity). Topics that were extracted from the DS included topics covered by IS, such as mathematical proofs and development of the ability to understand concepts [37]. In other words, these areas include important contents such as set theory, logic, graph theory, and probability theory, which are foundations of computer science.

Students should be able to implement IS to solve problems effectively in the field of AI. This means that they should be able to read and write programs in several programming languages. Topics that were sampled from SDF were more than just programming skill topics. They included basic concepts and techniques in the software development process, such as algorithm design and analysis, proper paradigm selection, and the use of modern development and test tools.

Performance may vary in implementing IS depending on both the accuracy of the model trained with a large amount of data, and the chosen algorithm and its suitability. In other words, algorithm design is very important to improve the efficiency of the IS model design [31]. AL is said to be the basis of computer science and software engineering as well as IS design. As demonstrated by brute-force algorithms, greedy algorithms, and search algorithms, building an understanding and insight into the algorithms is one of the important factors in IS composition. As seen from the results of the study, the topics with high similarity which were sampled from this study are closely related to the IS and AI fields.


**Table 3.** "0.90 < similarity ≤ 0.95": Topics with high similarity in Tier 1.

### *4.2. Results for Tier 2*

Table 4 presents the results of analyzing 327 Tier 2 topics in the IS and other areas. As shown in Table 2, Table 4 also shows more similar topic samples at the "0.90 < similarity ≤ 0.95" range than at the "0.95 < similarity" level. At the "0.90 < similarity ≤ 0.95" range, SE had 31 topic pairs of samples from 59 topics. The next highest area was PD (Parallel and Distributed Computing). At the "0.95 < similarity" level, ten pairs were extracted from seven areas, and the NC (Networking and Communication) topic had the most pairs compared to other topics with three pairs. As shown in Figure 7, the three topics of IS and the topic of "Routing versus forwarding" of NC were similar.


**Table 4.** Tier 2: Topic pairs with high similarity between IS and other topic areas.

**Figure 7.** "0.95 < Similarity": Topics with high relevance in Tier 2.

Based on CS2013, Tier 2 features more advanced topics compared to Tier 1. To learn these topics, results from Section 4.1 can be considered as prerequisite knowledge or references.

As shown in Figure 7, seven topic pairs from six areas were extracted from "reflexive, goal-based, and utility-based" of the "0.95 < similarity" level. AR (Architecture and Organization) had two pairs, while the other five areas had a topic pair each.

Looking at the AL around "reflexive, goal-based, and utility-based," Tier 1 is about problem-solving and graph-based querying (see Figure 6). On the other hand, Tier 2 is about an approach to algorithm recursion or iteration with "analysis of iterative and recursive algorithms." In other words, the similar topics of Tier 1 are related to theories or concepts for problem-solving, whereas Tier 2 is related to strategies for efficiently implementing agents [37].

Stability is an important factor for processing the knowledge of agents and for agents to interact with one another using processed knowledge. In terms of implementation, "signed and two's complement representations" is related to the sign and representation range variables to be saved in the memory. Memory overflow occurs when the data overrun the designated boundary of the memory space. This can lead to unexpected behaviors or computer security vulnerabilities [38]. In addition, "fault handling and reliability" is also about ensuring stability by handling the fault of the memory system.

In addition to the abovementioned, concepts related to extending or parallelizing the information system of agents and the topics related to tools that can be utilized without having to implement everything from scratch were extracted. Agents, which are typically complex and massive systems, are developed in a collaborative design environment. "Ethical dissent and whistle-blowing," which is an ethical and altruistic act, is about helping to improve the development of an organization and further contribute to forming a healthy community [39].

"Deterministic versus stochastic," "discrete versus continuous," and "autonomous versus semi-autonomous" were very similar to "routing versus forwarding" of NC. The three topics of IS were related to the characteristics of the given problems, which needed to be optimally inferred or solved by the agents. These concepts are applied in various fields in addition to AI. This means that it is an important topic for agents. However, the importance can also be applied to routing algorithms for deterministic routing and probabilistic routing [40], autonomous systems for network topology management and control [41], and the process of finding a network [42].

Table 5 presents topics with similarity to IS in Tier 2 at the "0.90 < Similarity ≤ 0.95" range. Topics were extracted from 13 of 17 areas excluding IS. Of these, the SE (Software Engineering) and PD (Parallel and Distributed Computing) areas had the most topics extracted. Schedule, cost, and quality are important factors in the design, implementation, and testing phases of IS. The SE topics are applicable to software development in all areas of computing, including IS [43]. SE, along with IS, can be said to be related to the application of theory, knowledge, and practice to build general-purpose software more efficiently.



In terms of the efficiency of IS, understanding parallel algorithms, strategies for problem decomposition, system architectures, detailed implementation strategy, and performance analysis and tuning is important [31]. For this reason, the topics extracted from PD, such as concurrency and parallel execution, consistency of status and memory operation, and latency, had high similarities.

### *4.3. KU of IS vs. KU of Other Knowledge Areas*

The following four units were the Tier 2 topics of academic major level in the IS areas of "fundamental issues," "basic search strategies," "basic knowledge representation and reasoning," and "basic machine learning." KUs with high similarity to the KUs of IS are shown in Table 6 below. At "0.95 < similarity," the KU topic "fundamental issues" was extracted. At "0.90 < similarity ≤ 0.95," all three KU topics except for "basic knowledge representation and reasoning" were extracted.

Out of 18 KAs, 53 KUs of 14 KAs showed a high similarity. The areas with high similarities to IS and KU were DS, NC, and SE. "Sets, relations, and functions" of DS had the highest similarity with the 14 KUs. If the topics were limited to KUs with more than five pairs of similar areas, "fundamental issues" was related to the 15 KUs. There were nine KUs corresponding to AL, DS, and SDF of Tier 1, and six KUs corresponding to AL, NC, PD, and SE of Tier 2. "Basic machine learning" had a high similarity with "algorithmic strategies" of AL, and "graphs and tree" and "sets, relations, and functions" of DS from Tier 1. "Fundamental issues," which covers the general contents of AI relative to the other KUs, had 111 pairs from Tier 1 and 84 pairs from Tier 2 with a high similarity. There were more than 40 pairs of topics on "basic machine learning," which covers the basics of machine learning, such as supervised learning, unsupervised learning, and reinforcement learning.

### *4.4. Evaluation*

Currently, there is no research available on the curriculum relevance of the topics. The validity of the study was evaluated using two methods because a direct comparison with previous studies was not possible. The first method was a content validation from experts. The degree of similarity among the topics for content validity verification was quantified. At the same time, "rater reliability" among experts was determined. The second method was an index term-based method using a search engine. Specifically, the index term-based method was applied in the study of variables affecting similarities between two documents or two topics. This is because if the two topics appear simultaneously in the same document referring to a specific area, they can be said to be semantically related.

### 4.4.1. Content Validation by Experts

This study analyzed opinions from five experts based on the topics related to "reflexive, goal-based, and utility-based" for intelligent agent design. The experts met at least three of the following four criteria: Ph.D. in AI, seven years of work experience in the field of AI, seven years of AI research experience, and five years of experience teaching AI at a university.

The content validation proceeded as follows. In the first step, 66 topics were selected as random samples from four different similarity ranges: "0.95 < similarity," "0.95 ≥ similarity > 0.9," "0.9 ≥ similarity > 0.8," and "0.8 ≥ similarity > 0.7." In the second step, experts were asked to examine the extracted topics based on semantic relevance with "reflexive, goal-based, and utility-based" and the need for those topics. The results of experts on the relevance and need of each range of topics are shown in the Table 7.

Regarding the relevance of "reflexive, goal-based, and utility-based" to topics in "0.95 < similarity" range, the expert response was close to four at 3.74, suggesting a high "relevance." The necessity score of the topics for intelligent agent design was also high at 3.74.

For the topics in "0.95≥similarity>0.9" range, the relevance score was 3.49, and the "0.9≥similarity > 0.8" range received a relevance score of 3.08. For the topics in the "0.8 ≥ similarity > 0.7" range, an "average" relevance score of 3.07 was received. In other words, based on the topics that were sampled in the study, it was shown that the higher the "similarity" range, the higher the relevance was for expert content validation. The "rater reliability" of experts was shown to be high at 0.90–0.78. Therefore, it can be concluded that topics sampled in the study for each range have a high semantic relevance.






**Table 7.** Expert review results with regards to topics.

### 4.4.2. Validation through Index Terms

Each of the six topics having semantic similarities in "reflexive, goal-based, and utility-based" and other areas from Tier 1 of "0.95 < similarity" was entered into a search engine and the results displayed on the first page were checked (see Figure 6). The search engine results are shown in Figure 8.

**Figure 8.** Topic search results from search engine.

As shown in Figure 8, for (1), (2), (3), (4), and (6), words included in the two topics appeared together in the documents related to AI. This implies that the two topics are either AI-related or semantically close. Further, for the case of (5), "utility-based" and "alias, references," which are utility-based mechanisms for managing communications in collaborative multi-sensor networks, were found. From this, it can be interpreted that the terms appeared simultaneously in one document because communications using sensors are directly linked to the communications among agents.

### **5. Conclusions**

The body of knowledge in a specific field of study reflects the linkages and continuity of knowledge. Identifying the inter-topic relationships among bodies of knowledge will help in constructing a hierarchy or a sequence of knowledge areas.

This study examined the interrelationships among the topics based on the understanding of various bodies of knowledge to provide suggestions for topic compositions to construct an academic body of AI knowledge. In constructing new topics through the analysis of the body of knowledge, sentence modeling, which is a data science method, was used to minimize noise in the semantic interpretation of the topics containing multiple words. Further, unprocessed original topics were applied. Regarding extracted contents, the validity was verified through expert validation tests and semantic similarity tests based on index terms.

Based on the obtained results, there were several topics with high similarity related to the agent design in the areas of "algorithm complexity," "discrete structures," "fundamentals of software development," "parallel and distributed computing," and "software engineering." These topics included algorithm design strategy, "divide-and-conquer," "data retrieval method," "depth and breadth-first traversals," and a theoretical foundation in the CS field, "permutations and combinations." The top three units with high similarity topic distribution based on KU were "sets, relations, and functions" and "graphs and trees" of DS, and "algorithmic strategies" of AL. These are the basic theories and concepts applied throughout the CS field, including the agent design.

Among the limitations of this study is sparsity of data on the core topics used in education. A large amount of data is needed to extract stable values through machine learning, but the ultimate limitation is that documents on the curriculum or core topics of education are not sufficient in terms of learning data. Furthermore, it is an area where research has not been conducted sufficiently on how to justify the results after the study through learning models. Consequently, identifying the possibilities through the educational use of the research results is time-consuming. To reflect emerging knowledge in education, the composition of knowledge or modeling of topics is essential. To achieve this, it is necessary to establish various methodologies to overcome limitations in the availability of extracted knowledge.

The field of AI deals with expert knowledge based on basic knowledge. The AI's body of knowledge compositions is extremely crucial in terms of preparing an academic foundation and deciding the topics to cover in the future. Therefore, the sentence modeling method presented in this study will contribute to the construction of various levels of knowledge hierarchies in the body of knowledge of AI and further understanding of the knowledge of the topics.

**Author Contributions:** Conceptualization, H.W. and J.K.; methodology, H.W.; software, H.W.; validation, H.W. and J.K.; writing—original draft preparation, H.W. and J.K.; writing—review and editing, H.W. and J.K.; visualization, H.W.; project administration, W.L.; funding acquisition, W.L. All authors have read and agreed to the published version of the manuscript.

**Funding:** This work was supported by the National Research Foundation of Korea(NRF) grant funded by the Korea government(MSIT) (No. 2016R1A2B4014471).

**Conflicts of Interest:** The authors declare no conflict of interest.

### **References**


© 2020 by the authors. Licensee MDPI, Basel, Switzerland. This article is an open access article distributed under the terms and conditions of the Creative Commons Attribution (CC BY) license (http://creativecommons.org/licenses/by/4.0/).

## *Article* **Prediction of Academic Performance at Undergraduate Graduation: Course Grades or Grade Point Average?**

**Ahmet Emin Tatar 1,† and Dilek Dü¸stegör 2,\*,†**


Received: 22 June 2020; Accepted: 16 July 2020; Published: 19 July 2020

**Abstract:** Predicting the academic standing of a student at the graduation time can be very useful, for example, in helping institutions select among candidates, or in helping potentially weak students in overcoming educational challenges. Most studies use individual course grades to represent college performance, with a recent trend towards using grade point average (GPA) per semester. It is unknown however which of these representations can yield the best predictive power, due to the lack of a comparative study. To answer this question, a case study is conducted that generates two sets of classification models, using respectively individual course grades and GPAs. Comprehensive sets of experiments are conducted, spanning different student data, using several well-known machine learning algorithms, and trying various prediction window sizes. Results show that using course grades yields better accuracy if the prediction is done before the third term, whereas using GPAs achieves better accuracy otherwise. Most importantly, variance analysis on the experiment results reveals interesting insights easily generalizable: individual course grades with short prediction window induces noise, and using GPAs with long prediction window causes over-simplification. The demonstrated analytical approach can be applied to any dataset to determine when to use which college performance representation for enhanced prediction.

**Keywords:** academic performance; course grades; data mining; grade point average; machine learning; prediction; undergraduate

### **1. Introduction and Motivation**

Educational Data Mining (EDM) is a fast-growing scientific field offering the potential to analyze a variety of student features to harness valuable knowledge from them. To this end, a plethora of predictive algorithms were effectively applied in educational contexts for numerous purposes using a variety of data and student records. As compiled in the review paper [1], two main application purposes can be identified in the college contexts: predictors and early warning systems (EWS). A predictor, "given a specific set of input data, aims to anticipate the outcome of a course or degree" [1], and a EWS "performs the same tasks as a predictor, and reports its findings to a teacher and/or to students at an early enough stage so that measures can be taken to avoid or mitigate potentially negative outcomes" [1]. Common prediction goals are listed as risk of failing a course, dropout risk, grade prediction, and graduation rate.

Among the various prediction goals, prediction of academic performance at graduation time especially, is of tremendous importance, as this information can be useful for:


Looking at the literature on prediction of academic performance at the graduation time, we can observe that all studies rely mainly on four types of information on students, namely: (1) demographics and socio-economic, (2) high-school related, (3) college enrollment, and (4) college performance (up to the time of prediction).

Commonly used demographics and socio-economic information are sex/race [2], household income [3], age, first generation student [4], marital status, parents' jobs and educational levels [2]. Among the high-school related information, high-school GPA [2], pre-college marks [5,6], college admission test scores [3], public or private high-school [2], are frequently observed. As college related, in terms of enrollment information, the major and campus [2], a student's full-time vs. part-time status as well as whether s/he has a scholarship [3], enrolled hours and earned credit hours [4], year of entry and program [2,7] are often used. Finally, we observe that college performance has mostly been represented with grades from courses taken earlier [2,4,6,8], unless the prediction model is meant to be used at admission time [3,8,9].

Based on the above background, we observe that the bulk of previous studies used datasets with relatively large dimensionality of observations, some of them being expensive to measure (when not already available in the records). This, often combined with small samples, caused the curse of dimensionality, potentially yielding models with sub-optimal predictive power.

Recently, there is a trend to use only college performance [8,9], or using courses average per semester instead of individual courses grades (i.e., GPA per semester, or CGPA for cumulative average at time of prediction) [2,4,7]. However, there is no study comparing the performance of EDM models using individual course grades vs. grade point averages. It is unknown whether these two college performance representations are equivalent and can be used interchangeably, or if one is superior to the other in yielding better predictive power.

The main purpose of the present study, therefore, is to elucidate this matter by answering the following research question:

### *Is the individual course grade or grade point average more relevant for predicting student graduation academic performance?*

To answer this question, recent student data compiled at the College of Computer Science and Information Technology (CCSIT) from Imam Abdulrahman bin Faisal University (IAU) are used to generate two sets of predictive models, one using individual course grades, the other using the grade point averages. Thus, predictive power of respective models can be compared. However, it is well known that the performance of such models can also be affected by (1) student data used (besides the academic performance), (2) the data mining technique applied, as well as (3) how far from graduation the prediction is performed. Therefore, a comprehensive set of experiments is designed for spanning the whole search space made of student information besides the college performance, several machine learning methods commonly used in the literature, and prediction window of various sizes.

In the following sections, we first describe the research methodology, including, the dataset description, its preprocessing, the methods used, the experimental setup, and the evaluation criteria. Then, each conducted experiment and its results are reported, followed by the discussion and the concluding remarks.

### **2. Research Methodology**

#### *2.1. The Dataset and its Preprocessing*

Our dataset contains records of 357 students who were admitted to the CCSIT at IAU from Fall 2011 to Fall 2013 (included), and thus includes three batches of students. The institutional review board at IAU reviewed and approved using the data anonymously (application approved on 19 December 2018; IRB Number: 2018-09-304). Two programs of CCSIT are included in this study, namely Computer Science (CS) and Computer Information Systems (CIS). During the first three years, all the students of the College follow the same plan. In their first year, they attend the Preparatory program where they take mainly intensive English Language courses. In their second and third years, called General Years, the students take courses fundamental to computer and information sciences. At the end of their third year, students select either CS or CIS program based on their interests.

Student records populated from IAU learning management system contain features of three different nature: the demographic features, the pre-college features, and the college records including enrollment information and college performance.

**The demographic features** consist of gender and nationality (see Figure 1). The female gender dominates the CCSIT as the College is one of the top ranking colleges in the Eastern Region of Saudi Arabia for females. As expected, the dominant nationality is Saudi Arabian with over 85%. The other significant countries represented are Yemen (YEM), Egypt (EGY), Jordan (JOR), and Syria (SYR). There are nine other countries represented — Morocco, Pakistan, Palestine, Ethiopia, India, Iraq, USA, Bahrain, and Sudan —each with less than 0.5% grouped in the OTHR class.

**Figure 1.** Bar charts showing the demographic information of the dataset.

The demographic features being nominal, we need to convert them into numerical features to use them in machine learning models. We used three different approaches for this purpose.


We experimented with all three approaches. We did not see a major difference in the performance metrics when Logistic Regression or Random Forest machine learning methods are used. However we observed a significant decline in performance when Naive Bayes method is used with either the redundant or the non-redundant dummification which can be explained by the introduction of the new features that are not probabilistically independent. Because of this performance drop, we adopted the third approach in all our experiments.

**The pre-college features** consists of scores obtained from three national exams. These are numeric scores over 100. The only preprocessing we applied to these features were standardization. (i.e., shifted the mean to 0 and scaled the standard deviation to 1).

**The academic records** are the third group of features, which contain all transcript information, including admission term, graduation term, and letter grades for all the courses taken per term, for all terms including preparatory year until graduation. We only use the numerical values of the letter grades as described in Table 1. The irregular students, as they are very rare, are not included in this study. Thus, per semester, students take the courses as shown in Table 2 that is prepared based on the degree plan (The actual degree plans can be found at the links [10] for CS program and [11] for CIS program).

**Table 1.** Conversion table from ordinal to numerical value for letter grades (defined by the university).


**Table 2.** Courses taken by term by regular CCSIT students during the academic years 2011/2012, 2012/2013, and 2013/2014.


The target variable in all the models is the graduation GPA, the weighted mean of the numeric scores of all the courses taken by a student. To draw more meaningful results, we use the graduation GPA not as a numerical feature but as an ordinal feature with three categories determined by the university. A student whose graduation GPA out of 5 is greater than or equal to 4.5 belongs to the class *"High GPA"*, between 4.5 and 3.75 (included) to the *"Average GPA"* class, and less than 3.75 to the *"Low GPA"* class. Figure 2 shows the distribution of three classes in the dataset.

**Figure 2.** The bar chart showing the three classes in the target variable. The GPA Classes 0, 1, and 2 represent the classes Low GPA, Average GPA, and High GPA, respectively.

#### *2.2. Experimental Set-Up*

To answer the research question, we develop several classification models (and not regression, as the target variable has been transformed into an ordinal variable in Section 2.1) that differ with respect to (1) the way college performance is defined, (2) the type of student's data included, (3) the machine learning algorithm applied, (4) the size of the performance window, and (5) the size of the observation window (historical data).

In this study, we define the "term performance" of a student in two different ways. In the first representation, *by courses*, term performance is represented by a vector of size equal to the number of courses that should be taken in that term according to Table 2 with components being the numeric scores of the courses. In the second representation, *by GPA*, we represent term performance by the numeric weighted average of the courses with weights being the credit hours. Comparing the results obtained from the models *by courses* vs the models *by GPA* allows identifying which of the individual course grades and grade averages is more relevant for predicting student's graduation academic performance, thus answer the main research question of this study.

Then, the college performance data for students is modeled using two observation window size. In the first approach, only the immediate past term performance is included in the model (either last term *by courses*, or last term *by GPA*). In the second approach, a cumulative view is adopted where all the past terms' performance is included in the analysis (either cumulative *by courses*, or cumulative *by GPA*). The first approach corresponds to using only one term as history window, while the second approach corresponds to using all past terms data since the student joined the college. The reason to consider models that include last term performance only, is to isolate the term the most impactful to the student's success.

Figure 3 shows a sample student transcript data for the first 6 terms. For instance, let us consider predicting the graduation GPA class at the end of the second year, which is the end of the term 4, if we want to use one term observation history, then we use the term 4 performance alone, either as the term courses which is the vector [2.5, 2, 4, 2.5, 4], or the GPA that is calculated as (3 × 2.5 + 4 × 2 + 3 × 4 + 4 × 2.5 + 2 × 4)/16 = 2.84375. On the other hand, if we want to use all past observations cumulatively, then we use all past terms performance, either as all past courses, which is to say the vector [4.5; 4.5; 4; 4; 4; 3.5; 4.75; 4.5; 3.5; 3; 2.5; 2.5; 4.75; 2.5; 2; 4; 2.5; 4; 2; 3; 2; 2.5; 3; 5; 2; 3; 2; 3.5; 3; 4.5], or the accumulated GPAs as [4.3125; 4.1; 3.1; 2.84375].

For investigating the impact of the prediction window, we develop six models at different times of the curriculum, (1) as early as by the end of the first semester of the preparatory year, *term 1*, (2) by the end of preparatory year, *term 2*, (3) after the first semester of the general year, *term 3*, (4) by the end of the first general year, *term 4*, (5) after completing the first term of the second common year, *term 5*, (6) by the end of the common general years, *term 6*. With reference to the student in Figure 3, the above described models correspond respectively to using only the term 1 data, adding one term at a time until all six terms data are used.

**Figure 3.** Sampe student transcript data (color code is, yellow: 1, orange: 2, light green: 3, dark green: 4, blue: 7 credits each).

We develop machine learning (ML) models working with the algorithm commonly used in EDM, namely Logistic Regression (LR), Random Forest (RF), and Naive Bayes (NB) with the accuracy as the performance metric. *Logistic Regression* is a linear model used for classification. It is often the first model considered due to its simplicity and interpretability. *Random Forest* is an ensemble method that fits number of decision trees. It makes a prediction based on the average of the predictions from the decision trees which is the method most used to predict graduation performance in the literature as identified in the review paper [12]. *Naive Bayes*, runner up method in the literature [12], is a statistical method based on the Bayes' Theorem. Its performance depends on the statistical independence of the features.

Finally, in order to investigate the impact of set of features on the model performance, we designed four experiments that exclude some features as seen in Table 3. Please note that we excluded all the four experiments which do not include academic records as they are not relevant to the goal of this study.



Thus, a total of 288 experiments are conducted. Figure 4 recapitulates them.

**Figure 4.** Experiments Conducted.

### *2.3. Performance Evaluation*

As a base case model, instead of a simple random guess, we develop naive models based on the statistical facts that only uses term performance features and not any demographic or pre-college features. The idea behind these models is the following: for every term, we calculate the average term performance across all the students. No matter how these terms' performance is calculated, *Single-Course*, *Single-GPA* , *Cumulative-Course* , and *Cumulative-GPA*, if for a student they are always equal or above the mean of the term performance calculated across all the students, then that student is classified as *High GPA* student (i.e., a student always better than the average). Conversely, if they are always below, then the student is classified as *Low GPA* student (i.e., a student always lower than the average). All the other cases are classified as *Average GPA*. Table 4 illustrates the naive models with three sample students using the term performance by current GPA's.

**Table 4.** The table shows how Students A, B, and C are classified by the naive models by the end of Term 4, assuming the means of the first 4 term GPA's are 4, 3.25, 3.5, and 3.75, respectively.


While developing the naive models, with a total of 357 samples, the size of the dataset can be problematic. If we use all the samples to develop our model, then we do not have any samples to estimate the true performance (performance of the model on an unseen data) of our models. Therefore, we split our dataset into training and testing datasets. We develop our model on the training dataset and evaluate its performance on the testing dataset. Performance indicators obtained using this approach, called *hold-out technique*, are more realistic. Yet, there are still some concerns. First of all, due to the hold-out samples, the learning is not 100%. To improve learning, we use split ratios with high training percentage. This creates yet another problem. Due to the small size of testing sets, the performance results may vary significantly. To reduce this variance, we can use repeated training and testing phases or use subsampling methods such as *k*-fold cross validation, or even repeated subsampling methods. We decide to use the repeated subsampling to minimize the variance in the accuracy scores. For this, the dataset is divided randomly into training and testing at the ratio of 4:1. Then, we calculate the statistics on the training dataset, do the classification of the samples on the testing dataset based on that statistics, and record the accuracy of the classification. Finally, we repeat

this experiment 500 times and report the arithmetic mean as the result. Table 5 reports the performance of naives models.


**Table 5.** Results of the Naive Models.

### **3. Experiments and Results**

All experiments are done on Python 3. We design the ML models on Python's scikit-learn library version 0.22.2 with default hyper-parameters. For the ML models, we also use the repeated 5-fold stratified cross-validation with 100 repetitions. Since our goal is to observe the change of performance when the academic features are used either *by course* or *by GPA*, hyper-parameter search is not relevant. Nevertheless, when we tested random hyper-parameters, we observed the same trends as explained in the Discussion Section 4. We record the performance of the model both on the training and the testing datasets. All results are reported in following sub-sections per scenario.

### *3.1. Scenario 1: Academic Features Only*

The first scenario only includes academic features, and the performance of the obtained prediction models are reported in Table 6. The best performance by the end of term 1, is the LR method with term performance represented *by courses*, whether single or cumulative, with 65.6%. When the prediction is performed later, the best performance is systematically obtained again with the LR method, but with the term performance represented *by cumulative GPA*. Please note that GNB shows the same best performance for the terms 5 and 6, and second best for the terms 3 and 4. Looking at Figure 5b,d, we observe that performance of the cumulative models are improving from 65.6% (term 1) to 94.9% (term 6) with decreasing prediction window size. Finally, Figure 5a,c show that among the models using only one past term, performance reaches a pick mostly when the prediction is performed by the end of term 4 or term 5.


**Table 6.** Accuracy results of the ML Models only with the academic features.

**Figure 5.** Accuracy plots along the terms of the ML Models only with the academic features.

### *3.2. Scenario 2: Demographics and Academic Features*

The second scenario includes demographics and academic features. Performance of the obtained prediction models is reported in Table 7. The best performance by the end of term 1, is the LR method with the term performance represented *by courses*, whether single or cumulative, with 64.4%. When the prediction is performed later, we observe similar results with the scenario 1, i.e., the best performance is mainly obtained with the LR method, with the term performance represented *by cumulative GPA*. Please note that GNB shows a slightly superior performance for the term 4, and the second best for the terms 5 and 6. Looking at Figure 6b,d, we observe that the performance of cumulative models are improving from 64.4% (term 1) to 94.9% (term 6) with decreasing prediction window size. Again, same as for the scenario 1, the models using only one past term reach a performance pick mostly when the prediction is performed by the end of term 4 or term 5 (see Figure 6a,c).


**Table 7.** Accuracy results of the ML Models with the academic and demographic features.

**Figure 6.** Accuracy plots along the terms of the ML Models with the academic and demographic features.

### *3.3. Scenario 3: Pre-College and Academic Features*

The third scenario includes pre-college and academic features. Performance of the obtained prediction models is reported in Table 8. The best performance by the end of term 1, is again the LR method with term performance represented *by courses*, whether single or cumulative, with 63.2%. When the prediction is performed later, we observe similar results with the previous two scenarios, in other words the best performance is mainly obtained with the LR method, with term performance represented *by cumulative GPA*. The GNB shows a slightly superior performance for the term 6. Figure 7b,d shows that the performance of the cumulative models are improving from 63.2% (term 1) to 93.7% (term 6) with the decreased prediction window size. Again, same as for previous two scenarios, the models using only one past term reach a performance pick for prediction performed by the end of term 4 (see Figure 7a,c).


**Table 8.** Accuracy results of the the ML Models with the academic and pre-college features.

**Figure 7.** Accuracy plots along the terms of the ML Models with the academic and pre-college features.

### *3.4. Scenario 4: Demographics, Pre-College, and Academic Features*

The last scenario includes all students data, that is to say, demographics, pre-college, and academic features. Performance of the obtained prediction models is reported in Table 9. The best performance by the end of term 1, is the LR method with the term performance represented *by courses*, whether single or cumulative, with 62.5%. When the prediction is performed later by the end of the terms 3, 4, 5, and 6, we observe similar results with all the past scenarios, that is, the best performance is mainly obtained with the LR method, with term the performance represented *by cumulative GPA*. Looking at Figure 8b,d, we observe that the performance of the cumulative models are improving from 62.5% (term 1) to 93.5% (term 6) with decreasing prediction window size. Models using only one past term reach a performance pick mostly when the prediction is performed by the end of term 4 or term 6 (see Figure 8a,c).


**Table 9.** Accuracy results of the ML Models with the academic, demographic, and pre-college features.

**Figure 8.** Accuracy plots along the terms of the ML Models with the academic, demographic, and pre-college features.

### **4. Discussion**

In this section, we are going to discuss our findings from two perspectives: (1) findings that can be generalized to any dataset and (2) findings specific to our dataset.

Looking at Table 10, which summarizes the best ML models per term, one can see the answer to the research question, namely, which representation of college academic performance is more relevant for predicting student graduation academic performance. When the models are compared based on how the academic performance is defined, that is whether *by courses* or *by GPA* as explained in Section 2.2, we notice, in general, that the models where the academic records are used *by GPA* achieve higher accuracy scores at the later terms whereas the models where the academic records are used *by courses* achieve higher accuracy scores at the earlier terms. To be more precise, we can see in Table 10 that course grades yields better results until the end of term 2, and GPA gives better results afterwards.


**Table 10.** Summary of the best models per term.

However, in an attempt to gain more insight, we analyzed the variance of the training and testing performance of ML models. We observed that after term 2, the course models all had a higher variance. For example, Figure 9 illustrates the train and test accuracy scores by the number of experiments of all ML models on term 6. In all the plots, it is clearly visible that the difference between the green and the red curves (train and test of the GPA models) is less than the difference between the blue and the orange curves (train and test of the Course models). This indicates that if we introduce academic performance into the models *by courses*, in later terms, the models learn the noise (i.e., overfitting), which results in performance loss. This is not the case for GPA models in later terms, since one GPA per term would replace several (up to six at IAU-CCSIT) individual course grades, thus representing an equivalent

information with reduced size of academic performance features. On the other hand, if we introduce the academic performance into the model as GPA at earlier terms, then we are over-simplifying the model by reducing the size of the academic features to a single number per term. We expect this observation to hold in any academic dataset and thus conclude that for earlier predictions academic performance should be used *by courses* and for later predictions *by GPA*.

**Figure 9.** Variance plots of ML models for term 6.

The conducted experiments also allow answering several common EDM questions with respect to the specific IAU-CCSIT dataset. Our experiment results show that all the models perform better than a random guess (which can be considered to be roughly 33.3%). Moreover, ML models perform significantly better than the naive models that we defined in Section 2.3 as our base line result. We can conclude that the prediction models can be used as early as by the end of term 1, knowing that delaying the prediction thus gathering more academic data about the student will improve the performance of the classifier. Within ML models, LR shows the best overall performance with GNB being the runner up. As for the highest performance, both LR and GNB records 94.9% accuracy using all six current GPAs cumulatively. The poor performance of RFC can be explained by the overfitting which is very clear from Figure 9c,d where training performance reaches 100%. Looking at the four scenarios and the corresponding results, we observe that the demographics is not a significant group of features as scenario 2 never yields the best performance. Academic records alone give the best performance models when used by the end of term 1, term 3, and term 6. For the terms 2 and 5, knowing about pre-college exam results slightly improves the performance. Certainly, these results should be interpreted as specific to the IAU-CCSIT dataset and congruent with the many case studies in the field ([13] and in their references).

Finally, we can draw conclusions from the models where the academic records are used alone or cumulatively. As expected, when the academic records, again alone or with the other features, are used cumulatively we get the best accuracy scores. Yet, we can extract some valuable information from the models where the academic performance is used alone. For instance, from the accuracy graphs we see that the accuracy scores peaked at term 4 and stabilized afterwards. Hence, term 4, which is the end of the general year 1, has the maximum impact on the graduation GPA class. We can thus conclude that the term 4 is a good moment to start predicting the graduation performance. This information can also be shared with students explaining them that an extra effort they will put in their studies on term 4 will have higher impact on their graduation GPA.

### **5. Conclusions**

With a plethora of studies in EDM for predicting a student's academic success at graduation time, this study investigated which of the individual course grades or grade averages is more relevant for predicting student graduation academic performance. Although both types of data are interchangeably used in the literature, there is no study comparing the performance of EDM models using grade averages vs. individual course grades. It is unknown when and how to use these two college performance representations to attain best predictive power. To elucidate this matter, a comprehensive set of experiments were conducted on the recent student data compiled from the second author's college.

The experiment results show that for earlier predictions, individual course grades should be used to represent academic performance, while it is preferable to use GPAs for prediction after a few terms. We explain based on variance analysis that this will help avoiding oversimplification and noise, as both can lower the performance of a predictive model. This is a novel contribution to the field of EDM, that will enable scientists and educators to decide which representation to adopt depending on the time of prediction.

The second main contribution of this study is to investigate the individual impact of each semester on the graduation academic performance. The results of such an analysis can help identifying when is the best time to do the prediction (e.g., in order not to miss the most impactful term), or can help in advising and motivating the students about when to put extra efforts in their studies.

**Author Contributions:** Conceptualization, D.D.; Methodology, A.E.T.; Software, A.E.T.; Validation, A.E.T.; Formal Analysis, A.E.T.; Investigation, D.D.; Resources, D.D.; Writing—Original Draft Preparation, A.E.T.; Writing—Reviewing—Editing, D.D.; Visualization, A.E.T.; Supervision, D.D.; Project Administration, D.D. All authors have read and agreed to the published version of the manuscript.

**Funding:** This research did not receive any grant.

**Conflicts of Interest:** The authors declare no conflict of interest.

### **References**


© 2020 by the authors. Licensee MDPI, Basel, Switzerland. This article is an open access article distributed under the terms and conditions of the Creative Commons Attribution (CC BY) license (http://creativecommons.org/licenses/by/4.0/).

## *Article* **Predicting and Interpreting Students' Grades in Distance Higher Education through a Semi-Regression Method**

### **Stamatis Karlos, Georgios Kostopoulos and Sotiris Kotsiantis \***

Department of Mathematics, University of Patras, 26504 Rio Patras, Greece; stkarlos@upatras.gr (S.K.); kostg@sch.gr (G.K.)

**\*** Correspondence: sotos@math.upatras.gr

Received: 30 October 2020; Accepted: 24 November 2020; Published: 26 November 2020

**Abstract:** Multi-view learning is a machine learning app0roach aiming to exploit the knowledge retrieved from data, represented by multiple feature subsets known as views. Co-training is considered the most representative form of multi-view learning, a very effective semi-supervised classification algorithm for building highly accurate and robust predictive models. Even though it has been implemented in various scientific fields, it has not adequately used in educational data mining and learning analytics, since the hypothesis about the existence of two feature views cannot be easily implemented. Some notable studies have emerged recently dealing with semi-supervised classification tasks, such as student performance or student dropout prediction, while semi-supervised regression is uncharted territory. Therefore, the present study attempts to implement a semi-regression algorithm for predicting the grades of undergraduate students in the final exams of a one-year online course, which exploits three independent and naturally formed feature views, since they are derived from different sources. Moreover, we examine a well-established framework for interpreting the acquired results regarding their contribution to the final outcome per student/instance. To this purpose, a plethora of experiments is conducted based on data offered by the Hellenic Open University and representative machine learning algorithms. The experimental results demonstrate that the early prognosis of students at risk of failure can be accurately achieved compared to supervised models, even for a small amount of initially collected data from the first two semesters. The robustness of the applying semi-supervised regression scheme along with supervised learners and the investigation of features' reasoning could highly benefit the educational domain.

**Keywords:** educational data mining; student grade prediction; semi-regression; early prognosis; interpretation; COREG algorithm

### **1. Introduction**

Educational data mining (EDM) has emerged in the past two decades as a highly-growing research field concerning the development and implementation of machine learning (ML) methods for analyzing datasets coming from various educational environments [1]. The key concept is to utilize these methods, extract meaningful knowledge about students' performance, and improve the learning process enriching the insights that the tutor may obtain on time. These methods are grouped into five main categories [2]: Prediction, clustering, relationship mining, discovery with models, and distillation of data for human judgment. The main research interest has been centered on predictive problems primarily concerned with three major questions [3]: (1) What outcome of students will be predicted? (2) Which ML methodology is the most effective for the specific problem? (3) How early can such a prediction be made?

Most of the EDM research is mainly focused on implementing supervised methods utilizing only labeled datasets. To this end, a plethora of classification and regression techniques have successfully been applied for predicting various learning outcomes of students, such as dropout, attrition, failure, academic performance, and grades, to name a few. In addition, the main interest concentrates on building efficient predictive models at the end of a course using all available information about students [4]. However, it is of practical importance to provide both accurate and early-step predictions at minimum cost [5]. A review of recent studies and developments in the field of EDM reveals that there is an urgent demand for accurate identification of students at risk of failure as soon as possible during the academic year, since early intervention activities and strategies can be implemented. Preventing academic failure, enhancing student performance, and improving learning outcomes is of utmost importance for higher education institutions that intend to provide high-quality education [6]. Some new directions that have recently been formatted concern the recognition of errors during the composition or the writing of code assessment, usually based on self-attenuation mechanisms for providing high quality automated debugging solutions to undergraduate and post-graduate students, as well as the exportation of remarkable insights about the obstacles that are met by them during such tasks [7].

Apart from supervised methods, semi-supervised learning (SSL) has gained a lot of attention among scientists in the past few years for solving a wide range of problems in various domains [8]. SSL methods exploit a small pool of labeled examples together with a large pool of unlabeled ones for building robust and highly-efficient learning models. However, SSL has not adequately used in the educational domain as easily identified after a thorough literature review. Nevertheless, some notable studies have emerged recently dealing with semi-supervised classification (SSC) tasks, such as student performance prediction or student dropout, while semi-supervised regression (SSR) is uncharted territory. The primal difference between SSC and SSR is that the target attribute is categorical in the former case, while a pure numeric quantity has to be predicted in the latter case. A recent literature review of SSR depicts the most important works in this field [9], separating them into approaches with a common strategy to solve their task, while more related works have been demonstrated on behalf of SSC [10].

Multi-view learning has also attracted the interest of this research community, distilling information from separate views, original or transformed ones, while a search of more appropriate subspaces into the initial feature set always remains a crucial learning task for boosting the performance of SSL methods [11,12]. Adopting ensemble learners has also been an active research territory concerning SSL [13], while some similar works have been demonstrated by our side [14,15]. Although some recent advances have taken place—exploiting graph-based solutions [16–18], or deep learning neural networks (DNNs) [19,20]—attempting to acquire more and more accurate predictions, or even robust ones in case that noisy inputs/labels have violated the ideal case of compact training data [21], such mechanisms introduce some important defeats:


The main scope of the present study is three-fold. At first, we implement a well-known semi-supervised regression algorithm that is based on multi-view learning, adopting several ML learners into its main kernel, tackling with the early prediction of undergraduate students' final exam grades in a one-year distance learning course. Each student is represented in terms of a plethora of features, which were collected from three different sources, thus producing three distinct sets of attributes: Demographics, academic achievements, and interaction within the course Learning Management System (LMS). Secondly, we investigate the effectiveness of the separate SSR variants that are produced compared with their corresponding supervised performance on the examined EDM task. In this sense, the proposed model may serve as an early alert tool with a view to providing appropriate interventions and support actions to low performers.

Finally, we apply a well-established framework for acquiring trustworthy reasoning scores per included attribute/indicator into the original dataset. Hence, interpretable models are created, providing carefully computed explanations about the predicted grades ranking the importance of each indicator without any dimensionality reduction trick and avoiding overconsumption of computational resources under specific cases. To the best of our knowledge, this is the first completed study towards this direction [24], which hopefully will provide the basis for further research in the field of EDM, as it is stated in the relevant and conclusory Sections.

The remainder of this paper is organized as follows. In the next section, we discuss the need for explainable artificial intelligence (XAI) solutions to the field of EDM, highlighting some of the most important approaches in interpreting decisions/predictions of various learning models and the assets of the selected interpretability framework. Section 3 presents a brief overview of relevant studies in the EDM field and some recently published works related to the SSR task. The research goal is set in Section 4, together with an analysis of the dataset used in the experimental procedure. The total pipeline for applying a well-known COREG algorithm (CO-training REGressors) [25] as an SSR wrapper along with several ML learners and some DNNs variants is provided in Section 5, also describing the two distinct explaining mechanisms that are based on the computation of Shapley values [26]. The experimental process and results are presented in Section 6. Finally, our conclusions are drawn in Section 7, which also mentions some promising improvements to this seminal work.

### **2. Interpretability in Machine Learning**

Consider the problem of predicting the final exam grade of students enrolled in a distance learning course using ML. In this case, a supervised algorithm is trained over a set of labeled data (the target attribute values are known), and an ML model is produced (supervised learning), which is subsequently deployed for predicting the grade of a previously unknown student for given values of the input attributes (features of students). The predictive model does not know why the student received the specific grade, while, at the same time, it fails to grasp the difference between anticipated grades and actual performance. Decision-makers are often hesitant to trust the results of these models, since their internal functions are primarily hidden in black-boxes [27]. This is quite reasonable, since people outside of the ML field neither can understand the manner that outputs are exported, nor are confident on just consuming some pure decisions without accompanying them with some consistent proofs. There is also a well-known trade-off regarding the predictive ability and the interpretability of ML algorithms, which unfortunately deters the co-existence of both these properties to be highly qualified under the same ML algorithm, in general. Since predictive models play a decisive role in the decision-making process in higher education institutions, the ability to comprehend these models seems to be indispensable. Thus, the interpretability of provided solutions usually needs to be filtered through XAI tools [28,29].

Model interpretability is the process of understanding the predictions of an ML model. In fact, it is the key point to build both accurate and reliable learning models. In traditional ML problems, the objective is to minimize the predictive error, while interpretability is focused on extracting more valuable information from the model [30]. Commonly, it aims to address questions, such as (Figure 1):


**Figure 1.** Why was a specific prediction was made by the model?

Although several published works have appeared in the literature of XAI recently, the majority of them make assumptions that are not actually consistent with the specifications of an educational task. For example, dimensionality reduction or feature transformations (e.g., semantic embeddings) may lead to incorrect conclusions or reasoning factors that ignore some of the underlying relationships that may be crucial for the real-life problem [31]. Furthermore, DNNs and their variants that operate by manipulating raw-data directly have highly attracted the interest of the XAI community, leading to solutions that are not applicable to our numerical source data. However, this fact does not exclude DNNs from being used as accurate black-boxes to such kind of problems, adopting mainly some model-agnostic approaches [32]. A representative work was done by Akusok et al. [22] exploiting extreme learning machines (ELM) trained on sampled subsets of the initial training set for increasing the output variance of the learning model, and later, explaining the information gained thought this strategy via proper confidence intervals for specific confidence levels. Both artificial and real-life datasets were evaluated, performing robust behavior without inducing much computational effort.

Besides DNNs, conventional ML algorithms need to overcome the long-standing obstacle of explainable predictions. One of the most popular libraries is LIME (local interpretable model-agnostic explanations) [33], which offers explanations based on local assumptions regarding the contribution of the examined learning model. A proper function that measures the interpretability and the local fidelity is defined, which is optimized using sparse linear models that are fed with perturbed samples from the region of interest. Global patterns are taken into consideration in the [32]. A framework of teacher-student models was proposed in Reference [34], where the corresponding explanations are obtained through adopting some additional models that mimic the behavior of the target black-box model and compare their performance on ground-truth trained models to clarify possible bias factors or reveal cases where the missing information has corrupted the final predictions. Because of the behavior of the adopted models, the confidence intervals are also produced for determining the importance of the detected differences.

Linear models and ensemble of trees were used in the previous work, while a solution that exploits some unsupervised mechanisms internally and focuses on exporting small, comprehensible, and more reliable rules exploiting ensemble of tress was proposed by Mollas et al. [35]. Both quantitative and qualitative investigation of the proposed LionForests approach has been taken place regarding Random Forest (RF) over binary classifications tasks, which is categorized as a local-based one. Another work that investigates classification tasks, but specializes in interpreting convolutional neural networks (CNNs) was recently demonstrated in Reference [36], where the Layer wise Relevance Propagation strategy was applied for extracting meaningful information when usual image transformations of audio signals are given as input. This process has been widely preferred for such networks, trying to propagate the computed weights of the total network to the input nodes, transforming them to important indications.

As it regards the adopted XAI framework by our side in the context of this work, Shapley values that stem from coalitional game theory constitute the basic concept that a more recent approach, named as Shapley additive explanations (SHAP), seems to satisfy better our research scope [37]. First, it is based on well-established theory and operates without violating a series of axioms: Efficiency, symmetry, dummy, and additivity. Without providing any extended analysis, we mention that Shapley values provide helpful insights by measuring the contribution of each feature into the original d-dimensional feature space *<sup>F</sup>* <sup>∈</sup> <sup>R</sup>*d*. Although this process demands quadratic computations regarding the size of *F*, it is an accurate and safe manner for revealing the actual contribution of each feature taking into consideration all the underlying dependencies of the measured values, thus assigning a combined profile of both local and global explanations. The exact formula for computing the total contribution of a random feature *i* ∈ *F* through all the necessary weighted marginal contributions is given here:

$$Incontribution\_i = \sum\_{S \subseteq F \backslash i} \frac{|S|!(d-|S|-1)!}{d!} (payout\_{S \cup i} - payout\_S) \tag{1}$$

$$Input\_F = \int \mid model(F) dF\_{feature\ \#F} - E\_F(model(F)) \tag{2}$$

where each pay-out integrates the predictions of the selected model for any feature that belongs to the feature space *F*, while the rest ones are replaced by their mean value. In total, the Shapley values express the contribution that corresponds to each feature regarding the difference of the predicted value minus the average predicted value. Modifications that are more carefully implemented for obtaining the SHAP values reducing the overhead of the original procedure based on statistical assumptions or exploiting the nature of the base learner. Two such variants were adopted for facilitating the total efficacy of our methodology [26].

### **3. Related Work**

Semi-regression has not been sufficiently implemented in the domain of EDM, as evidenced by a thorough study of the pertinent literature. Apparently, SSL classification algorithms cannot be directly applied for regression tasks, due to the nature of the target attribute, which is a real-valued one. Nevertheless, some recent and notable studies are discussed below.

Nunez et al. [38] proposed an SSR algorithm for predicting the exam marks of fourth-grade primary school students. The dataset comprised a wide range of students' information, such as demographics, social characteristics, and educational achievements. At first, the Tree-based Topology Oriented Self-Organizing Maps (TTOSOM) classifier was employed for building clusters exploiting all available data. These clusters were subsequently used for training the semi-regression model, which proved quite effective for handling the missing values directly without requiring a pre-processing stage. The experimental results demonstrated that the proposed algorithm achieved better results in terms of mean errors, compared to representative regression methods. Kostopoulos et al. [39] designed an SSR algorithm for predicting student grades in the final examination of a distance learning course. A plethora of demographic, academic, and activity attributes in the course Learning Management System (LMS) were employed, while several experiments were carried out. The results indicated the efficiency of the SSR algorithm compared to familiar regression methods, such as linear regression (LR), model trees (MTs), and random forests (RFs).

Bearing in mind the aforementioned studies and their findings, an attempt is made in the present study to implement an SSR algorithm for predicting the grades of undergraduate students in the final exams of a one-year online course offered by the Hellenic Open University. The main contribution of our research concentrates mainly on the following points:


We also include some related works that concern the SSR field, which tackle problems from different domains. Besides the COREG algorithm [25], which inspired most of the upcoming SSR works on how to exploit unlabeled data for performing SSL methods for predicting numeric target attributes, the use of a co-training scheme did not found great acceptance for SSR works. We highlight just the direct

expansion of COREG designed by Hady et al., via inserting the co-training by Committee for Regression (CoBCReg) scheme [40], which tries to encompass the use of more than one regressors for reducing noisy predictions, as well as the co-regularized least squares regression approach (CoRLSR) [41]. The latter one sets a risk minimization problem on the combined space of labeled and unlabeled data through proper kernel methods, focusing mainly on proposing some variants—a semi-parametric and a non-parametric—that scale linearly on the size of the unlabeled subset. The predictive benefits of adopting the co-training scheme without using any sophisticated feature split, just a random one, were remarkable.

More recently, a local linear regressor was employed by Liang R.-Z. et al. [42], which was iteratively applied for minimizing a joint problem on the neighborhood of each unlabeled examplFDe through sub-gradient descent algorithms. The authors of this work transformed two datasets that stem from unstructured data into structured problems and managed to outperform the compared algorithms regarding each posed performance metric, managing a competitive behavior regarding the time consumption. A multi-target fashion SSR model was presented in Reference [43], where the self-training scheme was combined with an efficient ensemble decision tree-based algorithm. Several modifications of the proposed scheme were examined, differentiated on the manipulation of the decisions that are drawn from the corresponding ensemble learner. Although their approach depends heavily on a reliability threshold which is domain-specific, a qualitative analysis was made over a dynamic selection, managing to outperform the supervised baseline as well as a random strategy for selecting unlabeled data for augmenting the initially collected data. Finally, an SSR method was used before applying an SSL method in the field of optical sensors, where limited data were readily available. In that scenario, a randomized method was used for generating unlabeled artificial data aiming at augmenting the labeled subset, but their annotation with pseudo-values was still crucial [44]. Therefore, a typical SSR strategy was applied before providing the finally created dataset to tackle the classification process.

### **4. Dataset Description**

The dataset used in the research was provided by the Hellenic Open University and comprised records of 1073 students who attended the 'Introduction to Informatics' module of the 'Computer Science' course during the academic year 2013–2014.

These records were collected from three different sources, the course database, the teachers, and the course LMS, thus producing three distinct sets of attributes (Figure 2):

• The demographic set S1 = {Gender, NewStudent} (Table 1).

The distribution of male and female students was 76.5% and 23.5%, respectively. In addition, 87.5% of the students had enrolled in the course for the first time, while the rest failed to pass the previous year's final exams.






### **Table 3.** LMS activity attributes in the i-th time-period, i ∈ {1, 2}.


**Figure 2.** Gathering the data during the academic year.

Each instance of the dataset represents a single student (Figure 2) and is described by a vector of attributes, such as *x* = (*s*1,*s*2,*s*3), where *s*1,*s*2,*s*<sup>3</sup> correspond to the vector attributes of S1, S2, S3 sets, respectively. Since the early prognosis of students at risk of failure is of utmost importance for higher education institutions, the academic year was divided into four time-periods according to each written assignment submission deadline (Figure 3). To this end, the notation V1i denotes the total number of student views in the pseudo-code forum in the i-th period, i ∈ {1, 2}, and so forth. For example, attribute P21 refers to the total number of student posts in the compiler forum in the first time-period (i.e., from the beginning of the academic year until the first written assignment submission deadline). Finally, the output attribute *y* ∈ [0, 10] represents the grade of students in the final examinations of the course. Note that we examine two distinct scenarios, corresponding to the first one and the first two time-periods, respectively.

**Figure 3.** Time-periods of the academic year.

### **5. Proposed Semi-Supervised Regression Wrapper Scheme**

Semi-Supervised Learning (SSL) is a rapidly evolving subfield of ML, embracing a wide range of high-performance algorithms. Typically, an ML model *h* is built from a training dataset *D* = *L* ∪ *U* consisting of a small pool of labeled examples *L* = *xi*, *yi l <sup>i</sup>*=<sup>1</sup> and a large pool of unlabeled ones *U* = {*xi*} *u <sup>i</sup>*=1, *<sup>l</sup>* << *<sup>u</sup>*, *xi* ∈ X, *yi* ∈ Y, *<sup>L</sup>* <sup>∩</sup> *<sup>U</sup>* <sup>=</sup> <sup>∅</sup>, without human intervention [45]. Depending upon the nature of the output attribute SSL is divided into two settings [9]:

> .


In our case, we employed an SSR scheme for exploiting the existence of both labeled and unlabeled data trying to acquire accurate estimations of the target attribute—students' final grade—based on a set of readily available data. Thus, one or more regressors are trained iteratively via selecting the most appropriate unlabeled data and annotating their missing target value in an automated fashion. Of course, the initial hypothesis is formatted on the manually gather the subset of *L*. Furthermore, the fact that the training set is split into two disjoint subsets, *L* and *U*, and that we aim at applying our trained model on another subset—the test set—which does not overlap with the training set leads us to an inductive SSR algorithm.

The most representative algorithm found in the literature that seems to satisfy our ambitions is the COREG that was firstly proposed by Zhou [25]. Actually, this learning scheme constitutes the analog of the co-training scheme also based on disagreement rule in the case of SSC [46], inserting a local-based criterion for measuring the effectiveness of the candidate unlabeled instances into the currently trained model for completing a regression task. Although various criteria have been designed in the context of SSC [47,48], the corresponding essential stage during an inductive SSR algorithm has not been highly studied by the related research community, following variants of the same criterion proposed in the case of COREG or proposing some new metrics that are mainly used under single-view works [44,49,50].

More specifically, the main concern of inductive SSR algorithms during the annotation of unlabeled examples is their *consistency* with the already existing labeled instances. This property is examined by measuring the next formula:

$$\text{Consistency}\_{\mathbf{x}\_{j}} = \sum\_{\mathbf{x}\_{i} \in L} \left( f(y\_{i\prime}h(\mathbf{x}\_{i})) - f\left(y\_{i\prime}h(\hat{\mathbf{x}}\_{i})\right) \right), \forall \mathbf{x}\_{j} \in \mathsf{U} \tag{3}$$

where *f* is a suitable performance metric, *yi* is the actual value of the *xi* labeled example, while *h*(*xi*) and <sup>ˆ</sup> *<sup>h</sup>*(*xi*) denote the output of regressor *<sup>h</sup>* when is trained solely on the current labeled set and on the augmented labeled set with the currently examined *xj* example, respectively. According to the COREG algorithm, a local criterion is inserted for investigating if the consistency of each unlabeled example is

beneficial for the current model per iteration. Thus, instead of examining the whole current *L* subset, only the neighbors of each *xj* ∈ *U* are considered for measuring the corresponding consistency metric, which is described in Equation (1). As it is discussed in the original work of the COREG, by maximizing this variant—mentioned hereinafter as δ*xj* ∀*xj* ∈ *U*—we reach safely either to the maximization of the general consistency metric or we acquire a zero value. In the first case, we pick the *j*-th unlabeled instance with a greater impact. Otherwise, we do not select any of them.

This strategy is similar to fitting an instance-based algorithm, like the k-Nearest Neighbors (kNN) [51], for selecting the unlabeled instances to augment the current labeled set per iteration, as it was preferred during the COREG approach. However, this fact does not hinder us from applying different kinds of regressors on the augmented labeled set, thus exploiting possible advantages of other learning models for capturing better the underlying relationships of the examined data. Based on our search in the literature, such a study has not yet been done.

Moreover, the already mentioned augmented per iteration labeled subset does not contain exclusively accurate values of the target attribute per its included instance, since during the training stage pseudo-labeled instances are joining the initially labeled examples, and their estimated values may differ from the actual one. This kind of noise into any SSL scheme may heavily deteriorate their total performance, settling them as myopic approaches that cannot guarantee safe predictions and violate the interpretation of the exported results.

Therefore, to alleviate the inherent confidence of COREG, we examine its efficacy on an EDM task that supports the multi-view description, increasing, thus the diversity of the trained regressors. Since the COREG algorithms is based on the co-training scheme, the feature space *F* of the original problem *D* is split into two disjoint views: *F* = *F*<sup>1</sup> ∪ *F*2. Although the random split has been proven quite competitive in several cases [52,53], co-training scheme should work if these two views are independent and sufficient.

The examined real-world problem brings a multidimensional and multi-view description that encourages us to train each regressor on separate views and get trustworthy predictions that would not harm our learning model regarding neither its predictiveness nor its interpretability despite the limited labeled data. Algorithm 1 presents the pseudocode of the end-to-end SSR pipeline.

```
Algorithm 1. The extended framework of the COREG algorithm.
```

```
Framework: Pool-based COREG(D, selector1, selector2, regressor1, regressor2)
Input:
    • Initially collected labeled L = 
                                    xi, yi
                                         l
                                          i=1 and unlabeled U = {xi}
                                                                   u
                                                                   i=1 instances, where D = L ∪ U and L ∩ U = ∅
    • F1, F2: provide the split of the original feature space F, where F = F1 ∪ F2 and F1 ∩ F2 = ∅
    • Define Max_iter: maximum number of semi-supervised iterations and f: performance metric
Main process:
    • Set iter = 1, consistentSet = ∅
    • Train selectori, regressori on L(Fi) ∀i ∈ {1, 2}
    • While iter ≤ Max_iter do
    • For each i ∈ {1, 2} do
    • For each xj ∈ U do
    • Compute δxj
                            (f) based on selectori ∀i ∈ {1, 2}
    • If δxj
                    (f) > 0 : add j to consistentSet
    • If consistentSet is empty do
    • iter:= iter + 1 and continue to the next iteration
    • else do
    • Find the index j
                               ∗ of consistentSet s.t. j
                                                    ∗ = arg maxj δxj
    • Update U : U ← U –

                                      xj
                                        ∗

    • For i {1, 2} do
    • Update Li : Li ← Li ∪ { xj
                                            ∗ , regressor ∼ i

                                                           xj
                                                             ∗

                                                              }, where ~i means the opposite index of the current
    • Retrain selectori, regressori on L(Fi)∀i ∈ {1, 2}
    • iter:= iter + 1
Output:
    • Apply the next rule to each met xtest instance:
                             hCOREG(xtest) = 1/2 · (regressor1(xtest) + regressor2(xtest))
```
### **6. Experimental Process and Results**

To conduct our experiments, we exploited the sci-kit Python library along with its integrated regressors and an implementation of computing the necessary Shapley values [37,54]. In order to systematically examine the efficiency of the extended COREG variant over the problem of early prognosis on student's performance, various choices of instance-based selectors and different learning model for the case of the regressors were chosen. Furthermore, we investigated two separate cases of the total dataset based on the measured indicators: Regarding only the first semester (*D*1-first scenario) and only the first two semesters (*D*2-second scenario). Thus, our predictions excuse the characterization of the early prognosis task, providing in time predictions using indicators that stem from the initial stages of an academic year. To be more specific, the size and the attributes of each view per dataset-scenario are reported here:

• First scenario:

$$D\_1 = F\_1 \cup F\_2$$

$$|F\_1| = 4, \; F\_1 = \text{(gender \textquotedblleft NewStudent\textquotedblright)} \cup \text{(Csc}\_1, \text{Wri}\_1)$$

$$|F\_2| = 10, \; F\_2 = \left(L\_{1\prime}V\_{11\prime}V\_{21\prime}V\_{31\prime}V\_{41\prime}V\_{51\prime}V\_{51\prime}P\_{11\prime}P\_{21\prime}P\_{31\prime}P\_{41\prime}\right)$$

• Second scenario:

$$D\_2 = F\_1 \cup F\_2$$

$$|F\_1| = 6, F\_1 = \text{(gender, NewStudent, Occs1, Wri}\_1, \text{Ocs}\_2, \text{Wri}\_2)$$

$$|F\_2| = 20, F\_2 = (L\_1, L\_2, V\_{11}, V\_{12}, V\_{21}, V\_{22}, V\_{31}, V\_{32}, V\_{41}, V\_{42}, V\_{51}, V\_{52}, P\_{11}, P\_{12}, P\_{21}, P\_{22}, P\_{31}, P\_{32}, P\_{41}, P\_{42})$$

Besides the multi-view role of our extended COREG framework, the diversity of the SSR algorithm is enriched by the fact that each selectori cannot select during one iteration the same *xj* <sup>∗</sup> instance, while during the initial design of the COREG, randomly selected subsamples of the original *U* set were selected per iteration. Although we also attempted to implement this strategy, our results were constantly worse than the case of exploiting the full length of the original *U* set. This is probably due to the relatively small size of our total problem *D*, which we hope to undertake during the next semesters to enrich our collected data.

As it regards the choice of the investigated selectors and regressors for the extended COREG framework, we mention here all the different variants/models that were included in our experiments:


serves our ambitions nor any great improvement was achieved. More information could be found in Reference [41].

As it concerns the rest required information about our evaluations, we set *Max\_iter* equal to 100 and the performance metric f ≡ MSE (Mean Squared Error). Moreover, we applied a 5-fold-Cross-Validation (5-fold-CV) evaluation process, while we held 100 instances out of the 1073 as the test set. Consequently, the rest *n* = 973 instances constitute the *D* set, where the size of the *L* (*l*) and the *U* (*u*) subsets sum up to *n*. Thus, we examined four different values of the initially labeled instances: 50, 100, 150, and 200, while all of the rest instances were exploited from the first iteration as the *U* subset, since, as already mentioned, a possible random sampling of the total *U* subset per iteration did not favor us. Finally, the scenario under which our selectors exploit kNN algorithm with (k1, k2) = (1,1) did not manage to detect instances that satisfy the restriction of consistency as described in Equation (3) in the majority of the conducted experiments, and for this reason, was excluded by our results. The performance of the examined COREG variants based on the mean absolute error (MAE) metric is presented in Tables 4 and 5.


**Table 4.** Relative improvement of mean absolute error (MAE) metric (±std) of the dataset based only on the first semester during the best iteration per different combination of selector and regressor.



To be more specific, in these tables, we have recorded the relative improvement between the performance of each regressor during the initially provided labeled set, and the iteration that recorded the best performance until the criterion of either exceeding the *Max\_iter* or not satisfying the consistency is violated. The results indicate that there is a decrease in the MAE metric, whilst the number of labeled instances is increased, as could be expected. Based only on the information regarding the first semester, it is noticed that the best performers are LR and MLP for size(*L*) = 50, while the tree-based learners achieved a more stable improvement over all the examined initially labeled subsets. Based on the information regarding both the first and the second semester, it is observed that the best performers are again LR and MLP for size(*L*) = 50, while they also performed greater improvement in the rest of the examined scenarios against their behavior on the previous case.

Additionally, we observe that as the cardinality of the *L* subset increases, the relative improvement of the investigated multi-view SSR approaches is decreasing in both cases during the majority of the recorded results. Through this kind of information, we can understand better the benefits of SSR approaches like COREG when multi-view problems are considered even under both limited labeled data are provided, and the volume of the unlabeled data is also highly restricted, reducing, thus the informativeness of this source of knowledge which is crucial for SSL scenario. Hence, the most important asset of transforming the COREG approach into a multi-view SSR variant is the remarkable reduction of the mean absolute error under strict conditions regarding the initially provided labeled instances. Despite the fact that the supervised learning performance in that cases is usually poor, since it heavily depends on the initially labeled data, both the insights that are obtained through the distinct, independent views and the disagreement mechanism that interchanges information between regressors that are fitted to these views lead to superior performance against it. Therefore, we believe that this indication is our most important contribution: Proof that in a real-life scenario, the complementary behavior of two separate views can be a trustworthy solution—even under highly limited labeled instances and not a large pool of unlabeled ones.

Another key is the fact that by mining additional unlabeled instances, we would expect even larger improvements in some cases, something that occurs by observing the fact that some approaches achieved their best performance at the late iterations, while almost none approach recorded its best performance during the early iterations. Thus, we are confident that by providing additional unlabeled instances, even better improvements should be achieved. Another interesting point that should be examined in the future is to insert a dynamic stage for terminating such a learning algorithm, avoiding saturation phenomena. A validation set could be useful, but small cardinality in a real-life dataset does not favor such a strategy.

Furthermore, in the majority of the presented results, we conclude that when the selectors coincide with the two 3NN algorithms, larger improvements of the relative error are recorded, especially for the more accurate models: GB-based variants and MLP. This happens due to the fact that in the majority of the cases that one selector coincides with the 1NN algorithm, this view through its fitted regressor does not detect any unlabeled instance that satisfies the consistency criterion. Hence, the other view is not actually enriched via the existence of annotated unlabeled instances. However, in the case of weaker regressors—3NN and LR—this behavior may be proven beneficial when noisy annotations take place, reducing, thus the chances of degeneration. To be complete with our experimental procedure, all our results are included in the following link: http://ml.math.upatras.gr/wp-content/uploads/2020/ 11/mdpi-Applied-Sciences-math-upatras-2020.7z, where the index of the best position per examined fold along with the improvement during the arbitrarily selected value of *Max\_iter* are recorded per regressor based on the separate views, as well as the finally exported one. Furthermore, the supervised performance of the whole dataset *D* for both cases and each investigated regressor, as well as their performance on all the four separate initial versions of the L size, are included—facilitating each interested researcher about the efficacy of our approaches.

Regarding the interpretability of our results, we computed the Shapley values of each one of the five distinct regressors. To safely conclude that the COREG scheme can produce trustworthy explanations under the existence of limited labeled data per different learner, we made the next assumptions: We compared the purely supervised decisions of the total dataset evaluated with the

aforementioned 5-fold-CV process per learner with the corresponding decisions that are exported by training the same regressor on the finally augmented *L* subset according to the adopted COREG scheme having fixed the choice of selector to (3NN, 3NN) with the pre-defined distance metrics as mentioned previously into this Section. Hopefully, in all the cases, we obtained similar enough decisions regarding the importance weights assigned to each indicator, while we had a perfect match between the ranking of the indicators. This fact verifies our main scope: To apply a multi-view SSR scheme that can improve the initial predictiveness of the model despite the limited number of the provided instances, acquiring at the same time trustworthy explanations about the importance of each included attribute.

Next, we present through suitable visualizations the SHAP values per case, exploiting the implementation provided by the authors of Reference [55]. Before we step to this stage, a short description is given regarding the two used approaches for computing these explainable weights that approximate the actual, but still computationally hungry Shapley values. First of all, a kernel-based approach was applied over all the five examined regressors (KernelSHAP), which is agnostic regarding the applied learning model and introduces a linear model that is fitted over the sampled pairs of (data, targets) and their generated weights. To generate these weights, several coalitions over the *F* space is produced, while the marginal distribution instead of the accurate conditional distribution is sampled for replacing the features that are absent during a random coalition. Although the assumptions here may lead to poor results because of the randomly selected coalitions that ignore some feature dependencies, the fact that a linear regression is applied during the last stage of the computation, additional strategies may easily be implemented trying to smoothen possible defects of this approximation (regularization, different learning model). On the other hand, a tree-based approach (TreeSHAP) has been applied in the case of GB-based approaches trying to figure out possible discrepancies between the explanation of this kind of learner. TreeSHAP constitutes an expansion of the KernelSHAP approximations, leading to faster results and facilitating the learners that are based on Decision Trees, integrating aggregating behavior through proper additive properties. Further information is provided in the original work [55].

We present here only the corresponding diagram of GB (ls) with both SHAP explainers, ignoring the similar enough performance of GB (huber), since it is the only tree-based regressor. The SHAP visualization plots (Figures 4–8) illustrate the attribute impact on the output of the produced regression model (the attributes are ranked in descending order from top to bottom) and how the attribute values impact the prediction (red color correlates to positive impact) in the first scenario using the *D*<sup>1</sup> dataset. Attributes Wri1 (grade in the first written assignment), Ocs1 (presence in the first optional contact session), and V31 (number of views in the module forum) are the most important ones in all cases regardless of the regressor employed. In addition, these attributes seem to positively influence the target attribute (i.e., student grade in the final examinations). Therefore, high-achieving students in the first written assignment, students with high participation rates in the first optional contact session, and students with high view rates in the module forum achieve a higher grade in the final course exam. Very similar results were produced regarding the second scenario using the *D*<sup>2</sup> dataset. In this case, attribute Wri2 (grade in the second written assignment) proved to be the most significant, along with attribute V32 (number of views in the module forum).

**Figure 4.** KernelSHAP values of the 5NN regressor (*D*<sup>1</sup> dataset).

**Figure 5.** KernelSHAP values of the LR regressor (*D*<sup>1</sup> dataset).

**Figure 7.** TreeSHAP values of the GB (ls) regressor (*D*<sup>1</sup> dataset).

**Figure 8.** KernelSHAP values of the MLP regressor (*D*<sup>1</sup> dataset).

### **7. Conclusions**

In the present study, an effort was made to build a highly-accurate semi-supervised regression model based on multi-view learning for the task of predicting student grades in a distance learning course. Additionally, we sought to gain insights and extract meaningful information from the model interpreting the predictions made and providing computed explanations about the predicted grades. The experimental results demonstrate the benefits brought by a natural split of the feature space. Therefore, our work contributes a different perspective to the existing single-view methods by fully exploiting the potential of different feature subsets by extending the COREG framework to the multi-view setting. In addition, it points out the importance of specific attributes that heavily influence the target attribute. Finally, the produced learning model may serve as an early alert tool for educators aiming at providing targeted interventions and support actions to low performers.

Generating synthetic data could be proven a highly favoring technique for mitigating the problem of limited labeled data. A recent demonstrated work has adopted such a strategy for training a boosting variant of the self-training scheme in the context of SSC [56]. In that work, the aspect of Natural Neighbors was preferred applying kNN algorithm as the base classifiers, and their obtained results seem encouraging enough for trying to extend their work also in our case. Another future direction could be applying pre-processing stages that may help us discriminate better the initially gathered data. Combination of semi-supervised Clustering either with conventional learners or ensembles, or even DNNs, as it has been validated in other real-life cases (e.g., geospatial data [57], medical image classification [58]) reducing inherent biases and helping us to uncover better possible underlying data relationships before the learning model could be found quite useful in practice. Another one possible effect of Clustering has been highlighted in Reference [50], where this strategy facilitated the scaling of a time-consuming learner over large volumes of unlabeled examples.

Finally, the strategy of transfer learning has been found great acceptance in the last years over several fields and could be proven beneficial in the case of EDM tasks. The two different aspects of this combination are expressed through either creating pre-trained models based on other learning tasks or enriching the discriminative ability of selected regressors through separate source domains that contain plentiful training data [59,60]. Combination of Active Learning with Semi-supervised learning might find great acceptance especially in cases that limited labeled data are provided, and the provided budget for monetization costs is highly bounded [61]. The modification also of transductive approaches for being considered under inductive learning scenarios seems a brilliant idea that compromises the accuracy of the former category and the generalization ability of the second one. Such a study was presented in Reference [62] and should be studied for SSR tasks.

**Author Contributions:** The authors contributed equally to the work. All authors have read and agreed to the published version of the manuscript.

**Funding:** This research received no external funding.

**Conflicts of Interest:** The authors declare no conflict of interest.

### **References**


**Publisher's Note:** MDPI stays neutral with regard to jurisdictional claims in published maps and institutional affiliations.

© 2020 by the authors. Licensee MDPI, Basel, Switzerland. This article is an open access article distributed under the terms and conditions of the Creative Commons Attribution (CC BY) license (http://creativecommons.org/licenses/by/4.0/).

## *Article* **Social Media Rumor Refuter Feature Analysis and Crowd Identification Based on XGBoost and NLP**

**Zongmin Li 1, Qi Zhang 1, Yuhong Wang <sup>2</sup> and Shihang Wang 1,\***


Received: 12 May 2020; Accepted: 6 July 2020; Published: 8 July 2020

### **Featured Application: Results of this work can be applied to anti-rumor microblog recommendation decisions for social media platforms, in order to reduce the impact of rumors by promoting the spread of the truth.**

**Abstract:** One prominent dark side of online information behavior is the spreading of rumors. The feature analysis and crowd identification of social media rumor refuters based on machine learning methods can shed light on the rumor refutation process. This paper analyzed the association between user features and rumor refuting behavior in five main rumor categories: economics, society, disaster, politics, and military. Natural language processing (NLP) techniques are applied to quantify the user's sentiment tendency and recent interests. Then, those results were combined with other personalized features to train an XGBoost classification model, and potential refuters can be identified. Information from 58,807 Sina Weibo users (including their 646,877 microblogs) for the five anti-rumor microblog categories was collected for model training and feature analysis. The results revealed that there were significant differences between rumor stiflers and refuters, as well as between refuters for different categories. Refuters tended to be more active on social media and a large proportion of them gathered in more developed regions. Tweeting history was a vital reference as well, and refuters showed higher interest in topics related with the rumor refuting message. Meanwhile, features such as gender, age, user labels and sentiment tendency also varied between refuters considering categories.

**Keywords:** rumor refuter; machine learning; nature language processing; XGBoost; feature analysis

### **1. Introduction**

Because of the widespread popularity of social networks and mobile devices, users are able to immediately exchange information and ideas or access news reports through social media feeds such as Twitter, Facebook, or Sina Weibo [1]. However, the dark side of online information behavior should not be neglected. Due to a general lack of control, incorrect, exaggerated, or distorted information can be easily circulated throughout the networks [2]. This kind of information is defined as a rumor [3] as it does not have publicized confirmations nor official refutations. In all the controversial news stories since Twitter's inception, the rumors were found to reach more people and spread deeper and faster than the actual facts [4]. Rumors have been found to affect a country's public opinion [5], lead to economic losses [6], and even cause political consequences [7]. When the sudden crisis broke out, online rumors were even more popular, seriously disrupting social stability. For example, since January 2020, the epidemic of new coronaviruses has spread, and rumors have emerged on the Internet, causing public panic and anger and intensifying social conflicts.

Combating rumors has been a hot research area. A lot of research focuses on the rumor itself—the identification [8], spread [9–11], and influencing factors of the rumors; while deep-rooted human nature are the main factors for the viral spread of the rumor, that is, people tend to read/share tweets that confirm their existing attitudes (selective exposure), regard information that is more consistent with their pre-existing beliefs as more persuasive (confirmation bias), and prefer entertaining and incredible content (desirability bias) [4]. Existing research on the participants of the rumor is mainly aimed at influential individuals [12] in social networks. At the public level, crowd identification of rumors participants is still worth further study.

When people receive a piece of 'news', they may (1) retweet and comment at the same time, or only retweet (spreaders), (2) deny it and spread a corresponding rumor refuting message (refuters), (3) only comment on it or neglect it (stiflers). Individuals' behaviors are closely related with their attitudes [13]. Lewandowsky et al. believed that the same rumor refutation information should be changed for different opinions and angles according to the characteristics and thinking patterns of different groups of people, avoiding sensitive positions such as political positions and world views [14]. Therefore, given a rumor category, analyzing the characteristics of voluntary refuters, and identifying the special group from all rumor participants, make it possible to design targeted rumors refutation strategies based on the characteristics and thinking patterns of refuters. The application value is that the platform can consciously recommend rumor refutation information to them, even adapt the information to suit their personality. It is of great significance for expanding the acceptance of real news and suppressing the spread of rumors. Understanding the content of rumor refutation is a re-learning process, with great subjectivity and group differences. Therefore, netizens featuring analysis and crowd identification are critical to breaking through rumor governance difficulty. Recently, the rapid developments in deep learning and machine learning methods make it possible to extract and process large amounts of unstructured social media data [15–17], so as to identify different crowds and extract group features. This research topic largely remains unexplored. The only prior work was from Wang et al., who predicted social media rumor refuters only in the disaster category [18].

This study intends to reveal the features of netizens who are willing to retweet rumor refutations (refuters) without extra incentives when confronting rumor refuting messages and user features, and propose a rumor refuter crowd identification model. Five main rumor refuting microblog categories are considered that can potentially affect social stability: economics, society, disaster, politics, and military [19]. Similarities and differences of rumor refuters in these five categories are compared.

Natural Language Processing (NLP) and XGBoost are the main tools in this research. NLP is a subfield of computer science, information engineering, and artificial intelligence, and is concerned with programming computers to process and analyze large amounts of human (natural) languages data [20]. Although NLP is already a mature technology, as far as the authors know, the short text similarities and sentiment analysis have not been well-applied in combating rumors, especially associating them with rumor refuting behaviors. Baidu NLP [21] will be applied to quantify the user microblog content's sentiment (recent sentiment tendency) and similarity with original rumor refuting message (recent interests) as a value between 0 and 1, which can also be viewed as a probability. The higher the value, the higher the probability that the sentiment of the microblog is positive or the microblog content is the same as the rumor refuting message.

XGBoost [22] is a relatively new algorithm that has gained popularity due to its accuracy and robustness. XGBoost utilizes boosting, which trains each new instance to emphasize the training instances previously mis-modeled for better classification results. It is a combination of classification and regression trees (CART) [23], but re-defined the objective function with both training loss and complexity of the trees to decrease the chance of overfitting. Thus, XGBoost is a very strong model with high extensibility.

In recent years, XGBoost has been widely applied to practical problem solving [24,25]. Wang et al. have proved XGBoost was found to be the most efficient machine learning method for disaster rumor refuter identification compared with logistic regression, support vector machines, and random

forest [18]. Therefore, this paper chooses XGBoost to construct the potential rumor refuters identification model.

The main contributions of this paper can be summarized as:

(1) The focus of social media users (instead of only considering influential individuals) in the rumor refutation process.

(2) Feature analysis of rumor refuters for five different categories of rumors, which can provide guidance on the personalized recommendation for social media users by accelerating the rumor refutation information dissemination.

(3) XGBoost based identification model to identify the rumor refuters and extract significant features of rumor refuters.

The remainder of this paper is organized as follows. Section 2 gives our research motivations. Section 3 shows the methods and results and Section 4 gives the discussion. Section 5 concludes the work and discusses future research applications.

### **2. Motivations**

Research motivation lies in two aspects.

### *2.1. Decision-Making Support to Rumor Countermeasures*

Identifying rumor refuters based on their features is quite valuable such that social media platforms can recommend rumor refuting microblogs or messages to them as this group is more likely to spread the anti-rumor information and accelerate rumor refutation [18,26]. Although there tends to be far fewer people refuting than spreading rumors [4], this ordinary refuter crowd is considerable still. Due to the potential risk of rumors, it is necessary to develop restraining countermeasures. Most identified countermeasures have been focused on blocking the rumors and spreading the truth [26]. Current practice has tended to seek to identify the influential nodes or opinion leaders to refute rumors, but has neglected the significance of the netizens willing to retweet rumor refutations without extra incentives to convince irrational followers. Therefore, from the perspective of accelerating the truth dissemination process, this paper employs feature analysis and voluntary rumor refuter crowd identification under the hypothesis that if these targeted users can be identified and taken advantage of, it is possible to gain new insights into internet rumor countermeasures.

#### *2.2. Adapting User Features into Rumor Control*

Many studies have attempted to identify how the unique features of social network users influence social media behavior. It was essential for personalized recommendation systems to detect the accurate and targeted user properties [27]. There were multiple social network identities such as microblog authors, stiflers, and retweeters. For example, some researchers collected potential author attributes such as gender, age, regional origin, and political orientation and found some feature-based differences [28]; others differentiated the features of stiflers and retweeters, and concluded that the stiflers were more concerned about social relationships and the retweeters were more driven by message content [5].

All of this prior research contributed to this paper's feature set construction. In addition, user retweet histories, status, active time, and interests also impacted retweet behavior. Hence, the similarities between the content of the target tweet and past retweeter posts [29] and users' subjective feelings were also determining factors [30]. However, few works have been done in adapting those user analyses into rumor control and refuter identification. Previous investigations have involved social media platforms with different user structures (i.e., Twitter and Facebook), but the conclusions could not be generalized to microblog users. Overall, few studies on retweeter attributes have commented on the distinctions between the different original microblog types; therefore, this paper analyzes the refuter features based on anti-rumor classifications.

### **3. Methods and Results**

In this section, we present the methods and relevant results in detail. The overall framework of the methods is shown in Figure 1. Firstly, the data are collected and cleaned. Then, the gender and label frequency comparisons between refuters of five categories are made, which form a rough refuter portrait. Thirdly, sentiment analysis and short text similarity analysis are made. If there are missing values in microblogs/label information/verified information/signature, a value of 0.5 will be assigned. Based on the trained XGBoost classification model, the refuter feature analysis is conducted.

**Figure 1.** The overall framework of the methods.

### *3.1. Data Collection*

Sina Weibo is a Chinese microblogging (Weibo) website and is one of the most popular social media platforms in China with 431 million active monthly users in 2019. Different from Wechat, which only allows a user to post to certified friends, Sina Weibo is the Chinese equivalent of Twitter as it has a wider, more open dispersal. Therefore, crawling microblogs on Sina Weibo has a high research value for rumor propagation or rumor refutation spread analyses. All of the anti-rumor microblogs with a retweet/comment amount larger than 100 were collected from October 2018 to April 2019 using a web crawler.

This paper only takes the anti-rumor microblogs verified, confirmed, and announced by official accounts (police accounts, government agency accounts, and authoritative media accounts) into considerations. Therefore, the refuters discussed in this paper are those who deliberately spread official accounts' rumor refutation information. As shown in Figure 2, the collected anti-rumor microblogs were classified into five categories based on content [18]; economics, society, disaster, politics, the military, all of which were the common rumors on social media platforms and could result in societal damage. The economic category contained business and entrepreneurial information; the society category covered rumors about social public affairs; the disaster category consisted of distorted information on natural and man-made disasters; the politics category comprised false political messages mainly involving certain political figures, groups, or specific policies; and the military category included rumors about national defense or military affairs.

These five main categories had a total of 106 anti-rumor microblogs, of which 45 were related with the society, 31 with economics, three with the military, 20 with politics and seven with disaster, with a total of 58,807 user samples. There were far more stiflers than refuters collected because the task of identifying the refuters from the population was inherently an imbalanced classification problem. As this research was simulating the refuter identification process and examining the validity of XGBoost model, testing on a small data subset was considered powerful enough to examine the algorithm's performance [31].

**Figure 2.** Microblog and sample quantities.

The users' most recent concerns were strongly associated with their most recent microblogs and our previous work found that the 11 most recently posted microblogs (topping microblogs are included) were reliable predictors in disaster rumor refuter prediction [18]. Except for a few users who had less than 11 microblogs since registration, 11 microblogs were extracted, i.e., the topping microblog (the sticky microblog) and 10 most recently posted microblogs. Although the topping microblogs might not have been recently posted, they were able to reveal the overall attitude of the users to some extent. Basic user information; gender, membership level, location, birthday, verified information, signature or brief introduction, user label, microblog number, number of followers and numbers following, and group numbers for each user; were extracted.

As the aim of this research was to identify the social media rumor refuter features, information was mined for two groups of people: refuters from the retweet lists and stiflers from the comment lists. Stiflers consist of the commenters and the users who only view the rumor refuting message. Due to the inaccessibility of the viewer list, only those commenters were treated as stiflers. Although both of refuters and stiflers had viewed the rumor refuting microblogs, the responses were quite different.

### *3.2. Comparison between Refuters*

#### 3.2.1. Gender

Figure 3 compares the gender differences for different categories. The gender gap was particularly large in the military and political fields, with the number of male refuters being nearly twice as many as females (roughly 150 men vs. 74 women and 1621 men vs. 888 women, respectively). In contrast, in the economic, disaster and society categories, there were only minor gender differences. In 2018, male users made up 57% of total Weibo users [32]. The results are correspondent with the Weibo user gender ratio.

**Figure 3.** Male to female ratio in database.

### 3.2.2. Label Frequency

A user portrait analysis was conducted based on the refuter label information. From the word frequency count, it was possible to roughly depict the refuter features and preferences.

As can be seen in Table 1, economics-related rumor refuters showed high interest in IT, Dig, and investment, with most being young practitioners in the internet or finance industries. The economic-related and politics-related rumor refuters had some common interests (i.e., military, investment, Finance and IT, and Dig), and could be the same group of people. For the military-related rumor refuters, the label "military" was third ranked, with interest also being shown in design and history. The society-related and disaster-related rumor refuters were also both interested in education, with the former group having a specific "campus" label and the latter group having a specific "employment" label. Based on this information, we infer that, for these groups, college students should account for a relative large proportion of refuters.


**Table 1.** Label frequency for the different rumor refuters.

#### *3.3. Rumor Refuters Identification*

A crowd identification process was applied in two steps.

Step 1. Convert the textual content into numerical values.

For rumors that are linked to specific geographical locations, we derived the locations from the original rumor texts. Then, comparing the locations where the rumor "took place" with the locations each user from, 1 would be assigned for the location feature if the user was in the same province as the rumor, and 0 otherwise.

Baidu's AipNLP [21], which is regarded as the most advanced Chinese text analysis technique, was applied to convert the textual content into numerical values. Then, the similarities between the user labels, the verified information, the signature, the most recent 11 microblog (including the topping microblog) contents, and the rumor refuting microblogs were transformed into values between 0 and 1. For the sentiment analysis, the emotional inclinations of the user signatures and the most recent 11 microblog contents were also converted into values between 0 and 1. The processed variables are listed in Table 2.

Additional implementations were applied to the variables to ensure the classification results were more valid and reliable:

(1) The corresponding rumor refuting microblogs were deleted if they were one of the 11 most recent microblogs from the user.

(2) A value of 0.5 was assigned to the microblogs/label information/verified information/signature sentiment or short text similarity analysis if the text was missing.

(3) Words irrelevant to the content of the text but that significantly influenced the result of the sentiment analysis, such as "Comment", "Like", and "Collect", were removed.


**Table 2.** Variables for refuter identification.

Step 2. The XGBoost model was utilized for refuter identification in the different categories.

In this research, the XGBoost model got a bi-classification task. Thus, people who forwarded the anti-rumor microblog was treated as rumor refuters and labeled 1, people who only commented but not retweeted was treated as rumor stiflers and labeled 0.

The samples were randomly divided into two parts, 80% for training the XGBoost model and the other 20% for testing the effect of the trained model. Two criteria; the F1 Score [33] and the AUC [34]; were applied for the classification result evaluation (as shown in Table 3). F1 score is a measure of a test's accuracy. It considers both the precision and the recall so that it is practical in classification tasks [33]. AUC is the probability of ranking the positive sample forward the negative sample whenever a positive and a negative samples are randomly selected. The higher the AUC, the better the classification result is [34].

With learning rate of 0.05, max\_depth of 10, subsample value of 0.8, scale\_pos\_weight corresponding to the proportion of positive and negative samples, and keeping all the other parameters as default, the model is trained by Python Xgboost package (num\_boost\_round = 300 and early\_stopping\_rounds = 50).

The efficiency of the XGBoost model can also be impacted by the number of samples. Therefore, for the disaster, economic, political, and societal categories and all samples, the amount of samples (randomly selected each time and not applied to the military category because there were only 1055 samples) was gradually increased to examine the influence of the number of samples on the classification results. During this process, different samples were applied for robustness testing and to determine the relationship between sample quantity and the F1 Score/AUC Score when the XGBoost model was applied, with the overall aim being to determine the number of samples needed to obtain a stable, available F1 Score/AUC Score.

The feature importance was also ranked using the XGBoost model to determine the most important refuter crowd identification features for the different rumor categories. For those most important features, *t*-tests were implemented, to identify which individual feature, for instance, number of microblogs or number of followers, was significantly different between refuters and stiflers.

Because the F1 Score and AUC curves were similar, for better observation, only the F1 Score curve was drawn. As it is shown in Figure 4, starting with 500 randomly selected samples, the F1 Score was observed to gradually increase and then, as sample quantity increased in all categories, it became stable at around 0.75 (except for the military category that had only 1055 samples and F1 and AUC scores of 0.65). Even when all sample types were included, the observations remained the same.

**Table 3.** AUC and F1 Scores for each category and all samples.


Therefore, it was concluded that, when plenty of data were provided, the XGBoost model was effective in identifying rumor refuters irrespective of the rumor category differences.

**Figure 4.** Refuter classification results based on XGBoost.

### *3.4. Feature Analysis between Refuters and Stiflers*

The XGBoost model also provided feature importance rankings (see Figure 5). Except for the disaster and political categories, gender was found to be the least important feature in the XGBoost classifications. However, the MSe11 and MSm11 (regard the topping microblog as the 1st microblog, and MSe11 and MSm11 refer to the sentiment value and similarity with the origin rumor refuting microblog of the 10th most recent microblog respectively) appeared to have the most important features for all categories.

It was therefore proposed that, if the user had the topping microblog and an emotional inclination and similarity to the original microblog, there would be an influence on the classification judgement. Therefore, samples with MSe11 and MSm11 not equal to 0.5 (i.e., samples with topping microblog) were extracted and their Mse1 and MSm1 were tested and the results are shown in Table 4. However, there were no significant differences between the refuter and stifler values for their MSe1 and MSm1 in the political-related, disaster-related and military-related categories. The economics-related and societal-related rumor refuter values for MSm1 were lower than those of the stiflers. There were no significant differences in the MSe1 values between the stiflers and refuters in the economics-related category but, for the societal-related category, the refuters' MSe1 values were higher than those of the stiflers.

The NOM and NOF were also found to be very important features. The *t*-test results in Table 4 found that the rumor refuter NOM and NOF values were somewhat higher than those for the stiflers, which indicated that refuters could be more active. The ML (another measurement of user activity) was also somewhat higher for the refuters in the economics, politics, and societal categories.

**Figure 5.** Feature importance in the different categories.


**Table 4.** *T*-test results for the rumor refuters and stiflers.

Note: (the value is shown in bold if there were significant differences between the refuters and stiflers under a 95% confidence level, "\_\_\_" means that the refuter value was lower than that of the stiflers; for users with a topping microblog, the MSe1 and MSm1 refuter and stifler values were compared.).

Except for the LSm in the societal category and the SSm in the political category, there were significant differences found for the VISm, LSm, and SSm between the refuters and stiflers at a 95% confidence level. For the economic, disaster, and military categories, the VISm and LSm of the refuters were significantly lower than those of the stiflers, while the refuters had higher SSm values. Similarly, for the societal category, the VISm of the refuters was significantly lower than that of stiflers, while the refuters had higher SSm values. In contrast, in the political category, the VISm and LSm of refuters were significantly higher than those of the stiflers.

The average values for the MSe (1–11) and the MSm (1–11) for the refuters and the stiflers in the 5 main categories were calculated and denoted MSe and MSm. As shown in Table 4, there were no significant differences found between the refuters and the stiflers for the MSe at a 95% confidence level, except for the economic category (the MSe of refuters was significantly lower than that of stiflers). The MSm of the refuters in all five categories, however, was higher than that of the stiflers, which indicated that the average short text refuter similarity degrees with the original rumor refuting microblogs were significantly higher than those of the stiflers.

According to Table 5, at the 95% confidence level, in the disaster, economic and society related rumor refuting microblogs, correlations were confirmed between user behavior (refute/stifle) and user location (whether in the same province in which the rumor-related event occurred); however, in the political category, no correlations were found.

**Table 5.** Chi-square test of contingency results between user behavior and user location.


As shown in Table 6, for users in the same location as the rumor refuting microblogs, the refuters were found to be less likely to retweet disaster, economic or society related rumor refutation information, with only 21.56%, 15.67%, and 23.77% of total viewers in the same province. One possible explanation is that these refuters know better about the local situation and do not feel the urge to spread truths. Therefore, the social media platform can recommend disaster, economic, or society related anti-rumor information to users not in the same location as the rumor refuting microblogs.



### **4. Discussion**

Based on feature analysis of users with different social media behavior, this study sought to identify the potential voluntary rumor refuter, and utilize them with the anti-rumor countermeasure: truth propagation and targeted immunization. Because of the growing popularity of social media and the availability of complete user information, it is possible to accurately obtain user features and therefore easier to identify the potential refuters. Thus, personalized recommendation services could be provided to trigger the spread of the truth, and thus enhance rumor refutation.

Although previous works have explored the features of retweeters, there have been few studies on utilizing these findings to combat rumor spread. This paper extended the scope of current studies, instead of studying the general features of retweeters or the opposite group, rumor spreaders, it focused on refuters and specified them with five main rumor categories that can affect social stability. Although both rumor spreaders and rumor refuters have the same behavior—retweet, their features were different. In contrast with the conclusion of Vosoughi et al. [4], in which the rumor spreaders were found to have less followers, it was observed that the rumor refuters had a greater number of microblogs and followers; i.e., they were more active. However, this result could be partially explained by Zhang et al. [35] that social relationship and message content were noticeable driven factors of retweeting behavior. Our findings were also in line with the literature indicating that users mainly retweet to remind others and express themselves, and retweetability is closely related to the number of followers and tweet contents' information and value [36]. It can be recognized that, when user got more followers, they tended to be more cautious with their microblog contents. Thus, retweeted messages that seem more reliable, and rumor refuting messages released by authoritative media could be one of those.

Except for the economic category, there were no significant general sentiment tendency differences between stiflers and refuters. However, the microblog contents and signature contents (except for the political type) of refuters got higher similarity with the original rumor refuting message, and this result was consistent with Luo et al. [29] and Macskassy and Michelson [37]. On the contrary, the similarity between rumor refuting message and verified information and user label were generally lower for refuters (except for the political type), which indicated that the circles and occupations of users were not seriously constant with their daily interests on the social media. Meanwhile, refuters tended to gather in more developed regions.

There were specific rumor refuter feature variations in the microblog categories that had not been previously detected. The politic and military related rumor refuters were generally older and many of them showed interests in finance and investment. Oppositely, the younger ones were more likely to be economic, society, and disaster related rumor refuters. Many of them showed interests in IT&Dig, reading, fashion, education, and employment, and those labels matched their age well. Users in the same province with the rumor seemed less likely to retweet the rumor refuting message. This phenomenon could be explained by the third person effects [38]. On the one hand, the more negative the event was, the more obvious the third person effects were. Due to peoples' underlying sense of superiority and confidence, they unconsciously believe that negative content would exert greater impacts on others than themselves and thus lead to their retweet behaviors to convey information to others (it was also why this phenomenon was most obvious in disaster related rumors). On the other hand, the effectiveness of third person effect is strongly influenced by the geographical distance between the receiver and information source, implying that the farther the receiver is from the information source, the stronger the third-person effect. Therefore, people in other locations thought that retweeting right message was urgent and important, considering those people with both long social and geographical distance would be significantly influenced by media content.

However, the small microblog sample size may have influenced the study's validity to certain extent, and the study was also limited by some of the basic variables that were extracted to characterize the refuter profiles. As the issue of user features has always been intriguing and could be explored from various dimensions, it is expected that, in the future, a wider range of features will be identified in future works to more comprehensively model rumor refuters such as ethnicity, personal preferences, active time, and sociolinguistic features. More empirical studies are also needed to investigate the usefulness and feasibility of the method developed in this paper on other social media platforms such as Wechat, Facebook, and Twitter so that it can be incorporated into active applications.

An additional uncontrolled factor is a difficulty in accurately identifying rumors/anti-rumor on Weibo. In terms of the Chinese legal framework, rumors are generally fake news. This definition emphasizes the deviation from the truth and the fact. From the perspective of mass communication, a rumor is the statement or piece of news that is deliberately made up out of thin air. The malicious

motives behind the information source might also be considered. In addition to the rumor itself, there are other forms of information filled up with Weibo, such as uncertain information and speculative information. It can be seen that there is no unified definition of rumors from different academic perspectives, and there are no clear judgment criteria for rumors, so, in practice, many difficulties and problems are unavoidable in the identification of rumors.

The principal purpose of countering the rumors is to filter the literal meaning of rumors, dig into once-hidden problems behind the rumors, and solve the underlying deep-seated social problems reflected, effectively responding to the social anxiety. Given that the main body in China to deal with rumors, solve social problems, and take targeted actions is mostly government agencies, this paper takes whether the false information/refutes of rumors posted by mainstream official accounts (such as police department accounts, government accounts and authoritative media accounts) as the criteria for recognizing and identifying rumors/anti-rumor, so as to maximize the distinction between truth and rumor. Such criteria might still lead to bias in rumors/anti-rumor judgments. Further research might add more dimensions and standards to search and identify rumors/anti-rumor on Weibo, for instance, taking the scientificity, social influence and the poster's subjective intention and other aspects of the web message into the comprehensive consideration.

### **5. Conclusions and Future Work**

The purpose of the current study was to determine the association between user features (including sentiment tendency, recent interests, gender, geographic distributions, age structure, and label frequency) and their refuting behaviors and so as to identify the rumor refuters, and deal with the dark side of online information behavior by accelerating the rumor refutation information dissemination.

The findings shown in Table 7 reveal some general features of refuters as well as variations between refuters considering different rumor categories: (1) there were more male refuters than females, especially in the politics and military categories; (2) rumor refuters of all categories were found to be highly concentrated in East, North and South of China, and particularly in provinces with first-tier cities; (3) when users were from the same geographic locations as the refutation microblogs, they were less inclined to retweet economic, societal and disaster related rumor refutation microblogs; (4) refuters were mainly aged between 18 and 40, with the refuters in the politics and military categories being somewhat older than those in the economic, society, and disaster-related categories; and (5) the political and society related rumor refuters tended to follow and post relevant information more frequently, which was shown by their higher MSm.

On the other hand, as it is shown in Table 8, there were significant differences between refuters and stiflers: (1) rumor refuters were found to be more active with higher NOM and NOF; (2) the ML was comparatively higher in the economic, political and societal categories; (3) in general, the refuters' VISm and LSm were significantly lower than the stiflers (except for the LSm in the societal category), but their SSm was higher; however, the refuters in the political category were found to have higher VISm and LSm than the stiflers and there were no significant differences between the SSm of rumor refuters and stiflers; (4) economic related rumor refuters had less positive microblog content sentiment, but the refuters tended to have higher MSm in all categories.

Provided that there was an adequately large amount of data, the XGBoost model was broadly applicable in identifying the refuters, regardless of differences in the rumor categories.

This paper only takes the anti-rumor posts verified, confirmed, and announced by official accounts into consideration, but there is still a small chance that a rumor refuted by the platform turn out to be true eventually. In the future, we plan to examine the refuter characteristics in a wider range of microblog samples, hopefully covering the possible bias with large data. In addition, we will consider more personalized and individualized features beyond just demographic attributes to more precisely identify the refuter crowd. Analysis on influence and power of refuters is also a focus of future research.


#### **Table 7.** General features of the rumor refuters.

**Table 8.** Features of the rumor refuters compared with the rumor stifler.


**Author Contributions:** Conceptualization, S.W., Z.L., and Y.W.; Data Curation, S.W., Q.Z., and Y.W.; Formal Analysis, S.W., Z.L., Y.W., and Q.Z.; Investigation, S.W. and Y.W.; Methodology, S.W., Z.L., Q.Z., and Y.W.; Supervision, Z.L.; Validation, S.W., Z.L., Y.W., and Q.Z.; Writing—Original Draft, S.W., Z.L., Y.W., and Q.Z.; Funding acquisition, Z.L. All authors have read and agreed to the published version of the manuscript.

**Funding:** This research was supported by the China Postdoctoral Science Foundation Funded Project (Grant No. 2017M612983), Chengdu Philosophy and Social Science Planning Project (Grant No. 2019L40), and the Fundamental Research Funds for the Central Universities (Grant No. SCUBS-PY-202017).

**Acknowledgments:** We thank the editors and any reviewers for their helpful comments.

**Conflicts of Interest:** The authors declare no conflict of interest.

### **References**


© 2020 by the authors. Licensee MDPI, Basel, Switzerland. This article is an open access article distributed under the terms and conditions of the Creative Commons Attribution (CC BY) license (http://creativecommons.org/licenses/by/4.0/).

### *Article*

## **The E**ffi**ciency of Social Network Services Management in Organizations. An In-Depth Analysis Applying Machine Learning Algorithms and Multiple Linear Regressions**

### **Luis Matosas-López 1,\* and Alberto Romero-Ania <sup>2</sup>**


Received: 25 June 2020; Accepted: 24 July 2020; Published: 27 July 2020

**Abstract:** The objective of this work is to detect the variables that allow organizations to manage their social network services efficiently. The study, applying machine learning algorithms and multiple linear regressions, reveals which aspects of published content increase the recognition of publications through retweets and favorites. The authors examine (I) the characteristics of the content (publication volumes, publication components, and publication moments) and (II) the message of the content (publication topics). The research considers 21,771 publications and thirty-nine variables. The results show that the recognition obtained through retweets and favorites is conditioned both by the characteristics of the content and by the message of the content. The recognition through retweets improves when the organization uses links, hashtags, and topics related to gender equality, whereas the recognition through favorites increases when the organization uses original tweets, publications between 8:00 and 10:00 a.m. and, again, gender equality related topics. The findings of this research provide new knowledge about trends and patterns of use in social media, providing academics and professionals with the necessary guidelines to efficiently manage these technologies in the organizational field.

**Keywords:** machine learning algorithms; multiple linear regression; support vector machines; SVM; management; social network services

### **1. Introduction**

The widespread use of the Internet has prompted numerous changes in recent decades. The Internet has transformed all sectors of the economy and society as a whole. In this sense, social network services are one of the best examples of how technology has changed our behavior patterns.

In 2020, the number of Internet users in the world is 4.54 billion, with an average penetration of social network use of 49% [1]. This global aggregate penetration datum obviously varies between countries. Thus, for example, the percentage in India is 29%, in Germany 45%, in the United States (US) 70%, and in South Korea 87% [1]. In the particular case of Spain, two-thirds of the population are regular users of platforms such as Facebook and Twitter [2].

The level of penetration of these technologies has transformed the way people interact with their environment, to the point of making it necessary to create a descriptive term for the typical user of these platforms: "media prosumer". Some authors describe the media prosumer as the subject capable of taking center stage, producing and consuming information in the net [3]. Others define the media prosumer as that user who actively assumes the role of the communication channel, taking advantage of it to become a recommender on different topics [4].

One of the approaches used by researchers to examine the behavior of the media prosumer on social network services is that of the uses and gratifications theory. Although this theory was initially developed to describe how audiences interact with mass media, such as radio, the press, or television [5,6], the power of these technologies to propagate information to large audiences, as traditional media do, makes the uses and gratifications theory especially suitable for contextualizing research in this field.

The conceptual framework defined by the uses and gratifications theory allows researchers to explore how the use of these media (mass media then and social media today) serves to gratify the underlying needs of the audience that uses them. The popularity of this theory is such that in recent years, it has been used to address numerous studies on the use of social network services in all types of contexts and organizations [7–11].

#### *1.1. Social Network Services in Organizations*

The potential of these platforms as a means of gratification for their users has captured the attention of organizations of various kinds. Since social networks began to become popular in the early 2000s, many organizations have used these technologies to gratify the needs of their audiences.

However, social network services not only help organizations gratify the needs of their stakeholders, these platforms have also become alternative interaction tools for official websites, an economic means to create user communities around the organization, and, in many cases, an instrument to enhance the brand image of the institution [12].

In this sense, business organizations, on the one hand, and university organizations, on the other, are two of the entities in which social network services have gained the most traction. In both cases, the main objective of the organization is to convey information to their audiences in an agile way through a channel that facilitates dialogue between parties [13]. Therefore, we can say that both companies and universities use these platforms for communication purposes. However, the differential nuance is that, whereas university organizations adhere exclusively to this communicational purpose, business organizations also seek a transactional goal. That is, in the business field, these technologies also pursue the formalization of transactions, whether they are understood as customer acquisition or as selling products or services [14].

Thus, in recent years, platforms, such as Twitter, Facebook, and Instagram, have been integrated into the strategies of many organizations, until becoming, in many cases, the cornerstone of the actions carried out in their communication and marketing departments.

The literature specialized on the topic of the use of social network services within organizations, both in the company and in the university, includes numerous references. Table 1 lists a sample of some of the studies conducted in the last five years.

Regarding the use of social media in the business context, several works can be highlighted [15–19]. Balan's research [15] explored the way in which the topics of the Instagram posts of a major sports equipment brand influenced the recognition received by its publications. Their study revealed significant differences in views, comments, and likes received depending on the topic of publication.

Matosas López [16] analyzed the aspects that condition the propagation of the content of companies in the food sector on Twitter. The author examined the way in which the interactivity of the content (links or mentions), the vividness of the publications (photos or hashtags), the sentiment of the emoticons, and the posting time influenced the dissemination of messages.

Carlson et al. [17] studied the design characteristics of a sample of company pages on the social network Facebook. In this work, the authors observed that the design of the fan page determined the way the client perceived the organization, as well as the client's predisposition to build links with it.


**Table 1.** Studies on the use of social network services in organizations.

The research of Mukherjee and Banerjee [18], based on surveys, analyzed the impact that advertising insertions on Facebook had on the users of the platform. The authors showed that advertising can lead the audience to have a positive attitude towards the brand, also increasing the purchase intention of the products or services of the company.

Giakoumaki and Krepapa [19] analyzed how the contents of luxury brands on the Instagram platform can obtain greater or lesser recognition, depending on whether the publication came from one source or another. The authors found that the recognition when the source of the publication was a personal account was greater than when the content was published by an influencer or by the corporate account of the company.

Finally, Majumdar and Bose [20], applying a multi-period analysis, studied, in a sample of manufacturing firms, the relationship between Twitter related activities and the company market value. The researches revealed the existence of positive associations between the distribution of product-related information in this social network and the firm's value.

Among studies that take university organizations as an object of analysis, several works stand out [21–25]. Laudano et al. [21] examined the Twitter presence of a sample of university libraries. Their findings revealed that, although libraries use this platform to disseminate information about collections, services, or the promotion of activities, its use is in general diffuse and poorly planned.

López-Pérez and Olvera-Lobo [24] explored the use of social media technologies for the distribution of research results in public university organizations. The authors confirmed that approximately 40% of the institutions examined used their corporate accounts on Facebook and Twitter to disseminate this type of content.

Cabrera Espín and Camarero [22] analyzed the different communication channels used by a sample of university institutions. Among other results, the researchers addressed that approximately 80% of the students turned to the university Facebook account to learn about the current affairs of their school, even more than on the school's own website.

Kimmons et al. [26], using a wide sample of publications, investigated the institutional uses of Twitter in colleges and universities. Their study suggested that even though these technologies are commonly considered as dialogic platforms, their use, in many cases, remains remarkably monologic, focusing all attention on the unidirectional distribution of information of an institutional nature.

Quitana Pujalte et al. [23] examined the ways universities use their corporate accounts to respond to situations of reputational crisis. The study showed that the university's Twitter profile can be used, in such circumstances, to redirect traffic to the institutional website or to official press releases.

Finally, Wu et al. [25] analyzed the comments that the publications of a sample of universities are capable of generating on Facebook. The authors noted that publications that use a friendly and familiar tone receive a greater volume of comments than those that use a more direct and authoritative tone.

### *1.2. The E*ffi*ciency of Social Network Services Management in Organizations*

As we can see, both business and university organizations use these technologies regularly and for different purposes. However, the keys to be considered by these organizations for developing efficient management of their platforms continue to be debated. Some authors hold that one of the problems in the management of these technologies lies in the lack of professionalization of the work teams [27]. Others point out that the management of social network services in organizations suffers from a lack of strategic planning [21].

The deficiencies in social media management are evident, but academics and professionals do maintain a firm consensus on which indicator to use to evaluate whether this management is adequate. This indicator is the recognition that the audience of the account gives to the publications of the account when they see their needs gratified.

As soon as the user perceives that the need that had originated his or her connection with the organization has been satisfied, he or she reacts positively by resorting to the relevant functionalities enabled in the platform. According to some authors [9,16,28], the user manifests this recognition of the organization by sharing its content or marking it as a favorite.

Even though the efficiency of management of social network services seems to have as an unquestionable indicator of success the recognition of content, either in the form of sharing or favoriting publications, the way to maximize this indicator is still under research. Fortunately, the enormous volume of information hosted in social network services enables a detailed study of the activity and behavior of its users.

#### *1.3. Objectives*

The millions of interactions that occur daily between organizations and users in these platforms generate millions of terabytes of information. The application of machine learning algorithms and multiple linear regressions allows us to extract the underlying knowledge in these immense information banks. Nevertheless, the ultimate goal of these techniques is the identification of trends, patterns, or models that facilitate decision-making and allow the organization to manage these technologies [29,30] efficiently.

Although some works have previously applied machine learning algorithms and multiple linear regressions to examine the activity occurred on social media, the dynamic and changing nature of these spaces requires constant updates of this knowledge [2]. An in-depth analysis of the trends and patterns of use is, without a doubt, the basis on which professionals in this field develop efficient management of these technologies in their organizations.

This research, based on the application of machine learning algorithms and multiple linear regressions, aims to provide information that serves to update this knowledge about the media. Consequently, the main objective of this work is to identify the variables that allow organizations to manage their social network services efficiently.

The object of study is the official Twitter accounts of university organizations in Spain. The social network service Twitter is taken as the object of research for the ease of access to the data. Likewise, the decision to opt for university organizations is due to the purely communicative purpose of these organizations, leaving aside the transactional objective of business organizations. Finally, the selection of Spanish institutions is justified both by the huge social media activity shown by universities in this country and by the variety of publication topics traditionally addressed by their accounts [31].

In this setting, the research analyzes how certain characteristics of the content, on the one hand, and the message of the content, on the other hand, increase the recognition of publications through retweets and favorites.

The characteristics of the content considered are publication volumes, publication components, and publication moments, whereas the effect of the message of the content focuses on the publication topics. This research will, therefore, answer the following two research questions:

Research Question I (RQI): What are the publication volumes, publication components, and publication moments that increase content recognition in the form of retweets and favorites?

Research Question II (RQII): What are the publication topics that increase content recognition in the form of retweets and favorites?

### **2. Materials and Methods**

Applying the postulates of studies on the analysis of social network services, and following the recommendations of Saura et al. [32], this study was organized into three stages: (1) sample design and data extraction, (2) data cleaning and organization, and (3) data analysis. These three stages (see Figure 1) are described below.

**Figure 1.** Stages of the methodology.

### *2.1. Sample Design and Data Extraction*

The researchers used a sample of Spanish university organizations. The selection of sampling elements was based on two of the most recognized rankings for assessing the activity of university institutions: the Webometrics list [33] and the Academic Ranking of World Universities (ARWU), also known as Shanghai ranking [34].

The authors took as their starting point the institutions located in the first fifteen positions of the Webometrics ranking in Spain in 2019 to then check whether these organizations appeared among the global top 500 of the ARWU of that same year. The authors selected only those institutions that rank in the top fifteen in Spain on the Webometrics list and, at the same time, among the top 500 in the world, according to the ARWU. This screening reduced the sample to ten organizations. The institutions were the University of Barcelona, Complutense University of Madrid, Autonomous University of Barcelona, University of Valencia, University of Granada, Autonomous University of Madrid, Polytechnic University of Catalonia, Polytechnic University of Valencia, Polytechnic University of Madrid and Pompeu Fabra University.

Once the sample was selected, the researchers extracted from the Twitter platform, all the content published by the official accounts of the ten organizations over a one-year period. Following the procedure of previous studies [23,35], the data were extracted through Twitter's API using the service provider Twitonomy. This process led to the gathering of 21,771 publications, in addition to the recognition obtained by each of them in terms of retweets and favorites.

*2.2. Data Cleaning and Organization*

The compiled data set was stored for cleaning and organization, extracting a total of thirty-nine variables arranged into six categories: (a) Publication volumes, (b) Publication components, (c) Publication day of the week, (d) Publication time slot, (e) Publication topic, and (f) Recognition obtained by the publication (see Table 2). These six categories, and the variables contained in them, were determined in accordance with previous research.


**Table 2.** Variables extracted from the data set.

The variables gathered in categories (a), (b), (c), (d), and (e) were taken as independent variables, whereas the variables in category (f) were used as dependent variables.

Publications volumes were defined considering the proposal of Bruns and Stieglitz [36]. Publication components were operationalized through the adaptation of post characteristics from De Vries et al. [37]. Publication moment, covering the categories of publication day and publication time slot was based on the analysis of Valerio Ureña et al. [38] in their study on associations between the moment of publication in social media and the engagement concept. Publication topics were addressed in accordance with the proposal of García [39] in her study on communication management in social networks services. And finally, the category that represented the recognition obtained by the publication was determined following the recommendations of authors such as Chen [9] or Pletikosa Cvijikj and Michahelles [28], among others.

The independent variables were clearly separated and differentiated from each other. For instance, a publication could be "Original Tweet", "Retweet", or "Reply", but never "Original Tweet" and "Reply" at the same time. Similarly, a message with a unique publication ID can only be posted on a specific day of the week. In the same way, a publication can not be categorized in two time slots at the same time.

Nevertheless, there could be potential correlations between variables placed in different categories. Thus, for example, publication days or publication time slots could be correlated with the topics of

publication. This could lead us to think that the variables in the publication topics´ category could also be considered as independent variables. However, this work, in line with previous studies on the efficiency of social media management in organizations, used as independent variables those commonly taken by the research community when evaluating the recognition of publications [9,16,28]. That is, the variables contained in the category (f), Retweeted Pubs. and Favorite Pubs.

### *2.3. Data Analysis*

To carry out the data analysis, the authors used a two steps approach. First, machine learning algorithms were applied for the classification of publication topics (Category e). Second, multiple linear regressions were used to reveal the volumes (Category a), components (Category b), publication moments (Categories c and d), and publication topics (Category e) that increased the recognition of content.

### 2.3.1. First Step: Machine Learning Algorithms

The authors applied machine learning algorithms to classify the publication topics (Category e). These publication topics would be used as independent variables in the multiple linear regression carried out in the second step of the data analysis.

In the field of social network services, machine learning algorithms are used to conduct categorizations or classifications of text publications [40]. These systems allow organizations to classify thousands or millions of pieces of text efficiently, and comfortably, for later exploration.

The textual information analyzed using machine learning algorithms is classified as unstructured information. These data do not adhere to a previously defined scheme; therefore, their processing requires the application of certain rules (idiomatic, grammatical, and semantic) to extract the information they contain.

Specifically for the platform under study, the methodologies based on Twitter Analytics approaches addressed by Goonetilleke et al. [41], Kumar et al. [42] or Lin and Ryaboy [43] generally use machine learning algorithms, either to analyze the sentiment of publications or to study specific hashtags. Examples of works focused on the analysis of the sentiment of publications (positive, negative, and neutral) are the studies by Hoeber et al. [44] or Saura et al. [32]. Whereas, examples of investigations focused on the observation of specific hashtags are the works of Lakhiwal and Kar [45] or De Maio et al. [46].

With respect to the techniques used in these methods, the following stood out: decision trees (DT), random forest (RF), Naïve Bayes classifier (NBC), logistic regression (LR), k-nearest neighbors (kNN) and support vector machines (SVM) [47–49]. However, whereas many of these techniques can be effective in determining the sentiment of publications or for hashtag examinations, the most appropriate technique for the classification of complex publication topics, and the one that offers the highest accuracy, is the SVM technique [14].

The SVM technique applied in the present study used, specifically, the linear Kernel function as a classification method. This general Kernel function is defined as follows:

$$\mathbf{K}\left(\mathbf{x}\_{i\prime},\mathbf{x}\_{j}\right) = \Phi\left(\mathbf{x}\_{i}\right) \cdot \Phi\left(\mathbf{x}\_{j}\right) \tag{1}$$

where K (x*i*, x*j*) is the core function, and Φ (x*i*) represents the mapping space associated with the vectors.

The machine learning algorithm used for the classification of publication topics (Category e) was a supervised machine learning algorithm. With supervised machine learning algorithms, there exists an initial set of already labeled data with input-output pairs that allows for training of the predictive model. From this initial data set, the algorithm learns to assign the appropriate output label to each incoming element in the model [50]. In the case of the specific application of these algorithms in the classification of texts in social media, the labeled data set is typically created, on a small scale, via the intervention of a subject or group of subjects who assign each publication the most appropriate label in each case.

In line with previous research, the classification was performed through the text classification API of the MonkeyLearn library [51,52]. This text classification API uses the JASON notation protocol in JavaScript, also allowing the researcher to carry out, before classification, a training process with the algorithm.

After carrying out this training process, manually categorizing 300 publications, the algorithm had the necessary knowledge to develop a personalized machine learning model. This model allowed the classification of the texts of each of the 21,771 publications into one publication topic.

#### 2.3.2. Second Step: Multiple Linear Regressions

The researchers used multiple linear regressions to discover the publication volumes (Category a), publication components (Category b), publication moments (Categories c and d), and publication topics (Category e) that increased the recognition of content.

In the context of social network services, multiple linear regressions focus on quantitative analyses of activity metrics from organizations and users [16].

The information, in the form of metrics that is analyzed by applying multiple linear regressions is generally regarded as structured information. These data are collected in predefined fields and presented using tables of values in which fields and cases are represented in columns and rows, respectively.

The analyses of activity metrics on these platforms can be carried out using techniques such as simple linear regressions (SLR), structural equation modeling (SEM), or even descriptive explorations. Examples of studies that use these techniques are the works of Valerio Ureña and Serna Valdivia [53], Pletikosa Cvijikj and Michahelles [28], or Alonso [54], among others. However, although these techniques are widely accepted, multiple linear regression is probably the most effective technique when the purpose is knowing not only the influence of the independent variables on the dependent ones individually, but also the joint potential of these within the predictive model [16].

The general equation, which is used to represent the multiple linear regression, is expressed as:

$$\mathbf{Y}\_{i} = \alpha + \beta\_{1}\,\,\mathbf{X}\_{i1} + \beta\_{2}\,\,\mathbf{X}\_{i2} + \dots \,\, + \,\, \beta\_{k}\,\,\mathbf{X}\_{ik} + \varepsilon\_{i}\tag{2}$$

where α is the constant term of the model, Y*i* is the dependent variable, X*i* represents the independent variables, β represents the regression coefficients, and ε*i* is the error or average of residuals.

The multiple linear regression allowed us to reveal the volumes, components, publication moments, and publication topics that increased the recognition of content. In this analysis, the authors took all the thirty-nine variables considered in the study. Thirty-seven acted as independent variables and two as dependent variables. These thirty-seven independent variables corresponded to the categories (a) Publication volumes, (b) Publication components, (c) Publication day of the week, (d) Publication time slot, and (e) Publication topic. The two dependent variables were those corresponding to category (f) Recognition obtained by the publication (Retweeted Pubs. and Favorite Pubs.).

### **3. Results**

#### *3.1. First Step: Machine Learning Algorithms*

The SVM technique applied, using the linear Kernel function as a classification method, reflected the existence of sixteen publication topics: General news, Scholarships, Science and technology, Contests, Culture and exhibitions, Sports, Entrepreneurship, Complementary training, Gender equality, Institutional information, Employability, Research, Seminars and conferences, Awards and recognitions, Health and green environment, and Volunteering.

These publication topics were determined by the researchers and validated by a panel of five judges who were experts on the management of social networks services in different organizational contexts.

In line with previous research, the authors used the Krippendorff's alpha to measure the accuracy of the text classification carried out by the supervised machine learning algorithm [32]. The Krippendorff's alpha value obtained (0.886), which is above the recommended threshold of 0.800, indicated that the supervised machine learning algorithm had been properly trained, and its predictive power was accurate enough.

The descriptive exploration of the publication topics addresses the presence of differences in the recognition obtained by the different topics. Table 3 reveals that there were differences in the way in which the content of each publication topic was recognized and what topics obtained greater recognition.


**Table 3.** Retweets and favorites obtained by publication topic.

This descriptive examination revealed that institutional information and general news were the most recurring topics, accounting for 20.13% and 11.29% of the publications, respectively.

The average number of retweets and favorites received per publication showed that the contents that achieved the greatest recognition were those related to the gender equality topic. Paradoxically, this topic, with the highest retweet and favorite average, represented only 3.22% of all publications. Therefore, it seems to be clear that certain topics get far more recognition than others.

### *3.2. Second Step: Multiple Linear Regressions*

Two multiple linear regression were performed, one for the dependent variable Retweeted Pubs. and another for the dependent variable Favorite Pubs. The results obtained from these analyses allowed the identification of the variables that increased content recognition in the form of retweets and favorites within the respective models. To examine the explanatory power of each independent variable, the items of the categories (a), (b), (c), (d), and (e) were introduced in their respective model as individual indicators. The researchers applied here the stepwise method for incorporating the variables.

In the first regression, the one performed for the Retweeted Pubs., the item "Links" (β = 0.560, *p*-value < 0.0001), was added in the first step of the procedure. The variable "Hashtags" (β = 0.455, *p*-value < 0.005) was introduced in the second step. Finally, the item "Gender equality" (β = 0.447, *p*-value < 0.0001) was added in the third step.

The model for this first dependent variable (Retweeted Pubs.) was significant as a whole (F = 78.341, *p*-value < 0.0001), optimally explaining the variance of the dependent variable with values of R = 0.976 and R<sup>2</sup> = 0.951. Therefore, this first regression showed the impact of the variables "Links", "Hashtags", and "Gender equality" when predicting the recognition of content published through retweets (see Table 4).


**Table 4.** Coefficients of multiple linear regressions.

\* *p*-value < 0.005; \*\* *p*-value < 0.0001.

For the second regression, the one developed for the variable Favorite Pubs., the item "Original Tweets" (β = 0.198, *p*-value < 0.005) appeared in the first step of the process. The variable added in the second step was called "Pub. 8:00 to 10:00" (β = 0.237, *p*-value < 0.005). To finish, the item "Gender equality" (β = 0.531, *p*-value < 0.005) appeared in the third step.

The model for the second dependent variable (Favorite Pubs.) was also significant (F = 311.278, *p*-value < 0.0001), adequately explaining the variance of this variable with values of R = 0.931 and

R2 = 0.917, respectively. Therefore, this second regression revealed the influence of the variables "Original Tweets", "Pub. 8:00 to 10:00", and "Gender equality" on the recognition obtained, through favorites, of the content published by the organization (see Table 4).

To corroborate the validity of the above regressions, the authors analyzed the residuals of both using the Shapiro–Wilk test and the Durbin–Watson test.

The Shapiro–Wilk test was performed to see whether the values of the standardized residuals followed a normal distribution. The *p*-values above 0.050 in the two regressions (0.851 for the first and 0.721 for the second) confirmed that the residuals were normally distributed [55]. The Durbin–Watson test served to verify whether the assumption of independence of residuals was met. The values of this indicator between 1 and 3 in both regressions (1.847 for the first and 1.425 for the second) verified that the requirement of independence of residual was satisfied [56]. The values in the Shapiro–Wilk and Durbin–Watson tests confirmed that the predictive models obtained from the multiple linear regressions carried out were adequate and robust.

### **4. Discussion**

Different authors have highlighted the need for organizations to invest in efficient management in social network services. Some studies suggest that these platforms require professionalized management systems and that their management cannot be left to nonspecialized profiles [12,27]. Other authors claim that the way many organizations handle these technologies lacks strategic vision [21]. Along the same lines, there are studies that indicate that organizations should not settle for using their accounts to build their institutional image but rather must also protect the reputation of the organization [57]. Some authors claim that properly managed, social network services can even serve as a customer acquisition tool [58].

The findings of this study provide academics and professionals with the necessary knowledge to efficiently manage their use of these technologies, enabling organizations to satisfy many of the aforementioned purposes. The results obtained, thanks to the application of machine learning algorithms and multiple linear regressions, allow us to answer the two research questions posed: the one that concerns the characteristics of the content (publication volumes, publication components, and publication moments) and the one related to the message of the content (publication topics).

### *4.1. Volumes, Components, and Publication Moments That Increase Content Recognition (RQI)*

The multiple linear regressions showed that content recognition through retweets was conditioned on the use of links and hashtags in publications, whereas recognition by favorites was fundamentally determined by the frequency of original tweets and a publication time between 8:00 and 10:00 a.m. The influence of these four variables had a positive valence. Thus, greater exploitation of links, hashtags, original tweets, and early-morning publication boosts the recognition achieved by the organization in the form of retweets and favorites.

Such results corroborate, for example, the findings of Túñez López et al. [34] in their work on the use of Facebook and Twitter as communication channels. Those authors highlighted the value of links as an essential element in any message.

Regarding the use of hashtags, the results are in line with those of Guzmán Duque et al. [58] in their study on the impact of the use of Twitter in the organizational field. In this work, the authors highlighted the potential of these markers in facilitating the promotion and projection of the organization to the audience.

As for the frequency of original tweets, the findings of this work corroborate what was indicated by Chen [9] in a study on uses and gratifications on Twitter: a high publication frequency of original content acts as a motivating factor that encourages the subject to interact with other users.

Finally, with respect to the publication moment, the results are aligned with the findings of Hanifawati et al. [59] in their work on the management of corporate Facebook accounts. In that study, the researchers emphasized that the messages in the most active time slots, those in which the user

is more likely to visit the platform, increase both the amount of shared content and the comments received on it.

### *4.2. Publication Topics That Increase Content Recognition (RQII)*

The multiple linear regressions demonstrated that content recognition through retweets and favorites can be influenced using topics related to gender equality. Therefore, the use of publications with this thematic approach increases the recognition achieved by the organization in the form of retweets and favorites.

Although it is true that other authors have highlighted the importance of the theme of the publication in the context of social network services in organizations [27,38], there are few studies that examine this issue in depth. When this has been done, the analyses tend to focus more on the sentiment or tone of the publications [25,60], or on superficial explorations of hashtags or predetermined search terms [34,61]. Consequently, these studies generally ignore most of the text of the publication and the semantics of the expressions contained in it. The present work, thanks to the use of a supervised machine learning algorithm, previously trained by the researchers, allowed for a highly adjusted text classification of the publications. This text classification, carried out in the first step of the analysis, allowed us to identify the publication topics that were used later in the multiple linear regression performed in the second step of the data analysis.

The analysis carried out by the authors revealed differences in the recognition obtained by the different publication topics. These findings are in line with the findings of Pletikosa Cvijikj and Michahelles [28] in their study on engagement factors in online communities within Facebook. Those authors pointed out that the type of content published by the organization can indeed determine the recognition obtained in its audience.

The findings achieved in the present research revealed that, paradoxically, topics with a smaller weight over the total number of publications, such as those that address topics related to gender equality, were the most successful in terms of recognition by the audience.

### **5. Conclusions**

Although in recent years some works have used machine learning algorithms and multiple linear regressions to examine the activity that has occurred on social media, these studies tend to focus exclusively on content characteristics (publication volumes, publication components, and publication moments) [16,62] or in its message (tone, sentiment, or publication topic) [60,63]. However, few studies have examined both topics simultaneously.

Perhaps the most emblematic of these studies is that of Pletikosa Cvijikj and Michahelles [28]. Their work, like the present one, considered the characteristics of the content (publication volumes, publication components, and publication moments) and the message of the content (publication topics). Nevertheless, their study analyzed a data set smaller than that of the present study and applied a text classification with only three publication topics (information, entertainment, and rewards), as opposed to the sixteen considered here.

The present research not only combines a study of the characteristics (volumes, components, and moments) and the message (topics) of the publications, it also addresses this challenge more comprehensively than previous works. In the authors' opinion, the examination of the characteristics and the message of the publications, in addition to the two steps analysis approach applied in the investigation, are among the key values of the current study. The supervised machine learning algorithm applied in the first term allowed the classification of the texts into the publication topics. Knowing publication topics besides publication volumes, publication components, and publication moments, the authors applied multiple linear regressions to discover the influence of all these variables on the recognition of content.

The results gathered to answer the first research question (RQI) are in line with the findings of previous studies, whereas the results obtained for the second question (RQII) provide novel and relevant information on the current field of investigation.

To this second point, regarding the publication topic, the findings confirmed that publications on topics of gender equality achieve much higher recognition than those obtained by content focused on other topics. In the opinion of the authors, this situation is conditioned by the recent social sensitivity around this issue. Likewise, given that no organization goes uninfluenced by social issues of this nature, the knowledge derived from these results in the context of university organizations is likely to be extended to the field of business organizations.

On the other hand, this finding can prompt reflections on the importance of media professionals' managing these technologies adequately and being able to identify, at all times, the topics of interest to their audience, adapting the content of their organizations to these preferences.

In view of all the above, the authors confirm the value of applying machine learning algorithms and multiple linear regressions to carry out an in-depth analysis of the enormous amount of information generated by social network services gaining new knowledge about trends and usage patterns in the media. This renewal will ultimately be the basis for the efficient management of social network services in the organizational field.

### **6. Limitations and Further Research**

This paper also suffers from several limitations. The sample, although significant, could be amplified to examine more deeply the observed phenomena.

In addition, future research could also consider complementing the analyzes carried out in the present study with other analytical approaches. A work like the present one could be complemented, for example, by using network centrality analysis [42] or OLAP (On-Line Analytical Processing) techniques [64,65].

Centrality analysis, generally supported in JUNG (Java Universal Network-Graph) open source frameworks, are used to identify who is the most important user in the network; revealing in a graphical way who gets more retweets (Degree Centrality), which is the most influential user (Eigenvector Centrality), or the number of shortest paths in which the user distributes the information (Betweenness Centrality).

Likewise, OLAP techniques allow the extraction of information related to user behaviors, emerging topics, or trends, providing generic multidimensional models for the analysis of data on social network services.

The aforementioned issues address new avenues for research in this field, confirming that further investigation is still needed to expand our understanding of the activity on social media.

**Author Contributions:** Conceptualization, methodology, formal analysis, investigation, data curation, writing—original draft, project administration, L.M.-L.; validation, writing—review and editing, visualization, supervision, A.R.-A. All authors have read and agreed to the published version of the manuscript.

**Funding:** This research received no external funding.

**Conflicts of Interest:** The authors declare no conflict of interest.

### **References**


© 2020 by the authors. Licensee MDPI, Basel, Switzerland. This article is an open access article distributed under the terms and conditions of the Creative Commons Attribution (CC BY) license (http://creativecommons.org/licenses/by/4.0/).

## *Article* **Research on Sentiment Classification of Online Travel Review Text**

### **Wen Chen 1,2, Zhiyun Xu 1, Xiaoyao Zheng 3, Qingying Yu <sup>3</sup> and Yonglong Luo 1,3,\***


Received: 19 June 2020; Accepted: 27 July 2020; Published: 30 July 2020

**Abstract:** In recent years, the number of review texts on online travel review sites has increased dramatically, which has provided a novel source of data for travel research. Sentiment analysis is a process that can extract tourists' sentiments regarding travel destinations from online travel review texts. The results of sentiment analysis form an important basis for tourism decision making. Thus far, there has been minimal concern as to how sentiment analysis methods can be effectively applied to improve the effect of sentiment analysis. However, online travel review texts are largely short texts characterized by uneven sentiment distribution, which makes it difficult to obtain accurate sentiment analysis results. Accordingly, in order to improve the sentiment classification accuracy of online travel review texts, this study transformed sentiment analysis into a multi-classification problem based on machine learning methods, and further designed a keyword semantic expansion method based on a knowledge graph. Our proposed method extracts keywords from online travel review texts and obtains the concept list of keywords through Microsoft Knowledge Graph. This list is then added to the review text to facilitate the construction of semantically expanded classification data. Our proposed method increases the number of classification features used for short text by employing the huge corpus of information associated with the knowledge graph. In addition, this article introduces online travel review text preprocessing, keyword extraction, text representation, sampling, establishment classification labeling, and the selection and application of machine learning-based sentiment classification methods in order to build an effective sentiment classification model for online travel review text. Experiments were implemented and evaluated based on the English review texts of four famous attractions in four countries on the TripAdvisor website. Our experimental results demonstrate that the method proposed in this paper can be used to effectively improve the accuracy of the sentiment classification of online travel review texts. Our research attempts to emphasize and improve the methodological relevance and applicability of sentiment analysis for future travel research.

**Keywords:** user generated content; sentiment analysis; classification; keyword extraction; text representation; sampling; machine learning; TripAdvisor

### **1. Introduction**

Tourism research has entered the era of big data. Based on big data analysis, academia and industry are now better positioned to understand and explore tourist behavior and the tourism market. Li et al. [1] contend that big data analysis can provide sufficient data without introducing sampling bias, and can also make up for the sample size limitations encountered by the survey data, thereby enabling a better understanding of tourist behavior. Sivarajah et al. [2] argued that big data analysis

can lead to new knowledge; subsequently, such analysis has become the mainstream method used to obtain useful information.

From blogs and social media posts to online travel review sites, user-generated content (UGC) is one of the most important data sources for big data. UGC comprises insightful feedback that is spontaneously provided by users. This feedback information is widely available at little to no cost and can also be easily obtained [3]. Such feedback also has potential commercial value in fields such as targeted advertising, customer–company relationships, and brand communication [4,5].

Online travel review sites, such as TripAdvisor, generate large amounts of text-based online travel review data, which constitute an important type of UGC [6]. Online review text data can help researchers and practitioners to correctly understand tourists travel preferences and needs [7,8]. The opinions expressed in user-generated comments also play an important role in influencing the choices of potential tourists [9,10].

The characteristics of big data have complicated the process of knowledge extraction. The question of how to transform data into valuable knowledge has become crucial for big data applications [11,12]. Previous research into online reviews has mainly focused on the quantitative ratings provided on the website, ignoring the text of online reviews [3]. Ratings cannot provide any information about the specific product characteristics that visitors like or dislike, and such information is typically included in the review text [5,13]. In addition, many users are overwhelmed by the enormous amount of review information provided on travel online review sites. Researchers in other fields have also raised similar questions. Ali et al. [14] noted that while urban traffic congestion is rapidly increasing, a city s rating score is insufficient to provide accurate information; however, comments or tweets may help travelers and traffic managers to understand all aspects of the city. As a result, it is necessary to establish an effective mechanism to help users identify the main content and emotions embedded in the review text [15].

Human emotions and emotional reasoning are understood to be important factors that influence consumer decision-making [16]. This makes sentiment analysis an effective method for mining the connotations of online travel review texts. Text sentiment analysis methods can be divided into dictionary-based methods [17], machine learning methods [13,18], deep learning methods [19,20], and hybrids of the above methods [21,22]. Alaei et al. [8] contend that dictionary-based systems rely on the use of sentiment dictionaries and rule sets. Their article proposes that such methods are unable to adapt to the rapid increase in data volume in the era of big data, so it is necessary to develop more effective automated methods for sentiment analysis. Deep learning methods usually require a large amount of training data to fully realize their potential; this training data usually requires expensive class labeling [23]. Among machine learning methods, support vector machines (SVM) and naive Bayes are the most widely used in the tourism-related sentiment analysis context [13]. Compared with neural networks, SVM and naive Bayes require fewer class annotations to train the model [8]. Most studies on the subject have shown [18] that SVM-based sentiment analysis of text produces superior results relative to other machine learning methods. Kirilenko et al. [13] compared automatic text sentiment analysis classifiers with humans and evaluated whether various types of automatic classifiers are suitable for typical applications in the tourism, hotel, and marketing research contexts. The article argues that on difficult and noisy datasets, automatic classifiers achieve worse performance than humans. It can therefore be concluded that the existing sentiment analysis technology needs to be improved to enable the analysis of specific data.

Contemporary researchers have proposed many effective solutions to improve the performance of SVM in sentiment analysis. Successful feature extraction is one of the main challenges faced by machine learning methods [24]. Feature extraction can reduce information loss and achieve improved discrimination ability in sentiment classification [25] tasks. In their study of feature selection methods, Manek et al. [25] proposed a Gini index feature selection method based on SVM to carry out sentiment classification for a large movie review dataset. Ali et al. [26] proposed a robust classification technology based on SVM and fuzzy domain ontology (FDO), used for the recognition of comment features and

the mining of semantic knowledge. Their experimental results showed that the integration of FDO and SVM greatly improves the accuracy of extracting comments and opinion words, as well as the accuracy of opinion mining. Parlar et al. [27] proposed a new feature selection method based on the query expansion term weighting method in the information retrieval context. This study uses four classifiers to compare their method with other widely used feature selection methods, thereby verifying their method's effectiveness. Zainuddin et al. [28] proposed a latent semantic analysis (LSA) and random projection (RP) feature selection method for the sentiment analysis of Twitter data, and thereby constructed a new Twitter mixed sentiment classification method. Kumar et al. [29] introduced swarm intelligence algorithms into the field of feature optimization in order to improve the sentiment classification performance accuracy. Pu et al. [30] used a variety of features to identify candidate opinion sentences, then used structured SVM to encode these opinion sentences for document sentiment classification. This article resolves the issue of sentiment classification problems arising when the sentiment of most sentences is inconsistent with the sentiment of the document overall.

As an effective feature selection method, semantic expansion has also been widely studied. Adhi et al. [31] designed a sentiment analysis model based on a naive Bayes classifier and the semantic extension method, proving that the semantic extension method can improve the accuracy of sentiment analysis. Fang et al. [32] integrated the context features extracted from the comment sentences and the external knowledge retrieved from the sentiment knowledge graph into a neural network to compensate for the lack of available training data, consequently obtaining better sentiment analysis results. At the same time, as an effective channel for semantic expansion, knowledge bases such as WordNet and ConceptNet are widely used in sentiment analysis in multiple languages. Alowaidi et al. [33] proposed using Arabic WordNet as an external knowledge base to enrich the representation of tweets due to the weakness of the bag of words model; the use of naive Bayes and SVM on the Arab Twitter dataset verified that this external knowledge base can be used to improve sentiment analysis accuracy. Asgarian et al. [34] used Persian WordNet to generate a review corpus, proving that sentiment dictionary quality plays a key role in improving the quality of sentiment classification in the Persian language. Moreover, Agarwal et al. [35] proposed a novel sentiment analysis model based on ConceptNet and common sense extracted from context information.

At the same time, a number of scholars in tourism research have studied the application of sentiment analysis to tourism and hospitality-related data. Several existing works [8,13] have already summarized the sentiment analysis methods adopted by the academic community in the tourism context prior to 2016; therefore, this article only summarizes the relevant literature published after 2017 in Table 1. Among these works, Höpken et al. [36] extracted customer feedback from two online platforms and carried out sentiment analysis and opinion mining, verifying that SVM is best able to solve the problem of sentiment analysis compared with other related methods. Akhtar et al. [37] used topic modeling technology to identify hidden information and other aspects, then performed sentiment analysis on classified hotel review text sentences. Ma et al. [38] performed sentiment analysis on TripAdvisor's review data using Leximancer. Ko et al. [39] applied statistical analysis methods to a large number of consumer review texts obtained from Expedia, enabling these authors to understand the experiences of hotel guests and analyze their association with satisfaction. Stepchenkova et al. [40] selected and compared three of the best-performing sentiment analysis methods to quantify respondents' views on travel in China. Bansal et al. [41] further proposed a sentiment classification method based on mixed attributes. By capturing implicit word relationships and combining domain-specific knowledge, these authors were able to obtain a fine-grained emotional orientation of online consumer reviews. Finally, Lawani et al. [42] used the AFINN dictionary (a lexicon based on unigrams) to extract the sentiments from comments left by Airbnb guests and derive a quality score from those comments.

An analysis of the above literature reveals that the academic community has carried out fruitful work in the field of sentiment analysis, particularly as regards the feature selection of SVM. Although these related topics have been extensively researched, certain specific types of content, such as online travel review texts for TripAdvisor, still present some challenges when using sentiment

analysis [43]. This is because the key features of reviews vary significantly from site to site, meaning that it cannot be assumed that the sentiment analysis method and findings of a certain site will be applicable to all other review sites [44]. On the subject of the sentiment analysis of online travel review texts, most existing sentiment analysis models fail to comprehensively and effectively consider the data characteristics of travel review texts during the modeling process. Online travel review texts have their own inherent characteristics. Most review texts are short, which makes it difficult to extract keywords; in addition, the sentiment distribution of short texts is uneven [45] (for example, the texts with the highest and lowest scores are comparatively few). These characteristics make it difficult for accurate sentiment analysis results to be obtained for online travel review texts [46]. In addition, the accuracy of existing automated sentiment analysis methods is also low [13].


**Table 1.** The methods of sentiment analysis used in tourism research.

In order to deal with the sentiment analysis-related challenges brought about by the data features of online travel review texts, this study converted the sentiment analysis of online travel review texts into a multi-classification process based on machine learning methods, and further conducted research on sentiment classification methods for such texts. In order to improve the classification accuracy of online travel review texts, the current research mainly addresses the following problems related to previous research. The main contributions of the paper include:


### **2. Materials and Methods**

Sentiment analysis generally includes multiple steps [48]. As can be seen from Figure 1, the sentiment analysis process proposed in this paper includes the following five steps:

**Figure 1.** System framework.

(1) Data retrieval. In this study, a crawler program written in Python was used to obtain the texts, namely English descriptions of four famous attractions in four countries from the travel review website TripAdvisor, used as sentiment analysis data. This process is relatively simple; due to space limitations, it will not be described here.

(2) Data preprocessing. Section 2.1 introduces the steps involved in online comment text preprocessing.

(3) Keyword extraction and semantic expansion of comment texts. In order to improve classification accuracy, Section 2.2 introduces our online travel review text keyword extraction method and keyword semantic expansion method based on Microsoft Knowledge Graph.

(4) Text representation. Section 2.3 introduces the text representation method based on Word2vec.

(5) Sentiment classification. Section 2.4 introduces the sentiment classification method adopted in this paper.

### *2.1. Data Preprocessing*

Not all characters included in the text of online travel reviews are important. For example, most reviews include words, punctuation, etc. that do not describe the subject of the text. Retaining all characters will lead to the formation of high-dimensional features; this will not only increase the time required for classification learning, but will also introduce a lot of noisy data into the classification and affect the classification accuracy. It is therefore necessary to preprocess the data. The preprocessing process used in this article comprises the following four steps:


In step 1, Python s BeautifulSoup library was used to remove HTML tags such as '<br> from the comment text. Steps 2–4 were implemented using NLTK (Natural Language Toolkit) [49] and regular expressions. Here, the second step deletes punctuation, numbers and other non-English characters from the comment text; the third step divides the sentence into words and converts all of these words to lower case; finally, the fourth step uses the stop word list provided by NLTK and deletes these words from the comment text and stoplist. The stop word list contains some noise words that do not describe the text subject ("the", "is", "are", "a", "an", etc.). In addition, combined with the characteristics of the dataset in this article, we added some specific vocabulary words (for example: "Mutianyu", "Great Wall", "China"). These specific high-frequency words will affect the subsequent keyword extraction and sentiment analysis results. However, these words are usually objective descriptions of scenic spots and accordingly do not help with the sentiment analysis.

### *2.2. Keyword Extraction and Semantic Expansion*

The online travel review text obtained in this article pertains to multiple attractions. As shown in Figures 2–5, before preprocessing, the length of the review text about Mutianyu Great Wall, Beijing, China is mostly between 260 and 280 words. Moreover, the length of the comment text for the Harry Potter Wizarding World Theme Park in Orlando, USA is between 90 and 130 words; the comment text for the Tower of London, England is between 90 and 140 words in length; and the lengths of the comment text for the Sydney Opera House in Australia are mostly in two categories (90 to 120 and 200 to 300 words). Because preprocessing will delete some characters that are not related to sentiment classification, the text will be shorter after preprocessing. It is difficult to extract effective feature words from shorter text and thus more difficult to obtain better sentiment classification results [46]. In order to improve the effectiveness of sentiment classification for online travel review text, this paper proposes a keyword semantic expansion method based on knowledge graphs. First, we compared several keyword extraction methods and selected the TextRank method as having the best effect [50] for achieving keyword extraction for online travel review text. Secondly, through the use of Microsoft Knowledge Graph, a conceptual list of keywords for each comment was obtained. This concept list of keywords can be used to expand the semantics of the comment text and provide a richer and more valuable classification feature for the classifier. Next, the specific implementation steps will be introduced.

**Figure 2.** Length of the comment text of Mutianyu Great Wall in Beijing, China.

**Figure 3.** Length of the comment text in the Harry Potter Wizarding World Theme Park in Orlando, USA.

**Figure 4.** Length of the comment text of the Tower of London.

**Figure 5.** Length of the comment text of the Sydney Opera House in Australia.

#### (1) Keyword extraction

Text keyword extraction is a machine learning algorithm-based text feature extraction method. In fields such as text-based recommendation and search, the accuracy of text keyword extraction is directly related to the final effect. Accordingly, text keyword extraction is an important research direction in the field of text mining. Text keyword extraction methods can be divided into supervised, semi-supervised, and unsupervised methods [51]. Supervised and semi-supervised methods regard keyword extraction as a classification problem and require a labeled training corpus to train the keyword extraction model. However, for massive datasets, labeling the training corpus is often very time-consuming. For its part, the unsupervised keyword extraction method does not require a manually annotated corpus, and is therefore more suitable for the keyword extraction of massive comment texts [52].

The TextRank algorithm proposed by Mihalcea et al. [50] draws on the realization of PageRank, which is the core algorithm of Google search. This is an unsupervised keyword extraction method. Unlike TF-IDF (term frequency—inverse document frequency), LDA (Latent Dirichlet Allocation), etc., TextRank divides the text into several units (e.g., words, sentences) and builds a graph model; keyword extraction can thus be achieved using only the information contained in a single document.

The process by which TextRank extracts text keywords comprises the following steps:

(1) Divide the given text into sentences.

(2) For each sentence segmentation and part-of-speech tagging, filter out stop words, so that only words belonging to the specified part-of-speech are reserved as candidate keywords.

(3) Construct the candidate keyword graph G = (V, E), where V is the node set comprising the candidate keywords generated in step (2); next, use the co-occurrence relationship to construct the edges between any two points.

(4) Calculate the weight of each node. These node weights are sorted in reverse order so that the most important words are obtained as candidate keywords.

(5) Mark the candidate keywords obtained in step (4) in the original text; if adjacent phrases are formed, these are combined into multi-word keywords.

A variety of keyword extraction algorithms represented by TextRank are widely used in tourism and many other fields. Shouzhong et al. [53] integrates TF-IDF and TextRank to mine and analyze personal interests from Weibo text. Paramonov et al. [54] developed a new method combining well-known keyword extraction algorithms (e.g., TextRank and Topic PageRank) and a thesaurus-based procedure, thereby improving the connectivity of the text-via-keyphrase graph while also increasing the accuracy and recall rate of key phrase extraction. Gagliardi et al. [55] integrated the word embedding model and clustering algorithm to establish a novel method capable of automatically extracting keywords/phrases from text without supervision. Ali et al. [56] used the N-gram method to extract the risk factors of heart disease diagnosis and applied these to an intelligent heart disease prediction system, improving the accuracy of heart disease diagnosis.

In Section 3.2, based on the similarity calculation results of the words, and following experiments with TF-IDF and LDA, it is determined that the keywords extracted by TextRank are more suitable for ascertaining the actual semantics of online travel text reviews. Therefore, this study used TextRank for text keyword extraction purposes.

(2) Keyword semantic expansion

Text feature semantic expansion is an effective method of solving the sparse text problem [57]. Wang et al. [58] conceptualized short text into a set of concepts and embedded the original text in order to form word vectors. Experimental results verify that the convolutional neural network based on this word vector can achieve good short text classification results. Rosso et al. [59] believe that combining large-scale unstructured content (text) and high-quality structured data (knowledge graph) can improve text analysis.

Microsoft Knowledge Graph [60] has learned a large amount of common sense knowledge through learning from billions of web pages and years of search logs. The system-provided conceptual model maps text entities into semantic concept categories with specific probabilities; for example, "Microsoft" may automatically map to "software companies" and "Fortune 500 companies" [61]. This paper introduces the conceptual model of the Microsoft Knowledge Graph to expand the semantics of online travel review text keywords. This knowledge graph-based keyword semantic expansion method utilizes the huge information corpus of the Microsoft knowledge graph to expand the semantics of the text. This method overcomes the issue of fewer features being available that is caused by the sparseness of short texts, and accordingly provides richer and more valuable classification features for short text sentiment classification. We demonstrate the improvement in classification accuracy brought about by this method in the experiment discussed in Chapter 3.

### *2.3. Text Representation*

(1) Text representation of comments based on Word2vec

Representing text as structured data that is able to be handled by machine learning classification algorithms is a highly important part of the text classification process. In 2013, Google released the software tool Word2vec for training word vectors [62]. Word2vec s high-dimensional vector model solves the multi-dimensional semantic problem, because it can quickly and effectively express words in high-dimensional vector form through the optimized training model according to a given corpus, thereby providing a new tool for the application research in the field of natural language processing [63]. Academic research [64,65] demonstrates that Word2vec has achieved excellent performance in the fields of text similarity calculation and text classification. In light of the above analysis, this study opted to construct Word2vec vectors for the pre-processed and semantically expanded comment text.

#### (2) Data normalization

Normalized data exhibits enhance stability for attributes with very small variances, while maintaining 0 entries in the sparse matrix [62]. Therefore, this study used the normalization method to scale the text vector represented by Word2vec to between 1 and 0. The formula utilized is as follows:

$$\mathbf{x}\_{l}^{\prime} = \frac{\mathbf{x}\_{i} - \mathbf{x}\_{\min}}{\mathbf{x}\_{\max} - \mathbf{x}\_{\min}} \tag{1}$$

In Equation (1), *xi* represents the result of normalization, while *xi* represents the data that needs to be normalized. Moreover, *xmax* and *xmin* represent the maximum and minimum values in the dataset, respectively.

### *2.4. Sentiment Classification*

For massive texts, one effective solution involves transforming sentiment analysis into classification and applying machine learning methods in order to solve such problems [66]. This article has introduced the problems encountered by deep learning methods, along with the excellent results achieved by machine learning methods in the text sentiment analysis context. Therefore, using the online travel review text data processed in Sections 2.1–2.3 as the training data, this SVM was chosen in this study as the method of sentiment classification. In Section 3.4, through the analysis of experimental results, the most suitable sentiment classification model for processing online travel review texts is then provided.

### **3. Case Study**

This section introduces the research process utilized in this article and draws conclusions from a sentiment classification experiment on online tourist review texts of multiple attractions. In more detail, Section 3.1 describes the experimental dataset and the results of preprocessing; Section 3.2 introduces the experimental process of keyword semantic expansion based on the knowledge graph; Section 3.3 introduces the text representation based on Word2vec; finally, Section 3.4 introduces the sentiment classification based on SVM experiments and result analysis.

### *3.1. Data Acquisition and Preprocessing Experiment*

As shown in Table 2, the present research used a crawler program written in Python to obtain four datasets from the TripAdvisor website. Mutianyu\_Great\_Wall contains review text pertaining to Mutianyu Great Wall, which is the number one attraction on the TripAdvisor website in Beijing, China. It contains a total of 2772 pieces of review data in English published by tourists from January 2016 to December 2019. Wizarding\_World\_of\_Harry\_Potter contains comment text pertaining to the Harry Potter Wizarding World Theme Park in Orlando, USA, specifically, 6641 pieces of comment data published by tourists from June 2017 to December 2019 in English. Tower\_of\_London contains comment text about the Tower of London in the United Kingdom. It contains the data of 4428 comments in English published by tourists from July 2018 to December 2019. Finally, Sydney\_Opera\_House contains review text pertaining to the Sydney Opera House in Australia. It contains 6776 pieces of comment data in English published by tourists from March 2017 to December 2019.

Table 3 presents a piece of comment text in the Mutianyu\_Great\_Wall dataset and its pre-processed results. In the table, the first column is the comment text published by tourists, while the second column is the pre-processed text. As can be seen from the introduction in Section 2.2, the preprocessing operation removes any original comment text content that is unrelated to sentiment classification.


**Table 2.** Experimental data set.

**Table 3.** Review and preprocessed results.


### *3.2. Keyword Semantic Expansion Experiment Based on Knowledge Graph*

(1) Comparison of text keyword extraction methods

In this study, three types of text keyword extraction methods, namely TF-IDF, LDA, and TextRank were selected to carry out comparative experiments. Taking the comment text in Table 3 as an example, the manually provided subject term is "transport", while the second column of Table 4 presents the keywords with the largest calculation result values obtained by each of the three methods. Among them, the calculation results of the two keywords obtained by TF-IDF are the same.

**Table 4.** Results of three types of keyword extraction methods.


Word2vec is able to convert words into vectors and calculate the distance between the vectors. The larger the value of the calculation result, the greater the similarity between the two words [63]. Based on Word2vec s word vector similarity calculation, this study calculated the similarity of the word vector using the first-order keywords obtained by the three methods in addition to the subject words ("transport") of the manually provided comment text. The calculation results are shown in the third column of Table 4. Here, the keywords obtained by TextRank are the most relevant to the subject words of the manually provided review text. TF-IDF also identified the most relevant keyword, "bus". However, the keyword "cable", which has the same weight as "bus", has poor relevance to the subject words of the manually provided review text, which affects the final result. LDA requires a large corpus (i.e., large amount of comment text) for accurate results to be obtained. However, this research requires keywords to be derived for each short text of the comment. Therefore, LDA is unsuitable for this research, and the final keyword extraction effect is also poor. This study randomly selected 10% of the samples in each dataset and used the above three methods to extract keywords. Following experimental comparison, TextRank was found to have the best keyword extraction effect. Therefore, TextRank was used to extract the text keywords from online travel reviews.

(2) Keyword semantic expansion experiment

This study obtained a concept list of online travel review text keywords using the conceptual model of Microsoft Knowledge Graph [67]. For example, the conceptual list of the keyword "bus" in the comment text of Table 3 is as follows: vehicle, public transportation, large vehicle, etc.

For the four datasets listed in Table 2, TextRank was used to extract text keywords from a total of 20,617 comment texts. In the next step, the concept list for the first-ranked keywords was obtained in ascending order of weight. Although Microsoft Knowledge Graph covers a very wide range, it does not cover any word. Following calculation, Microsoft Knowledge Graph returned results for 97.6% of the keywords in this experiment. Finally, we added the return results of each keyword to the pre-processed comment text in order to create a pre-classification dataset.

### *3.3. Text Representation and Normalization Experiment*

Once preprocessing and semantic expansion was complete, the comment text was typically under 300 characters in length. Therefore, googlenews-vecctors-negative300.bin [62], a word vector library of news corpora pre-trained by Google, was selected to create a comment text vector. The final results are illustrated in Figure 6. Each line in the figure is a normalized 300-dimensional Word2vec real vector, which represents a specific comment text.


### **Figure 6.** Vector representation of text.

#### *3.4. Sentiment Classification Experiment*

#### (1) Acquisition of training set classification labels

The machine learning classification method represented by SVM requires training data with sentiment classification results for model training. The sentiment classification results for these training data are also referred to as the training set classification labels. In this study, manual analysis and sentiment analysis software were used to generate the classification labels for the training set.

SentiStrength [68] is a software package that estimates the strength of positive and negative emotions contained in text. It also has an artificial level of accuracy for short social network texts in English. We chose the nine-level sentiment classification results provided by SentiStrength. For negative emotions, the scores range from −1 (not negative) to −4 (extremely negative); for positive emotions, moreover, the scores range from 1 (not positive) to 4 (extremely positive); 0 represents neutral emotion. In this study, the SentiStrength results were again scored by humans, and the adjustment rate was about 24.7%. Finally, the sentiment analysis results of the dataset are presented in Figures 7–10. The abscissa represents the sentiment analysis results, which range from −4 to 4 in a total of nine categories; the ordinate indicates the number of samples in each category. It can be seen that the number of samples of each sentiment value is extremely uneven. For example, in the review dataset for the Mutianyu Great Wall in Beijing, China, the number of texts in category 2 is 1266, while there is only 1 text in category −4. Moreover, there is no category −4 data in the review data of the Sydney Opera House in Australia.

**Figure 7.** The sentiment distribution of the review text of Mutianyu Great Wall in Beijing, China.

**Figure 8.** The sentiment distribution of the review text of the Harry Potter Wizarding World Theme Park in Orlando, USA.

**Figure 9.** The sentiment distribution of the review text of Tower of London.

**Figure 10.** The sentiment distribution of the review text of Sydney Opera House, Australia.

(2) Sampling experiment of unbalanced data

From the analysis presented in the previous section, we can see that the sentiment distribution of online travel review texts is very uneven. In fact, this is a typical unbalanced dataset. For unbalanced datasets, machine learning classifiers will tend to incorrectly divide new samples into categories with more samples, resulting in classification errors [69]. The methods used to process unbalanced datasets are mainly divided into undersampling, oversampling, and improved methods [70]. This study used Python to implement two types of sampling methods. Our experimental results demonstrate that, due to the extremely uneven sentiment distribution of the experimental dataset used, the undersampling dataset was so small that it was difficult to obtain more accurate classification results. Overall, Naive Random Over Sampler (ROS) [71] achieved the best sampling results.

(3) Evaluation index

The evaluation indicators of classification results that have been adopted by academia include Accuracy, Precision, Recall, and F1 score [72]. In binary classification, the sample categories are divided into positive and negative types. Let us suppose that TP represents the number of samples that are both actually positive and classified as positive, while FP denotes the number of samples that are actually negative but classified as positive; moreover, FN represents the number of samples that are in fact positive but are classified as negative, while TN indicates the number of samples that are both actually negative and classified as negative. In addition, the accuracy rate refers to the proportion of correct samples classified as positive to the samples classified as positive. The calculation formula for this is as follows:

$$Precision = \frac{TP}{TP + FP} \tag{2}$$

Furthermore, the recall rate refers to the proportion of correct samples classified as positive to actually positive samples, and the calculation formula is as follows:

$$Recall = \frac{TP}{TP + FN} \tag{3}$$

Finally, the F1 score is the harmonic average of the precision rate and the recall rate. The calculation formula is as follows:

$$F1\text{ score} = \frac{2 \ast Precision \ast Recall}{Precision + Recall} \tag{4}$$

The accuracy rate reflects the model s ability to distinguish negative samples: the higher the accuracy rate, the stronger the model s ability to distinguish negative samples. Moreover, the recall rate reflects the model s ability to identify positive samples: the higher the recall rate, the stronger the model s ability to recognize positive samples. In addition, the F1 score is the combination of the accuracy rate and recall rate: the higher the F1 score, the more robust the model. While accuracy is the simplest and most intuitive evaluation index in classification, it is also affected by obvious defects. For example, if we assume that 99% of the samples are positive samples, the classifier could obtain 99% accuracy if it always predicted a positive result, but its actual performance would be very low. That is to say, when the proportion of samples in different categories is highly uneven, the category with the largest proportion often becomes the most important factor affecting the accuracy. As the experimental data in this study was unbalanced data, we did not use the accuracy rate as a classification result evaluation index. Instead, we selected three indicators—accuracy rate, recall rate, and F1 score—to measure the classification results.

(4) SVM-based sentiment classification

Python s sklearn was used to implement the SVM algorithm. After a large number of experiments, the kernel function RBF (Radial Basis Function) was found to achieve the highest classification accuracy, while other parameters were assigned default values. We used 30% of the data as test data and the remaining 70% as training data.

The comparative experimental results of one dataset (Mutianyu\_Great\_Wall) are presented in Table 5. The first row of Table 5 displays the classification results of SVM. The classification accuracy of SVM on this imbalanced dataset is very low, as it assigns most of the samples to the category with the largest number of samples. Once ROS sampling and the Word2vec vectorization of text is complete, the data in the second row of Table 5 shows that the SVM algorithm's classification result has been greatly improved. The next experiment carried out involves extracting TextRank keywords from the comment text and expand the semantics of the keywords with the largest weights based on the Microsoft Knowledge Graph. The semantic expansion of keywords and pre-processed online travel review text make up the SVM classification dataset. Moreover, the experimental results in the third row of Table 5 list the final classification results; it can be seen from this table that the knowledge graph-based keyword semantic expansion method proposed in this paper optimizes the classification results.



Optimal solution in the comparison result is marked in bold.

Table 6 presents the experimental results of the other three datasets. Similar to the experimental results of the Mutianyu\_Great\_Wall dataset, it can be seen that the sampling technique, Word2vec-based text vectorization, and knowledge graph-based keyword semantic expansion method effectively improve the classification effect. Similar experimental results obtained on different datasets verify the universality of this method. In short, this provides an effective solution for sentiment analysis of online travel review texts.

**Table 6.** Experimental results of the other three data sets.


Optimal solution in the comparison result is marked in bold.

The receiver operating characteristic curve (ROC) is an evaluation method that demonstrates the accuracy of classification through intuitive graphics. Figures 11–14 show the ROC curves of the four data sets. We have labeled each sentiment category (the Sydney\_Opera\_House dataset has only eight sentiment categories) with a different color. The abscissa in the figures indicates the proportion of samples classified as positive but actually negative to all negative samples; the ordinate represents the proportion of all positive samples that are predicted to be positive and actually positive. The closer the ROC curve is to the upper left corner, the higher the accuracy of the experiment.

**Figure 11.** Mutianyu\_Great\_Wall receiver operating characteristic curve.

**Figure 12.** Wizarding\_World\_of\_Harry\_Potter receiver operating characteristic curve.

**Figure 13.** Tower\_of\_London receiver operating characteristic curve.

**Figure 14.** Sydney\_Opera\_House receiver operating characteristic curve.

### **4. Discussion**

Sentiment analysis is a mainstream technology that employs social media analysis strategies to analyze customer feedback and comments. Conducting sentiment analysis based on websites such as TripAdvisor is desirable because a large number of free datasets can be obtained from such websites for large-scale research, while such large-scale data cannot easily be obtained via traditional research methods. Big data provides a new type of data for use in tourism research, and also puts forward higher requirements for data processing. Currently, few studies have been conducted on the applicability and accuracy of sentiment analysis methods in the tourism research literature. In addition, contemporary research ignores the possibility of integrating human knowledge, such as knowledge graphs, into existing methods in order to improve the text sentiment analysis performance. Big data is characterized by a huge data volume, and the speed and accuracy requirements for sentiment analysis are becoming steadily higher [8]. Therefore, the prospect of developing suitable and efficient sentiment analysis methods for specific types of big data in the tourism context is a highly valuable proposition.

The obtained sentiment analysis results based on TripAdvisor review text can be applied to multiple fields. For example, they can help sightseeing spots, restaurants, or hotels to explain comments and adopt corresponding countermeasures, which can in turn provide decision makers and customers with better decision-making information. Similarly, this approach can also be used to study theoretical issues related to customer satisfaction (for example, whether a tour guide service would improve the tourist experience). However, existing studies [43,44] have found that the key features of the review text differ substantially depending on which websites they are drawn from, and that it is therefore necessary to conduct sentiment analysis research on one specific website at a time. Therefore, research into machine learning sentiment analysis methods for TripAdvisor review texts will aid in the development of tourism research utilizing these texts. Compared with vocabulary sentiment analysis, one of the advantages of machine learning sentiment analysis is that it does not require humans to create a dictionary; this is beneficial because the production of such a dictionary is a time-consuming and laborious process. In addition, machine learning methods achieve more accurate performance on larger amounts of training data than can be obtained using vocabulary sentiment analysis [8]. Feature extraction is a key issue in the application of machine learning to the field of sentiment analysis [24]. Accordingly, this study designed and implemented a sentiment classification method based on the semantic expansion of text keywords that both increases the classification features and improves the accuracy of sentiment analysis, thereby providing a novel solution for machine learning sentiment analysis.

In terms of the specific details of the work of this article, in order to improve the accuracy of sentiment analysis conducted on online travel review texts, this study conducted extensive research work on the classification problems caused by the data features of online review texts. First, most online review texts are short texts, which makes it difficult to obtain more accurate sentiment classification results. To solve this problem, we designed a text keyword semantic expansion method based on a knowledge graph. In this part of the research, the present study compared three typical text keyword extraction methods and provided keyword extraction methods that are suitable for online travel review texts. In addition, based on Microsoft Knowledge Graph, the semantics of text keywords were expanded, and richer and more valuable sentiment classification features were constructed. The second part of the research involved comparing the two types of sampling methods and identifying which of these is more suitable for use in solving the uneven sentiment distribution problem in online review texts. This article fully describes the key aspects of online travel review text sentiment classification, establishes an effective sentiment classification research framework for online travel review text, and validates the proposed method based on a relatively extensive sample.

The work put forward in this paper aims to emphasize and improve the methodological relevance and applicability of sentiment analysis. However, there are some limitations:


classification results. However, the question of whether these novel methods are suitable for the research object of this article is worthy of further study.

• In terms of experimental subjects, this article only studies English reviews from TripAdvisor, and does not investigate other online travel platforms and other languages. Therefore, it is highly advisable to investigate data in other languages and other platforms to verify the applicability of this method.

**Author Contributions:** Conceptualization, W.C., Z.X. and Y.L.; Methodology, W.C.; Software, W.C.; Validation, W.C., Q.Y. and X.Z.; Data curation, W.C.; Writing—original draft preparation, W.C.; Writing—review and editing, W.C., Z.X. and Y.L.; Supervision, Y.L.; Project administration, Y.L.; Funding acquisition X.Z., Q.Y. and Y.L. All authors have read and agreed to the published version of the manuscript.

**Funding:** This research was funded by the National Natural Science Foundation of China under Grants 61972439, 61672039, 61702010, 61772034.

**Acknowledgments:** The authors would like to acknowledge all of reviewers and editors.

**Conflicts of Interest:** The authors declare no conflict of interest.

### **References**


© 2020 by the authors. Licensee MDPI, Basel, Switzerland. This article is an open access article distributed under the terms and conditions of the Creative Commons Attribution (CC BY) license (http://creativecommons.org/licenses/by/4.0/).

## *Article* **Towards the Discovery of Influencers to Follow in Micro-Blogs (Twitter) by Detecting Topics in Posted Messages (Tweets)**

### **Mubashir Ali 1, Anees Baqir 2, Giuseppe Psaila 1,\* and Sayyam Malik <sup>2</sup>**


Received: 8 July 2020; Accepted: 12 August 2020; Published: 18 August 2020

**Abstract:** Micro-blogs, such as Twitter, have become important tools to share opinions and information among users. Messages concerning any topic are daily posted. A message posted by a given user reaches all the users that decided to follow her/him. Some users post many messages, because they aim at being recognized as influencers, typically on specific topics. How a user can discover influencers concerned with her/his interest? Micro-blog apps and web sites lack a functionality to recommend users with influencers, on the basis of the content of posted messages. In this paper, we envision such a scenario and we identify the problem that constitutes the basic brick for developing a recommender of (possibly influencer) users: training a classification model by exploiting messages labeled with topical classes, so as this model can be used to classify unlabeled messages, to let the hidden topic they talk about emerge. Specifically, the paper reports the investigation activity we performed to demonstrate the suitability of our idea. To perform the investigation, we developed an investigation framework that exploits various patterns for extracting features from within messages (labeled with topical classes) in conjunction with the mostly-used classifiers for text classification problems. By means of the investigation framework, we were able to perform a large pool of experiments, that allowed us to evaluate all the combinations of feature patterns with classifiers. By means of a cost-benefit function called "Suitability", that combines accuracy with execution time, we were able to demonstrate that a technique for discovering topics from within messages suitable for the application context is available.

**Keywords:** social media; micro-blogs (Twitter); towards recommending influencers based on topic classification; investigation framework; comparison of various techniques for topic classification; cost-benefit function

### **1. Introduction**

Micro-blogs have become widely-used online platforms for sharing ideas, political views, emotions and so on. One very famous micro-blog is *Twitter*: it is an online social network that allows users to publish short sentences; every day, millions of messages (also called *tweets*) concerning a very large variety of topics are published (or *posted*) by users. According to [1], Twitter is a famous micro-blogging site where more than 313 million users from all over the world are active monthly.

Due to the importance it has gained, Twitter inspired novel researches concerned with many areas of computer science, in particular data mining [2], sentiment analysis [3], text mining [4], discovering mobility of people [5–7] and so on. For example, tweets are analyzed to find out political friends [8], so this implies that texts are analyzed to detect their political polarity. Another interesting application

is detecting communities from networks of users [9], in which sentiment analysis plays an important role; sentiment analysis and opinion mining can be also adopted to study the general sentiment of a given country [10], in order to detect the degree of support to terrorists. We can summarize that most of works concerned with the analysis of tweets are focused on sentiment analysis and opinion extraction; thus, the common perspective is that tweets posted by users are collected and queried to provide useful information about users. We can say that users are analyzed from outside the micro-blog; the results of the analysis are not used to provide a service or a functionality to users of the micro-blog itself.

Nevertheless, many users post a lot of messages, because they wish to influence other users. In fact, when a followed user posts a new tweet, all her/his followers receive it. Typically, users post many messages because they would like to be recognized as influencers in a specific topic. This goal requires a user to have many followers, that are interested in the same topic. Consequently, it is critical, for an influencer, to be interesting for other users and easily found by them. On the contrary, non influencer users would like to easily find interesting influencers to follow.

How to find users to follow? The reasons to decide to follow other users can be various; typically, one reason is affinity of interests: a user would like to follow other users with similar interests. However, currently it is quite hard to find out users that show the same interests, because micro-blog platforms in general (and Twitter in particular) do not provide any end-user functionality or service that recommends users with similar interests; consequently, we are envisioning a new scenario for micro-blog platforms.

This new scenario can become reality only if it is technically possible to realize it. This is the goal of this paper, i.e., addressing the basic problem at the basis of the envisioned functionality: we show that it is effective and efficient to classify messages with topics they talk about. In practice, we demonstrate that it is possible to define a technique that allows for characterizing user interests (in terms of topics) by analyzing their posted messages, that will open the way to build a sort of recommender system that recommends one user with other users having similar interests. At the best of our knowledge, very limited work has appeared in literature concerning this topic.

In this paper, we investigate the definition of an approach based on supervised learning, to discover topics that messages posted by micro-blog users talk about. To this end, we devised an investigation framework whose goal is to apply various combinations of feature patterns (extracted from within posted messages) and classification techniques: this framework has enabled us to identify the best combination to address the problem. In this work, basic and combined n-grams, weighted with a "Term Frequency-Inverse Document Frequency" (TF-IDF)-like metric, are used to extract features from messages to train four of the mostly-used classifiers for text classification, i.e., Naive Bayes (NB), Support Vector Machine (SVM), K- Nearest Neighbors (kNN) and Random Forest (RF). By means of the investigation framework, we performed a comparative analysis of accuracy and execution times; to identify the most suitable solution, we defined a cost-benefit function called *Suitability*, able to balance the benefit of a technique in terms of accuracy with the computational cost of using that technique. We will show that the comparative analysis yielded the solution that we think suitable for discovering topics messages talk about: this is the preliminary step to extend micro-blog user interfaces with functionalities able to suggest influencers to follow. This comparative analysis, that considers both accuracy and execution time, is the distinctive contribution of this paper: in fact, at the best of our knowledge, a similar approach has not been proposed yet.

The rest of the paper is organized as follows. Section 2 gives a brief review of the existing approaches used for text mining applications on micro-blog data sets. Section 3 depicts the envisioned application scenario and defines the specific problem addressed by the paper. Section 4 presents the investigation framework, by discussing the dimensions of the investigation. Section 5 reports about the experimental analysis conducted by means of the investigation framework; by means of results, we perform a comparative analysis of techniques, by considering both their effectiveness (in terms of accuracy) and their computational cost. By means of the cost-benefit function called Suitability, we rank the techniques and we identify the most suitable solution for the application scenario depicted in Section 3. Finally, Section 6 draws the conclusions.

#### **2. Literature Review**

To the best of our knowledge, the problem of discovering topics from messages in micro-blogs has not been significantly addressed yet, specifically if the goal is to introduce new functionalities in the micro-blogs interface. Nevertheless, micro-blogs have become precious for many application fields, and many techniques have been developed. In this sense, the related literature is so vast that it is impossible to be exhaustive. In the rest of this section, we propose a brief overview of techniques developed for application areas that are somehow related to our paper.

### *2.1. Sentiment Analysis and Opinion Mining*

Topic discovery is somewhat close to sentiment analysis and opinion mining. Various approaches to perform sentiment analysis and opinion mining on micro-blogs (and Twitter in particular) have been proposed. Their application context is very different with respect to the context and the goal considered in this paper. Nevertheless, it is useful to give an overview of these techniques.

Kanavos et al. [11] proposed an algorithm to exploit the emotions of Twitter users by considering a very large data-set of tweets for sentiment analysis. They proposed a distributed framework to perform sentiment classification. They used Apache Hadoop and Apache Spark to take the benefits of big data technology. They partitioned tweets into three classes, i.e., positive, negative and neutral tweets. The proposed framework is composed of four stages: (i) feature extraction (ii) feature vector construction (iii) distance computation, and (iv) sentiment classification. They utilized hashtags and emoticons as sentiment labels, while they performed classification by adopting the AkNN method (specifically designed for Map-Reduce frameworks).

The study [12] by Hassan et al. evaluated the impact of research articles on individuals, based on the sentiments expressed on them within tweets citing scientific papers. The authors defined three categories of tweets, i.e., positive, negative, and neutral. They observed that articles which were cited in positive or neutral tweets have more impact if compared to articles cited in negative tweets or not cited at all. To perform sentiment analysis, a data-set of 6,482,260 tweets linking to 1,083,535 publications was used.

Twitter data are also very important for companies, so as to exploit them to improve their understanding about the perception by customers of the quality of their products. In [13] authors proposed an approach to process the comments of the customers about a popular food brand, by using tweets from customers. A Binary Tree Classifier was used for discovering the polarity lexicon of English tweets, i.e., positive or negative. To group similar words in tweets, a K-means clustering algorithm was employed.

### *2.2. Sociological Analysis*

The area of sociological analysis is the target of many classification techniques on micro-blog messages.

The paper [14] presents a technique to understand the emotional reactions of supporters of two Super Bowl 50 teams, i.e., Panthers and Broncos. The author applied a lexicon-based text mining approach. About 328,000 tweets were posted during the match by supporters, in which they expressed their emotions regarding different events during the match. For instance, supporters expressed positive emotions when their team scored; on the other hand, they expressed negative emotions when their team conceded a goal. It was concluded that results supported sociological theories of affective disposition and opponent process.

The work [15] shows how the authors used tweets to monitor the opinion of citizens regarding vaccination in Italy, i.e., in favor, not in favor and neutral. For improving the proposed system, different combinations of text representations and classification approaches were used, and the best accuracy was achieved by the combination scheme of bag-of-words, with stemmed n-grams

as tokens, and Support Vector Machines (SVM) for classification. The proposed approach fetched and pre-processed tweets related to vaccine and applied SVM to perform classification of tweets and achieved an accuracy of 64.84% , that is acceptable but not very good. The investigation approach is similar to the one adopted in our research, i.e., various combinations of techniques are tested to find the most effective combination.

Geetha et al. [16] aimed to analyze the state of mind expressed on Twitter through emoticons and text in tweets. They developed FPAEC—Future Prediction Architecture Based on Efficient Classification; it incorporates different classification algorithms, including Fisher's linear discriminant classifier, artificial neural networks, Support Vector Machines (SVM), Naive Bayes and balanced iterative reducing; it also incorporates a hierarchical clustering algorithm. In fact, they propose a two-step approach, where clustering follows a preliminary classification step, to aggregate classified data.

### *2.3. Politics*

Politics is an interesting application field of sentiment analysis and opinion mining on micro-blogs. Here, we report a few works.

In [17], the authors proposed a framework to predict the popularity of political parties in Pakistan in 2013 public election, by finding the sentiments of Twitter users. The proposed framework is based on the following steps: (1) collection of tweets; (2) pre-processing of tweets; (3) manual annotation of the corpus. Then, to perform sentiment classification, supervised machine learning techniques such as Naive Bayes (NB), k Nearest Neighbors (kNN), Support Vector Machines (SVM) and Naive Bayes Multinomial (NBMN) were used to categorize the tweets into the predefined labels.

In [18], authors utilized tweets to reveal the views of the leaders of two democratic parties in India. The tweet data-set was collected by using the public twitter accounts, and Opinion Lexicon [19] was used to compute the number of positive, negative and neutral tweets. They proposed a "Twitter Sentiment Analysis" framework, which, after pre-processing of the crawled data-set from Twitter, accumulated opinion lexicon along with classification of tweets into three classes, i.e., positive, negative and neutral, for the evaluation of sentiments of users.

To discover the sentiments of Twitter users, with the aim of exploring their opinions regarding political activities during election days, the authors of [20] proposed a methodology and compared the performance of three sentiment lexicons, i.e., W-WSD, SentiWordNet, TextBlob and two well known machine learning classifiers, i.e., Support Vector Machines (SVM) and Naive Bayes. They achieved better classification results with the W-WSD sentiment lexicon.

In [10], authors utilized tweets to predict the sentiment about Islamic State of Iraq and Syria (ISIS); opinions are organized based on their geographical location. To perform the experimental evaluation, they collected tweets for a period of three days and used Jeffrey Breen's algorithm with data mining algorithms such as Support Vector Machine, Random Forest, Bagging, Decision Trees and Maximum Entropy to classify tweets related to ISIS.

The paper [8] presents a study where the authors exploit tweets to find out political friends. They named their approach a "Politic Ally" which identifies the friends having the same political interest.

### *2.4. Phishing and Spamming*

Aspects related to phishing and spamming can be addressed by analyzing micro-blogs as well, and are close to the problem of topic discovery.

The work [21] proposed an effective security alert mechanism to contrast phishing attacks which targeted users on social networks such as Twitter, Facebook and so on. The proposed methodology is based on a supervised machine learning technique. Eleven critical features in messages were identified: URL length, SSL connection, Hexadecimal, Alexa rank, Age of domain-Year, Equal Digit in host, Host length, Path length, Registrar and Number of dots in host name. Based on these features, messages were classified, to build a classification model able to identify phishing.

Similarly, to deal with spam content being shared on twitter by spammers, Washha et al. [22] introduced a framework called Spam Drift, which combined various classification algorithms, such as Random Forest, Support Vector Machines (SVM) and J48 [23]. In short, they developed an unsupervised framework that dynamically retrains classifiers, used during the on-line classification of new tweets to detect spam.

### *2.5. Frameworks for Topic Discovery (Interest Mining)*

As far as topic discovery (or user interest mining) is concerned, the work [24] proposed a framework for "Tweets Classification, Hashtags Suggestion and Tweet Linking". The framework performs seven activities: (i) data-set selection; (ii) pre-processing of data-set; (iii) separation of hashtags; (iv) finding relevant domain of tweets; (v) suggestion of possible interesting hashtags; (vi) indexing of tweets; (vii) linking of tweets. Thus, topics are represented by hashtags, that are suggested to users. With respect to our approach, discovered topics are very fine grained (at the level of hashtags), because the idea is to suggest hashtags to follow, not users.

In a similar study [25], to detect user interests by automatically categorizing data on the basis of data collected from Twitter and Reddit, authors proposed a methodology comprised of two steps. (i) multi-label text classification model by using Word2vec [26], a predictive model and (ii) topic detection by using Latent Dirichlet Allocation (LDA) [27], a statistical topic detection model based on counting word frequency from a set of documents. A pool of 42,100 documents collected from Redit and manually labeled was used to train the model; then, a pool of 1,573,000 tweets (posted by 1573 users) was used as training set. This work is interesting because it uses Reddit to build the classification model to classify unlabeled tweets from Twitter. However, the scenario is quite different with respect to our paper: in fact, we propose that users wishing to be influencers voluntarily label their posts, with the goal to be recognized as influencers.

The work [28] presents a web-based application to classify tweets into predefined categories of interest. These classes are related to health, music, sport, and technology. The system performs various activities. First of all, they fetch tweets from Twitter and pre-process them; second of all, feature selection from texts is performed; finally, the machine learning algorithm is applied. Although, from a general point of view, it is an interesting system, it is designed to perform analysis of messages from outside the micro-blog. In contrast, our goal is to find out the best technique suitable to discover topics within the micro-blog application.

So, we can say that our envisioned application scenario is quite novel; furthermore, the specific goal of the investigation framework presented in this paper is not to be the end-user solution, but a tool to discover the technique that is most suitable to be executed within the micro-blog application to discover topics.

### *2.6. Recommendation Techniques*

Recommendation techniques have been proposed in the social network world by a multitude of papers. They are so many that it is impossible to report them all. Hereafter, we report those that we consider representative of most recent developments.

The Reference [29] proposed a Recommendation System for Podcast (RSPOD). The system recommends podcasts, i.e., audios, to listen to. The system utilizes the intimacy between social network users, i.e., how well they virtually communicate with each other. RSPOD works (i) by crawling podcast information, (ii) by extracting data from social network services and (iii) by applying a recommendation module for podcasts.

To predict user's rating for several items, [30] considers social trust and user influence. In fact, it is argued that social trust and influence of users can play a vital role to overcome the negative impact on the quality of recommendation caused by sparsity of data. The phenomenon of social trust is based on the sociology theory called "Six Degrees of Separation" [31]: the authors proposed a framework that jointly adopts user similarity, trust and influence, by balancing preferences of users, trust between them and ratings from influential users for recommending shops, hotels and other services. The proposed framework was applied on a data set collected from dianping.com, a Chinese platform that allows users to rate the aforementioned services.

According to Chen et al. [32], previous recommendation systems mainly focus on recommendations based on users' preference and overlook the significance of users' attention. Influence of trust relation dwells more on users' attention rather than users' preference. Therefore, an item of a user's interest can be skipped if it does not get his attention. To counter this, they proposed a probabilistic model called Hierarchical Trust-Based Poisson Factorization, which utilizes both users' attention and preferences for social recommendation of movies, music, software, television shows and so on.

Similarly, [33] aimed at accurately predicting users' preferences and relevant products recommendation on social networks by integrating interaction, trust relationships and popularity of products. The key focus of the proposed model is on performing analysis of users' interaction behavior to infer users' latent interaction relationships, based on product ratings and comments. Moreover, the popularity of product is considered as well, to help support decision making for purchasing products.

By emphasizing on the importance of social interaction on recommendation systems [34], presented an approach based on mapping the weighted social interaction for representing interactions among users of a social network, by including historical information about users' behavior. This information is further mined by using an algorithm called Complete Path Mining, which helps find similar social neighbors possessing similar tastes as of the target user. To predict the final ratings of unrated items (such as software, music, movie and so on), the proposed model uses social similar tendencies of the users on complete paths.

To summarize, the reader can see that recommendation techniques are thought to recommend single items (such as posts, podcasts or products) to users, based on the existing relationships among users. Li et al. [35] address the same general problem that we envision in our application scenario, i.e., recommending users to follow: they propose a framework to recommend the 50 users that are more similar to a specific user; they jointly exploit user features (such as ID, gender, region, job, education and so on) and user relationships. In contrast, in our envisioned scenario, we propose a different approach, i.e., recommending other users to establish a relationship with (e.g., to follow) on the basis topics their posts talk about. At the best of our knowledge, this problem has not been addressed yet in literature.

### **3. Scenario and Problem Statement**

In this section, we illustrate the application scenario we are considering, in order to define the problem we address in the rest of the paper.

Suppose a user of a micro-blog platform wants to look for other users to follow, in order to receive their posts. How to find them? Currently, both micro-blog apps and the web sites provide a search functionality to search for users on the basis of a keyword-based search. So, the activity a user has to perform to find out interesting users to follow, that is depicted in the right-hand side of Figure 1 (the block titled *Current Scenario*), can be summarized as follows.


Such a process is quite tedious and boring, so probably user *u* could miss interesting users to follow.

In contrast, we envision a novel functionality for micro-blog apps and web sites: suggesting users based on similar interests. Let us clarify our vision:

• User *u* starts posting some messages, possibly re-posting messages received from followed users.


Clearly, it is necessary to devise a technique able to learn about user interests. This must necessarily be a multi-label classification technique, that based on the analysis of features extracted from posted messages, builds a model of user interests on the basis of these features.

So, the application scenario we envision, that is illustrated in the left-hand side of Figure 1 (the block titled *Proposed Scenario*), can be described as follows.

**Figure 1.** Application scenario.


messages. Once the most frequent topics in unlabeled messages posted by user *u* are collected, the application suggests the list *S* = *s*1, ... ,*sm* of users possibly posting messages concerning the same topics of interest for *u*.

• User *u* can inspect the profiles of users in *S* and choose the ones to follow, if any.

In order to avoid misunderstandings, we clearly state that we do not consider two different types of users, i.e., influencers and regular users: any user is equal to other users. However, if a users wishes to be recognized as an influencer, she/he can better succeed if the micro-blog platform provides a tool that helps achieve this aim. In fact, the basic condition for a user to be considered as an influencer is that the number of followers is significantly high; thus, a tool that recommends potential interesting users is the solution. Such a tool could integrate classical text-based search: in fact, we can envision that the micro-blog platform is pro-active in suggesting users; furthermore, text search could be too fine grained to be successful. In other words, we explore the possibility to improve the service provided by micro-blog platforms to users, both those who wish to become influencers and those who wish to find out possibly interesting and emerging influencers.

Clearly, the basic brick to be able to develop the envisioned functionalities is to be able to assign the proper topic to unlabeled messages. The main goal of this paper is to investigate if there exists a classification technique that is suitable for this task, both in terms of effectiveness and in terms of efficiency. The specific problem that must be addressed by the wished technique is defined as follows.

**Problem 1.** *Consider a set LP* = {*l p*1, ... , *l pn*} *of labeled posts; each l pi* = *mti*, *ati denotes a labeled post, where mti is the message text and ati is the assigned topic.*

*Consider a second set UM* = {*umt*1, ... , *umtm*} *of unlabeled messages umtj. Based on the set of labeled posts LP, a classification model C*(*umt*) *must be built, such that given a message text umtj* ∈ *UM, C*(*umtj*) = *tpj, i.e., the classification model C provides the topic tpj of the umtj message.*

In the rest of the paper, we will address Problem 1, looking for the technique based on text classification that provides the best compromise between accuracy (as far as topic detection is concerned) and efficiency. In fact, if we are able to demonstrate that there exists a technique suitable to solve Problem 1, the way to further investigate how to rank influencers to suggest to users can be taken.

### **4. The Investigation Framework**

In this section, we introduce the framework we built to investigate how to discover topics messages talk about, as reported in Problem 1. First of all, we discuss the dimensions of investigation we considered (Section 4.1); then, we present technical aspects of the framework in details.

### *4.1. Dimensions of the Investigation*

Problem 1 is a multi-label text classification problem. Thus, through the investigation framework, two dimensions must be investigated.


Figure 2 graphically reports the dimensions of the investigation: the reader can see that the two dimensions are orthogonal. Thus, the goal of the investigation framework is to experiment all combinations, in order to find the best one to solve Problem 1. Hereafter, we separately discuss each dimension.

**Figure 2.** Dimensions of the investigation.

### 4.1.1. Feature Extraction

In order to apply the classification technique, we need to extract features to classify from texts, in order to obtain a different representation of texts. We decided to adopt the *n-gram model*, that is widely adopted in text classification.

Hereafter, we shortly introduce the four basic n-gram patterns we adopted in our investigation.

• **Uni-gram patterns.** In our model, a *uni-gram* is a single word (or token) that is present in the text. Uni-gram patterns are singleton patterns, i.e., a single word is a pattern itself (i.e., n-grams with *n* = 1).

Uni-gram patterns do not consider the relative position of words in the text.


With these premises, we can represent a document *d* (a message text, in our context) as a vector of terms, i.e., *d*[*j*] is the *j*-th term in the document. When we consider n-grams, the document is represented as a vector of n-grams, i.e., *d*[*j*] is the n-gram whose first word is in position *j* in the original document (of course, if *n* = 1, the vector of uni-grams and the vector of words coincide.).

Table 1 reports four different ways of representing a sample document, based on uni-grams, bi-grams, tri-grams and quad-grams, by reporting the different vectors that represent the same document. For example, if we consider the case *n* = 3 in Table 1, *d* contains only two items, i.e., *d*[1] and *d*[2].

**Table 1.** An example of n-gram patterns for the string "this is a sentence".


Moving from the methodology proposed in [36], we consider also combined features, i.e., features obtained by combining basic features (i.e., uni-grams, bi-grams, tri-grams and quad-grams).

Given *z* sets of basic features *BFi*, with 1 ≤ *i* ≤ *z*, a set *CF* of complex features is obtained by means of the Cartesian product of sets *BFi*, i.e., *CF* = *BF*<sup>1</sup> × *BF*<sup>2</sup> ×···× *BFz*. Thus, a feature in *CF* is a tuple of *z* basic features (n-grams).

As an example, consider Table 1. A feature obtained by combining a uni-gram and a bi-gram is the tuple "this", "a sentence".

In our framework, we considered the four basic feature patterns and five complex feature patterns. In Table 2, we report them and the corresponding abbreviation we will use throughout the paper.

**Table 2.** Basic and complex feature patterns computed by the investigation framework.


**Feature weight.** In order to help the construction of the classification model, features are weighted. Typically, in text classification the most-frequently adopted metric is *Term Frequency-Inverse Document Frequency* (TF-IDF) [37]. It is a numerical score which denotes the importance of a term in a collection of documents. TF-IDF is the combination of two scores which are called *Term Frequency* and *Inverse Document Frequency*. The comparative analysis in [38] demonstrated that TF-IDF significantly improves the effectiveness of classifiers.

The score balances the importance of a term for a given document with respect to its capability of characterizing a small number of documents. The rationale is that if a term is highly frequent in the collection, it does not characterize a subset of documents; thus, terms that appear in many documents cannot be considered relevant features for any document.

Consider a set *D* = {*d*1, ... .*dn*} of documents, where each document *di* ∈ *D* is a vector of terms (in the broadest sense, i.e., terms can be either n-grams or tuples of n-grams). The Term Frequency *T f*(*t*, *d*) of a term *t* in a document *d* is the number of times *t* appears within *d* on the total number of terms in *d* (see [39]). It is defined as:

$$Tf(t,d) = \frac{|\{j|d[j] = t\}|}{|d|}.$$

The Inverse Document Frequency *Idf*(*t*, *D*) of a term *t* in the collection (of documents) *D* measures the capability of *t* of denoting a small set of documents in *D*: the lower the number of documents in which *t* appears, the greater its *Idf* score [39]. It is defined as:

$$Idf(t, D) = \log\_{\epsilon} \frac{|D|}{|\{d \in D | (\exists j | d[j] = t)\}|}.$$

By combining *T f* and *Idf* , we obtain the overall *Tf Idf*(*t*, *d*, *D*) score of a term *t* within a document *d* belonging to a collection of documents *D*, as follows:

$$TfIdf(t,d,D) = Tf(t,d) \times Idf(t,D).$$

In our model, terms are either basic n-grams (basic feature patterns denoted as **U**, **B**, **T** and **Q**), or combined n-grams, such as **U,B** and so on (see Table 2): thus, we apply the *Tf Idf* metric to rank these features. However, we do not compute TF-IDF on a document basis, but on a class basis: the frequency of a term is the number of documents in the class that contain the term; the inverse document frequency should be properly called *Inverse Class Frequency*, because we count the number of classes that contain the term on the total number of classes. Formula 1 formally defines the weight.

$$\text{Weight}(t, c, \mathbb{C}) = \frac{|\{d | (d.c \text{class} = c \land \exists j | (d[j] = t))\}|}{|\{d | d.c \text{class} = c\}|} \times \log\_{\varepsilon} \frac{|\mathbb{C}|}{|\{c\_i \in \mathbb{C} | (\exists j | d[j] = t \land d.c \text{class} = c\_l)\}|} \tag{1}$$

in other words, the weight of a term *t* in a class *c* ∈ *C* is the frequency of *t* within the documents that belong to that class, multiplied by the inverse frequency of *t* among all classes in *C*. Notice that with *d*.*class* we denote the class which document *d* belongs to.

In Section 5.2, we perform experiments with the full set of features and with the strongest 80%, 65% and 50% features, on the basis of function *Weight*(*t*, *c*, *C*) defined in Formula 1.

### 4.1.2. Classification Techniques

The second dimension of investigation is to find out the classification technique that demonstrates to be more suitable for the application scenario. Recall from Problem 1 that the classifier has to discover the topic *tpj* that an unlabeled message *umtj* talks about. Hereafter, we briefly introduce the four classification techniques we considered in our investigation framework.

• **Naive Bayes (NB).** The Naive Bayes classifier [40] is a simple, fast, efficient, easy to implement and popular classification technique for texts: in fact, this technique is quite efficient as far as computation time is concerned; however, it performs well when features behave as statistically independent variables.

In short, it is a probabilistic classification technique, which completely depends on the probabilistic value of features. For each single feature, the probability that it falls into a given class is calculated. It is widely used to address many different problems, such as for predicting social events, for denoting personality traits, for analyzing social crime, and so on.


becomes the class of *d* [43]. For example, if *N*(*d*) contains 15 nearest neighbors to document *d*, where seven documents are labeled with the "politics" class, four documents are labeled with the "sports" class, three documents are labeled with the "weather" class and one document is labeled with the "health" class, then *d* is labeled with the "politics" class.

• **Random Forest (RF).** It is a known supervised learning method for classification devised by Ho [44]. It is an evolution of classical tree-based classifiers. The name of the technique is motivated by the fact that, during the training phase, many classification trees are generated: we can say that the classification model is a *forest* of classification trees. During the test phase, all the classification tress are independently used to classify the unclassified case: the class assigned by the majority of trees is chosen as class label assigned to the unclassified case.

This technique is very general and widely used in many application contexts, not only for text classification [45].

### *4.2. Framework Overview*

The investigation framework is composed of many modules. First of all, we give a high-level overview of them, describing the task performed by each single module.

• **Module M1-Pre-processor.** The module named *Pre-processor* performs many pre-processing activities on the data set, i.e., the set of labeled messages that constitutes the input data set for the investigation framework. Specifically, it performs tokenization, stop-word removal, special symbol elimination, and stemming (that are typical pre-processing techniques adopted in information retrieval). Specifically, stemming is important to reduce dimensionality of features: in fact, natural languages provide many different forms for the same word (for instance, singular and plural); stemming reduces words to the root form.

The result of the *Pre-processor* module is the corpus of messages, where each document is represented as an array of stems.


Figure 3 reports the organization of the framework, by illustrating data flows between and inside modules. They are discussed in details hereafter.

The framework is implemented in the Python programming languages, by exploiting the libraries nltk for pre-processing, sklearn for feature extraction, generation of n-gram combinations, training and testing of classifiers.

### 4.2.1. External Module EM1-Message Collector

This module is responsible to gather messages from the source micro-blog (in our case, Twitter) and support researchers to label messages with class labels denoting topics.

Since our investigation framework is designed to be independent of the specific source micro-blog, we decided to consider it as an external module that can be replaced with a different one, suitable to gather data from a different micro-blog.

**Figure 3.** The Investigation Framework.

4.2.2. Module M1-Pre-Processor in Details

When users write messages, they write punctuation, single characters and stop-words that are not useful for topic classification (and even decrease the accuracy of classification). So, before features are extracted, messages must be pre-processed in order to be cleaned from noise. Specifically, module *Pre-processor* performs text tokenization, special symbol removal, stop-words filtering and stemming. Hereafter, we describe these activities in more details.

Let us denote the input data set as *T* = {*l p*1, *l p*2, ... , }, where each *l pi* is a labeled message, such that *l pi* = *mti*, *ati*, i.e., *mti* is the message text and *ati* is the label class or topic associated to the message.

For each message *l p* ∈ *T*, on its message text *l p*.*mt* the module performs the following processing steps, in order to generate the set *Tp* of pre-processed messages.

1. The *l p*.*mt* message is tokenized, in order to represent it as a vector *d*<sup>1</sup> of terms, where a term is a token found within the message text.


$$= \begin{cases} \operatorname{RemoveSW}(d, pos} = \\ \begin{cases} \operatorname{RemoveSW}(\operatorname{remove}(d, pos}), pos \end{cases} & \text{if } (1 \le pos \le |d|) \land (|d[pos]| \le 2 \lor d[pos] \in \mathcal{S}\_L)) \\\ \operatorname{RemoveSW}(d, pos + 1) & \text{if } (1 \le pos \le |d|) \land \neg(|d[pos]| \le 2 \lor d[pos] \in \mathcal{S}\_L)) \\\ d & \text{if } \neg(1 \le pos \le |d|) \end{cases}$$

where *d* is the message represented as a vector of terms, |*d*| is the size of the vector, *pos* denotes a position index, *d*[*pos*] denotes the term (string within vector *d* in position *pos*), |*d*[*pos*]| denotes the length of the term (string) in position *pos* in vector *d*. Furthermore, *SL* is the list of stop-words, while function *remove*(*d*, *pos*) removes the item in position *pos* from vector *d*. The function is defined by the following formula.

$$remove(d,pos) = \begin{cases} \, d[1,(pos-1)] \bullet \, d[(pos+1),|d|] & \text{if } 1 < pos < |d|\\ \, d[2,|d|] & \text{if } pos = 1\\ \, d[1,(pos-1)] & \text{if } pos = |d| \end{cases}$$

where with *d*[*i*, *j*] we denote the sub-vector with items from position *i* to position *j* and the • denotes an operator that concatenates two vectors.

The *d*<sup>3</sup> vector representing the message without stop-words is obtained by calling the *RemoveSW* function as *d*<sup>3</sup> = *RemoveSW*(*d*2, 1).

4. After stop-word removal, stemming is performed on vector *d*<sup>3</sup> by applying the *Porter stemming algorithm* [47]; we obtain the final *d*<sup>4</sup> vector, i.e., the vector of terms that represent the *l p*.*mt* message text after pre-processing. The *d*<sup>4</sup> vector is paired with the class label *l p*.*at*, obtaining the pair *lmp* = *d*4, *l p*.*at* that is inserted into *Tp*, the set of pre-processed messages.

*Tp* is the final output of this module: the source data set *T* has been transformed into *Tp*, where instead of strings, message texts are represented by vectors of terms.

#### 4.2.3. External Module EM2-Data Splitter

The pre-processed data set *Tp* is now split into training set *TR* and test set *TE*. The training set becomes the input of module *M2-Feature Extractor*, while the test set *TE* will be used by module *M3-Multi-classifier*, for computing the accuracy. Notice that *TE* contains labeled messages: this is necessary for validating classification and compute the accuracy, by computing true positives, false positives, true negatives and false negatives simply by comparing the topic assigned by the classifier to the message and the label originally associated with the message.

This is an external module of the investigation framework, if compared to modules M1, M2 and M3, that are the core modules of the framework. This choice is motivated by the need for flexibility. In fact, different techniques for splitting could be used; this way, the investigation framework is parametric with respect to data splitting.

#### 4.2.4. Module M2-Feature Extractor in Details

Module *M2-Feature Extractor* receives the training set *TR* extracted from the overall set of pre-processed messages *Tp*. Its goal is to give different representations of each labeled message *lmp* ∈ *TR*, based on the basic and combined feature patterns reported in Table 2.

The module generates nine different versions of the training set *TR*, one for each feature pattern, denoted as *TRU*, *TRB*, *TRT*, *TRQ*, *TRU*,*B*, *TRB*,*T*, *TRT*,*Q*, *TRU*,*B*,*<sup>T</sup>* and *TRU*,*B*,*T*,*Q*. These are intermediate results, necessary to generate the actual output of the module, i.e., a pool of feature vectors *FV f p* (where *f p* denotes the feature pattern, as in Table 2): each *FV f p* contains a feature vector for each topical class. Let us start by describing the generation of *TRf p*.


Similarly, training sets *TR<sup>T</sup>* and *TR<sup>Q</sup>* contains descriptions *lm<sup>T</sup>* and *lm<sup>Q</sup>* of messages whose vectors *lmT*.*d* and *lmQ*.*d* are vectors of tri-grams and quad-grams, respectively.

• Training sets based on combined feature patterns are derived from training sets based on basic patterns.

Given a message *lmp*, its representations based on combined feature patterns are obtained as follows:


Once the training sets are prepared, for each of them (that generically we will denote as *TRf p*) the module performs the following activities.

1. For each message *lmf p <sup>i</sup>* <sup>∈</sup> *TRf p*, a set of terms *<sup>s</sup> f p <sup>i</sup>* is derived from the vector of terms *lmf p <sup>i</sup>* .*d*:

$$s\_i^{fp} = \{ t | \exists p \text{os} (1 \le p \text{os} \le | lm\_i^{fp}.d| \land lm\_i^{fp}.d [p \text{os}] = t) \}$$

so that duplicate occurrences of a term in *lmf p <sup>i</sup>* .*d* becomes a unique occurrence in *s f p i* .

2. The frequency matrix *Freq f p*[*t*, *c*] is built, where *t* is a term and *c* is a class (or topic). To obtain the *Freq f p* matrix, first of all the module builds *tc f p <sup>i</sup>* , a set of (term, class, message identifier) triples (*t*, *c*, *i*) obtained as

$$tc\_i^{fp} = s\_i^{fp} \times \{lm\_i^{fp}.at\} \times \{i\}$$

that is, by performing the Cartesian product among the set *s f p <sup>i</sup>* of terms in the message, the singleton set of the class label (topic) *lmf p <sup>i</sup>* .*at* associated to the message and the singleton set containing *i* (i.e., the identifier of the message). All the *tc f p <sup>i</sup>* sets are united into the *TCf p* set, i.e., *TCf p* <sup>=</sup> ∪∀*lmf p <sup>i</sup>* <sup>∈</sup>*TRf p* (*tc f p <sup>i</sup>* ).

Each single item of the *Freq f p* matrix is then computed as follows:

$$\operatorname{Freq}^{fp}[\mathfrak{t}, \mathfrak{c}] = |\{ (\mathfrak{t}\_{\mathfrak{I}}, \mathfrak{c}\_{\mathfrak{I}}, \mathfrak{t}\_{\mathfrak{I}}) \in \operatorname{TC}^{fp}|\mathfrak{t}\_{\mathfrak{I}} = \mathfrak{t} \wedge \mathfrak{c}\_{\mathfrak{I}} = \mathfrak{c} \}|\mathfrak{t}$$

where we count, for each term *t* and each class *c*, the number of different documents, associated to class *c*, which term *t* occurs in (the third element *ij* in triples is necessary to distinguish term occurrences coming from different messages).

3. For each term *t* in each class *c*, the module computes the *Weight*(*t*, *c*, *C*) score (defined in Formula 1), where *C* is the set of all class labels. We denote the weight for the feature pattern *f p* as *wf p*; it is defined as:

$$\mathcal{W}^{fp}(t,\mathbf{c}) = \mathcal{W} \text{weight}(t,\mathbf{c},\mathbf{C}) = \frac{\text{Freq}^{fp}[t,\mathbf{c}]}{\sum\_{\forall t\_{j}} \text{Freq}^{fp}(t\_{j},\mathbf{c})} \times \log\_{\mathcal{E}}\left(\frac{|\mathbf{C}|}{|\{c\_{k}|\text{Freq}^{fp}[t,c\_{k}] > 0\}| }\right) \tag{2}$$

where *C* is the overall set of class labels. In the product on the right-hand side of the formula, the first operand is the term frequency, while the second operand is the inverse document frequency.

4. Finally, for each class *ck* <sup>∈</sup> *<sup>C</sup>*, the feature vector *<sup>f</sup> f p*(*ck*) for the given feature pattern is built, where each item is a pair (*t*, *w*), where *t* is the term and *w* = *w*(*t*, *ck*) is the weight. The sets *FV f p* of feature vectors *f f p*(*ck*), where *f p* denotes the feature pattern (as in Table 2), are the final output of module M2.

### 4.2.5. Module M3-Multi-Classifier in Details

Module *M3-Multi-classifier* performs the last step of the investigation process, i.e., it builds the classification models by training the classifiers, then exploits the classification models to label the test set. It is called *multi-classifier* because it uses all the four classification techniques shortly presented in Section 4.1.2, to train a classification model and label the test set.

Let us describe the process performed by Module M3 in details. The module receives two inputs: the pool of feature vector sets *FVU*, *FVB*, *FVT*, *FVQ*, *FVU*,*B*, *FVB*,*T*, *FVT*,*Q*, *FVU*,*B*,*<sup>T</sup>* and *FVU*,*B*,*T*,*Q*, generated by module M2, and the test set *TE*, generated by the external module EM2. For each one of the training sets and for each one of the classification techniques, the module performs the following activities.


At this point, the module performs the accuracy evaluation, i.e, it evaluates accuracy of classification for all the classified test sets, in order to produce a final report, that is the outcome of the investigation framework.

### **5. Experimental Evaluation**

The investigation framework was run on a data set specifically collected. In Section 5.1, we present both the way we collected and prepared the data set, as well as the metrics we adopted to evaluate the classification results. In Section 5.2, we present the experiments and discuss the results, as far as the effectiveness of classification is concerned, while in Section 5.3 we present the sensitivity analysis of classifiers.. Then, Section 5.4 considers execution times and introduces the metric called *Suitability*.

#### *5.1. Data Preparation and Evaluation Metrics*

To perform the experimental evaluation through the proposed investigation framework, we performed data collection and labeling. Data collection is the process of collecting messages (from Twitter) that are relevant to the problem domain. It is a crucial step, because it strongly determines the results obtained by classifiers. Messages were collected from Twitter by using Tweepy API [48]; 133,306 messages were collected from different accounts.

The next step was to manually label the messages with a pool of predefined topics. In this process, we involved five volunteer students of the Masters degree at University of Sialkot (Pakistan), to label messages. Each message was labeled by two different students, that worked separately: in the case two different labels were assigned to the same message, the message was discarded from the data set. This way, only messages labeled with the same class by two different students were considered: messages that did not clearly talk about one of the selected topics were not considered.

Hereafter, we list the topics considered as classes and the criteria adopted for labeling messages with each single topic.

• *Business*: Messages talk about stocks, business activities, oil prices, Wall Street and companies' shares.


The list of topics was inspired by [49], that proposed a list of categories for classifying sensitive tweets; we did not consider all the list proposed in [49], because some of the proposed categories did not denote topics that users would use to label messages (e.g., racism); we selected and integrated those that, presumably, could be often used by users. Table 3 provides a sample message for each one of the topical classes.


**Table 3.** Chosen topics with example messages.

In order to have a homogeneous distribution among classes, the training set *TR* contained 3500 messages for each class, while the test set *TE* contained 1500 messages for each class. Consequently, the training set *TR* contained 24,500 messages, while the test set *TE* contained 10,500 messages; the total number of messages was 35,000, that constitute the input for the investigation framework. All messages were written in English.

Remember that messages in the test set *TE* were labeled by hand as well, in order to allow module M3 to automatically compute accuracy.

To evaluate the results, the investigation framework computed accuracy, precision, recall and F1-measure. These measures are typical metrics adopted in information retrieval. Since we operated in a context of multi-label classification, we adopted the definitions reported in [50,51].

Given a set *C* = {*c*1, *c*2,..., *cn*} of class labels, for each class *cj* we define the following counts:


For each class *cj*, we can define the four above-mentioned metrics.


Since we are in a context of multi-label classification, we have to compute a general global version of each measure. This is usually done by averaging the values computed for each class. Consequently, *Accuracy* = (∑*cj*∈*<sup>C</sup> Accuracyj*)/|*C*|, *Precision* = (∑*cj*∈*<sup>C</sup> Precisionj*)/|*C*|, *Recall* = (∑*cj*∈*<sup>C</sup> Recallj*)/|*C*|, *F1-measure*= (∑*cj*∈*<sup>C</sup> F1-measurej*)/|*C*|.

We are now ready to discuss the results of our investigation, based on the two dimensions discussed in Section 4.1.

#### *5.2. Experiments and Comparison of Classifiers*

Based on the dimensions of investigation discussed in Section 4.1, we performed a large number of experiments, that involved the four classification techniques presented in Section 4.1.2.

Let us start considering the results obtained by the Naive Bayes classifier. Table 4 is organized as follows: for each basic n-gram pattern, i.e., **U**, **B**, **T** and **Q**, as well as for each combined feature pattern **U,B**, **B,T**, **T,Q**, **U,B,T** and **U,B,T,Q**, the full set of features (100%) and the most relevant 80%, 65% and 50% of features, on the basis of their weight defined in Formula 1 are used to perform experiments.


**Table 4.** Experimental results for Naïve Bayes classifier.


**Table 4.** *Cont.*

Similarly, Table 5 shows the results obtained by applying the kNN classification technique to the same feature patterns previously discussed; in the same way, we report the sensitivity analysis for each feature pattern. Table 6 reports the results obtained by applying the SVM classification technique, while Table 7 reports the results obtained by applying the Random-Forest classification technique.


**Table 5.** Experimental results for kNN classifier.


**Table 5.** *Cont.*

**Table 6.** Experimental results for Support Vector Machines (SVM) classifier.



**Table 7.** Experimental results for Random-Forest classifier.

Figure 4 depicts the results obtained by each classifier for all feature patterns by using the full set of features extracted from the training set. The blue line depicts the results obtained by the Naïve Bayes classifier; the red line depicts the results obtained by the kNN classifier; the brown line depicts the results obtained by the SVM classifier, the black line depicts the results obtained by the Random-Forest classifier.

We can notice that the Naïve Bayes classifier (blue line) always performed as the best classifier, always obtaining the highest accuracy. The SVM classifier (brown line) performed only a little bit worse, but results were comparable. The Random Forest classifier still showed comparable accuracy, even though a little bit less than Naïve Bayes and SVM classifiers.

In contrast, the inability of the kNN classifier to exploit most of feature patterns was evident. In details, we noticed that for **U**, **U,B**, **U,B,T** and **U,B,T,Q** feature patterns, the kNN classifiers obtained results that were comparable with the other classifiers. Instead, for feature patterns that did not include uni-grams, the kNN classifier obtained very poor results.

Nevertheless, notice that the other three tested classification techniques suffered for the absence of uni-grams in the feature sets as well, even though they behaved better than the kNN classifier.

If we focus on results obtained by each classifier for feature patterns that contain uni-grams, it clearly appears that no advantage was obtained by combining uni-grams with other features. Looking at Table 4, we see that the Naïve Bayes classifier obtained a very slight improvement; in contrast, looking at Tables 5–7, we can see a slight deterioration of accuracy, when comparing the **U** pattern with **U,B**, **U,B,T** and **U,B,T,Q** combined patterns.

**Figure 4.** Comparing accuracy of classifiers for different feature patterns with 100% features.

### *5.3. Sensitivity Analysis*

We can now consider the sensitivity analysis we performed. Recall that, apart from the full set of features, we also considered the best 80%, 65% and 50% of features, on the basis of the weight defined in Formula 1.

Figures 5–7 depict the results so far obtained, respectively, with the 80%, 65% and 50% of features. In the 80% case (Figure 5), no significant variations appeared: the performances obtained by all classifiers were, more or less the same. This is also confirmed by looking at the tables, that show very small reductions of accuracy. Nevertheless, the general behavior of the four classifiers remained exactly the same as for the full set of features. Consequently, we could argue that it was not the case to use the full set of features for training the classifiers, so as to save time and computational power.

**Figure 5.** Comparing accuracy of classifiers for different feature patterns with 80% features.

Considering the 65% case (depicted in Figure 6), and the 50% case (depicted i Figure 7), we still observed a very slight reduction of accuracy. Only the kNN classifier behaved significantly worse with uni-gram patterns in the 50% case; nevertheless, looking at Figure 7, we notice that with patterns

**U,B,T** and **U,B,T,Q**, the combined feature patterns that contained uni-grams helped the classifier to obtain good results.

**Figure 6.** Comparing accuracy of classifiers for different feature patterns with 65% features.

**Figure 7.** Comparing accuracy of classifiers for different feature patterns with 50% features.

In effect, looking at Table 5, we can see that both precision and recall strongly penalized the kNN classifier, with respect to the other competitors. This happened with all feature patterns.

### *5.4. Execution Times and Suitability Metric*

Based on accuracy, the kNN classifier was not suitable for the investigated application context, while the other classifiers showed comparable performance. However, the cost of computation is an important issue, thus we also gathered execution times both for training and testing.

We performed experiments on a PC powered by Processor Intel(R) Core(TM) i7-5600U, with clock frequency of 2.60 GHz, equipped with 8 GB RAM; the operating system was Windows 10 Pro (64 bit).

Table 8 reports the execution times shown by the four classifiers on the full set of features, for the most promising feature patterns, i.e., **U**, **U,B**, **U,B,T** and **U,B,T,Q**. Specifically, we evaluated execution times during the training phase and during the test phase; notice that we also measured the execution times concerned with feature extraction, so as to understand how heavy the computation of Cartesian products of features was.

The first thing we can notice is that feature extraction was performed in a negligible time, if compared with the actual training performed by the classifier; even in the case of the most complicated feature pattern (i.e., **U,B,T,Q**), this time was negligible. Nonetheless, to obtain a given feature pattern, experiments confirmed that the library we adopted was deterministic, since the execution time was independent of the specific attempt.


**Table 8.** Execution times (in seconds) for the most promising feature patterns.

In contrast, looking at the execution time for model building, the reader can see that there were significant differences. Consequently, in order to choose the best classifier, the cost of computation should be considered. For this reason, we defined a cost-benefit metric, in such a way accuracy represents the benefit, while execution time represents the cost. We called this metric *Suitability*, because by means of it we wanted to rank classifiers in order to find out the one that was suitable for our context.

Consider a pool of experiments *E* = {*e*1,*e*2, ... ,*eh*}, where for an experiment *ei* we refer to its accuracy as *ei*.*Accuracy*, to its training execution times as *ei*.*trtime* (the execution time shown during the training phase) and to the test execution times as *ei*.*tetime* (the execution time shown during the test phase). The *Training Suitability* of an experiment is defined as

$$\text{TrainingSuitableity}(e\_i) = e\_i.\\ \text{Accuracy} \times \frac{\text{mintrtime}}{\text{mintrtime} + (e\_i.\text{trtime} - \text{mintrtime}) \times \beta} \tag{3}$$

where *mintrtime* = *min*∀*ei*∈*E*(*ei*.*trtime*) (i.e., the minimum training execution time). *<sup>β</sup>* is a importance weight of the difference between the training execution time of the *ei* experiment and the minimum training execution time; we decided to set it to 50%, in order to mitigate the effect of execution times on the final score; in fact, with *β* = 1, the penalty effect would be excessive.

Similarly, we can define the *Testing Suitability*, defined in Formula 4.

$$\text{TestingSatisfaction}(e\_i) = e\_i.\text{Accuracy} \times \frac{mintime}{mintime + (e\_i.\text{time} - mintime) \times \gamma} \tag{4}$$

where *mintetime* = *min*∀*ei*∈*E*(*ei*.*tetime*) (i.e., the minimum testing execution time among the experiments). Similarly to *β*, *γ* is the relevance of the difference between execution time of the *ei* experiment and the minimum testing time. We decided to set it to 50% as well.

*Training Suitability* and *Testing Suitability* ranked experiments by keeping the two phases (training and testing) separated.

In Formula 5, we propose a unified *Suitability* metric.

$$\text{Suitability}(e\_i) = a \times \text{TrainingSuitablely}(e\_i) + (1 - a) \times \text{TestingSuitablely}(e\_i) \tag{5}$$

i.e., the unified suitability is the weighted average of *Training Suitability* and *Testing Suitability*, where *α* ∈ [0, 1] balances the two contributions.

Table 9 reports the values of training suitability, testing suitability and unified suitability for the same experiments considered in Table 8. Figure 8 depicts the results, by using the same convention as in Figure 4, by using the unified suitability. The reader can see that the Naive Bayes classifier had the highest suitability values, due to its ability to combine high accuracy and very low execution times. Surprisingly, the kNN classifier obtained the second position; in fact, in spite of the fact that it obtained the worst accuracy values, it obtained the lowest execution times. Finally, the SVM classifier and the Random-Forest classifier were strongly penalized by their execution times.


**Table 9.** Suitability for the most promising feature patterns.

Consequently, on the basis of Table 9 and Figure 8, we can clearly say that the Naive Bayes classifier applied to the feature pattern of uni-grams clearly emerged as the most suitable solution for our application context (presented in Section 3) and specifically to solve Problem 1.

**Figure 8.** Suitability for the most promising feature patterns.

### **6. Conclusions and Future Work**

In this paper, we have posed the basic brick towards the extension of micro-blog user interfaces with a new functionality: a tool to recommend users with other users to follow (influencers) on the basis of topics their message talk about. The basic brick is a text classification technique applied to a given feature pattern that provides good accuracy by requiring limited execution times. To identify it, we built an investigation framework, that allowed us to perform experiments, by measuring effectiveness (accuracy) and execution times. A cost-benefit function, called *Suitability*, has been defined: by means of it, we discovered that the best solution to address the problem is to apply a Naive Bayes classifier to uni-grams extracted from within messages, both to train the model and to classify unlabeled messages. We considered execution times because, in our opinion, the envisioned application scenario asks for fast functionalities; thus execution times emerge as critical factors. At the best of our knowledge, this comparative study of performances shown by classifiers, based on both accuracy and execution times, is a unique contribution of this paper.

The next steps towards the more ambitions goal of building a recommender system for influencers is to develop the surrounding methodology that actually enables to recommend influencers: in fact, once messages are labeled with topics, it is necessary to rank potential influencers, on the basis of the frequency with which they post messages about a given topic. This methodology will be the next step of our work.

**Author Contributions:** Conceptualization and methodology, G.P.; software, M.A., A.B.; writing—original draft preparation, M.A.; writing—review and editing, G.P.; Data collection and annotation, S.M., A.B. All authors have read and agreed to the published version of the manuscript.

**Funding:** This research received no external funding.

**Conflicts of Interest:** The authors declare no conflict of interest.

### **References**


© 2020 by the authors. Licensee MDPI, Basel, Switzerland. This article is an open access article distributed under the terms and conditions of the Creative Commons Attribution (CC BY) license (http://creativecommons.org/licenses/by/4.0/).

## **Time-Aware Learning Framework for Over-The-Top Consumer Classification Based on Machine- and Deep-Learning Capabilities**

### **Jaeun Choi <sup>1</sup> and Yongsung Kim 2,\***


Received: 29 October 2020; Accepted: 26 November 2020; Published: 27 November 2020

**Abstract:** With the widespread use of over-the-top (OTT) media, such as YouTube and Netflix, network markets are changing and innovating rapidly, making it essential for network providers to quickly and efficiently analyze OTT traffic with respect to pricing plans and infrastructure investments. This study proposes a time-aware deep-learning method of analyzing OTT traffic to classify users for this purpose. With traditional deep learning, classification accuracy can be improved over conventional methods, but it takes a considerable amount of time. Therefore, we propose a novel framework to better exploit accuracy, which is the strength of deep learning, while dramatically reducing classification time. This framework uses a two-step classification process. Because only ambiguous data need to be subjected to deep-learning classification, vast numbers of unambiguous data can be filtered out. This reduces the workload and ensures higher accuracy. The resultant method provides a simple method for customizing pricing plans and load balancing by classifying OTT users more accurately.

**Keywords:** consumer classification; deep learning; machine learning; over-the-top; time-aware classification

### **1. Introduction**

With the advancements of smart devices and the rapid development of wired and wireless networks, our modes of entertainment and venues for information are changing rapidly. In the past, our receipt of multimedia mainly relied upon standardized broadcasts from franchise television (TV) networks. With the turn of the century, our multimedia consumption has centered around internet smartphones, as typified by the "over-the-top" internet-protocol (IP) services of YouTube and Netflix [1]. Because of its ubiquity and simplicity, OTT content can be viewed worldwide without TVs. According to PricewaterhouseCoopers, the global OTT market is expected to grow sharply from USD 81.6 B in 2019 to USD 156.9 B in 2024 with an annual average growth of approximately 14% [2]. Furthermore, subscription OTT services are expected to reach approximately 650 million by 2021 [3]. With this growth, the dominant OTT players such as YouTube, Netflix, Amazon Prime, and Hulu, are actively promoting entry into the global market, and Disney, Apple, Warner Media, and HBO are gaining access using their existing content. The rising OTT market is indeed becoming a fierce battleground.

Consequently, network broadcast media and more recent TV-over-IP enterprises have experienced heavy competition [4,5]. On the other hand, internet-service providers (ISP) increasingly find themselves in conflicts with OTT media providers because of bandwidth- and fee-related issues. Because OTT services consume a vast amount of network resources, ISPs seek to bolster their profits to support growth [6]. Furthermore, consumers increasingly demand higher quality of service (QoS) from their ISPs [7] who desperately need to expand their throughput capabilities [8]. Often, the ISPs respond to bandwidth overloads by limiting the amount of throughput (i.e., throttling) based on the OTT service in demand. This process is most often reactionary, inevitably creating surges of consumer complaints. To balance this vicious cycle of competing demands, the ISPs require better and more-timely consumer OTT-usage analysis capabilities, so that they can better mitigate network performance issues while balancing customer demand. Furthermore, utilizing a sound and trustworthy tool such for this purpose would put the ISPs in a better position to negotiate with OTT providers [9].

Traffic analysis models have been widely researched and utilized for this purpose in conventional hypertext-transfer-protocol (HTTP) mobile-network environments. Machine learning has been key to their success. However, only a few academic studies have focused on OTT content in the context of strategic service provision [9]. In this study, we analyze network consumption patterns based on consumers' OTT-usage patterns, confirming that a combination of machine- and deep-learning capabilities can achieve the highest accuracy. Additionally, by mitigating the time and resource requirements of deep learning, we provide a novel MetaCost-based framework related to OTT user analysis that can reduce the time required for analysis while exploiting the technology's high accuracy. This framework drastically reduces the analysis workload, making the process very efficient and timely so that ISPs can achieve instantaneous status and influence over OTT service demands.

This paper is structured as follows. In Section 2, we examine why the analysis of OTT-related trends and data-usage patterns is critical. We also justify the application of machine and deep learning to this pursuit with a review of previous traffic-analysis studies. In Section 3, we fully describe the OTT user-analysis framework. Section 4 presents a discussion of our experimental results. Finally, Section 5 presents the conclusions of this study with further research directions.

### **2. Literature Review**

### *2.1. OTT Services*

The success of OTT services is owed, in part, to the increase in the number of single-person households and their desire for highly personalized content. With the phenomena of cord-cutting, which circumvents paid broadcasting, and cord-shaving, which leverages alternate broadcasting venues, the demand for new and innovative OTT services will not likely decrease [10]. For many worldwide consumers, OTT services have replaced legacy subscription models. In Korea and China, consumers spend around USD 3 per month for high-definition OTT services with tailored recommendation systems [11].

Currently, the OTT market is dominated by giant companies such as Netflix, YouTube, Amazon Prime Video, and Hulu, which account for approximately 75, 55, 44, and 32% of the US OTT market, respectively. As a whole, these firms currently account for 79% of the US market share [12]. Competitors are quickly entering the fray. Disney launched Disney+ in November 2019 after acquiring 21st Century Fox, securing 50-million US subscribers as of July 2020 [13]. After acquiring Warner Media, AT&T launched HBO Max, which utilizes current and past HBO content, securing more than 34-million US subscribers as of June 2020 [14].

The rapidly changing OTT landscape presents both a crisis and an opportunity for conventional TV networks. Although it proves to be a disadvantage to extant strategies, it does provide an opportunity to enter new markets by providing OTT services using their current content and service supply chains [5]. Hulu, launched in 2008, dominates the market with content from FOX, NBC, and ABC [4]. In Korea, there are ~3-million monthly Wavve subscribers, which offers content from KBS, MBC, and SBS [15].

The market positions and strategies of ISPs are more complicated than those of legacy TV networks. ISPs that simultaneously provide IP-TV and internet services must install and maintain high-quality broadband infrastructures for both. Additionally, bundling strategies are required to prevent cord-cutting that would cancel IP-TV services in favor of OTT services alone [16]. In fact, in Korea, KT, which holds the highest share of the IP-TV market, is considering a strategy to create synergy through a partnership with Netflix. As such, Netflix could use the opportunity to expand their

market further into Korea. KT has the largest number of wired network subscribers and can collect more subscribers and their abundant network usage fees by providing Netflix content [17]. With the gradual distribution of 5G, the number of customers using wireless OTT services is bound to grow. If stable QoS cannot be guaranteed, customer churn will be difficult to deal with. In particular, with the increase of the use of real-time video-streaming services (e.g., Twitch and Discord) stable services will forever be challenged [8,18].

As mentioned, ISPs must be able to dynamically execute service degradation plans to minimize network-resource overconsumption while meeting the QoS needs of consumers. That is, it is essential to create a win–win situation for both ISPs and OTT providers. Based on assumptions of network neutrality, various studies have analyzed the complex pricing systems related to content providers, networks, and consumers. Dai et al. [7] proposed a pricing plan that could guarantee QoS based on the Nash equilibrium. They showed that the direct sale of QoS by an ISP to a consumer achieved better results than selling QoS to an OTT provider. Based on the quality of experience (QoE), a model that would benefit all OTT providers, networks, and consumers was also proposed [19]. This study compared three methods: Providing better QoE to customers that paid more; satisfying QoEs of the most profitable customers (MPC) to increase lifetime value; and providing fair QoE to all customers. Of these, the method of providing QoE to MPCs was found to be the most beneficial. A study based on shadow pricing was also conducted to determine an effective method to price broadband services [20], concluding that setting the pricing plan according to the usage patterns of consumers was the best strategy. Because OTT services utilize a considerable amount of network bandwidth, some studies have proposed methods of predicting network consumption and pricing via a content-delivery or a software-defined network [21,22]. With the spread of OTT services, the pricing-related issues of OTT and network providers persist, and most existing studies have suggested plans based on the amount of network usage. Thus, in order for networks or OTT providers to establish an optimal strategy related to OTT, the OTT-service usage patterns of consumers must be identified very quickly. Both network and OTT providers can establish effective pricing strategies only when they can correctly and immediately identify which users are using what OTT services and how much data they consume, classifying all items and load-balancing accordingly. Thus, OTT user classification is the first step in establishing an effective strategy.

### *2.2. Review of Classification Using Machine Learning*

Researchers have conducted extensive studies on methods to manage and operate networks by analyzing network traffic and user behaviors. Hence, the widespread use of network technologies incorporating artificial-intelligence (AI) technologies has gained attention. Extensive research has been conducted on the knowledge-defined-networking paradigm, in which AI technology is incorporated into network routing, resource management, log analysis, and planning. Several companies have already applied AI data analysis to network operations [23]. In turn, multiple studies have been conducted to determine how network providers can leverage machine learning to analyze user traffic. Middleton and Modafferi [24] exploited machine learning to classify IP traffic in support of QoS guarantees. Yang et al. [25] proposed the classification of Chinese mobile internet users into heavy and high-mobility users by analyzing the network traffic of 2G and 3G services. Various other studies proposed methods, such as decision trees [26,27], support-vector machine (SVM) [28–30], *k*-nearest neighbor (KNN) [31,32], hidden Markov model (HMM) [33,34], and K-Means [35,36] for traffic classification. The application targets of these methods differed depending on whether the analysis was performed on wired, wireless, or encrypted traffic. However, these techniques demonstrated the following structure: They captured and analyzed network traffic; they applied machine learning by using traffic characteristics as features; and they classified the traffic data. The traffic data were diverse, ranging from captured packets to public datasets. However, data related to OTT usage, which is a recent trend, were rarely considered. Few studies have proposed methods to classify users based on OTT consumption [9,37]. Those that did classified consumers into three consumption categories

(i.e., high, medium, and low) by using various machine-learning methods to analyze OTT traffic. The current study is significant in that it is the first to use deep learning in an attempt to classify users in terms of OTT usage. However, deep learning has the disadvantage of requiring large numbers of calculations. On the other hand, it has the advantage of high accuracy. Hence, it has been widely utilized for similar classification problems [38–40]. Therefore, in this study, to overcome the demerits of excessive time-consumption, we propose a time-aware user-analysis framework that applies the MetaCost method [41]. Based on Bayes risk theory, MetaCost can reduce specific classification errors while setting the cost of misclassification differently. By using these properties, we can reduce the load on deep learning through cost adjustment.

### **3. Research Design**

First, we verified whether OTT users can be effectively classified using machine- and deeplearning methods. The description and application method of the machine- and deep- learning used in this study are detailed in Section 3.1. In addition, we were able to confirm that the classification accuracy was high when using deep learning; however, it was also found that the time required was large due to the characteristics of deep learning. Therefore, in Section 3.2, we propose a framework that can reduce time consumption but utilizes the accuracy of deep learning as well.

### *3.1. Appling Machine- and Deep- Learning to OTT Consumer Classification*

Figure 1 shows the process of analyzing OTT users based on machine- and deep- learning [42,43]. The steps include raw data collection, data preprocessing for feature extraction, and dataset processing for machine learning. We leverage OTT usage data for this purpose. Our dataset was previously published [37] and is open to the public. It includes general network traffic characteristics but is also appropriate for OTT-specific research. It contains traffic information about the activities of actual internet users with respect to 29 types of OTT services. A detailed description of the dataset is presented in Section 3.3.

**Figure 1.** Machine- and deep-learning-based over-the-top (OTT) consumer classification process.

In this study, we analyze OTT users according to the conventional machine-learning methods of KNN, decision tree, SVM, naïve Bayes, and repeated incremental pruning to produce error reduction (RIPPER) models. We also use the multilayer perceptron (MLP) and convolutional neural network (CNN) as deep-learning applications.

### 3.1.1. Conventional Machine Learning Methods

**The KNN** is a typical classification method that applies clustering. It first confirms which class the *k* neighbors of a data point belong to, and it then performs classification by taking the majority vote based on the result. If *k* = 1, it is assigned to the nearest-neighbor class. Therefore, if *k* represents an odd number, classification becomes easier, because there is no possibility of a tie [44]. A neighbor in KNN can be defined by calculating the distance between vectors in a multidimensional feature set. The distance between vectors is calculated using Euclidean and Manhattan distances. Suppose that input-sample *x* has *m* features. The feature set of *x* is expressed as (*x*1, *x*2, ··· , *xm*), and the Euclidean and Manhattan distances of *x* and *y* are defined as follows [45]:

$$\text{Euclidean}: D(\mathbf{x}, y) = \sqrt{\sum\_{i=1}^{m} \left| \mathbf{x}\_i - y\_i \right|^2} \tag{1}$$

$$\text{Manhattan}:\ D(\mathbf{x},\mathbf{y}) = \sum\_{i=1}^{m} |\mathbf{x}\_{i} - \mathbf{y}\_{i}|.\tag{2}$$

The greatest advantage of the KNN is its straightforwardness, and the variables are not required to be adjusted. This method only requires the assignment of a *k* value. However, if the distribution is distorted, KNN cannot be applied, because the data may not belong to the same class, even if it is close to its neighbors. Additionally, when dealing with multidimensional data having many features, classification accuracy may be degraded, owing to the curse of dimensionality. Thus, it is essential to reduce features [46].

**Decision trees** are built upon the tree model and re-used to classify an entire dataset into several subgroups. When traversing from an upper to a lower node, nodes are split according to the classification variables. Furthermore, nodes of the same branch have similar attributes, whereas nodes of different branches have different attributes. The most typical decision-tree algorithms are ID3 and C4.5. The C4.5 algorithm, derived from ID3, minimizes the entropy sum of the subsets by leveraging the concept of "information gain". The subset is split to the direction that maximizes information gain. Thus, accuracy is high when classification is performed through the learned result. Decision trees have the advantages of intuitiveness, high classification accuracy, and simple implementation. Therefore, they are widely adopted for various classification tasks. However, for data containing variables at different levels, the level is biased mainly to most of the data. Moreover, for a small tree with fewer branches, rule extraction is easy and intuitive, but accuracy may decrease. Moreover, for a deep and wide tree having many branches, rule extraction is difficult and non-intuitive, but accuracy may be higher than that of a small tree.

**The SVM** performs classification based on a hyperplane that separates two classes in the feature space. The hyperplane having the longest distance between the closest data points to the hyperplane in two classes is set to have the maximal margin. For inputs *x* and *y*, the hyperplane separating the classes is defined as *wT*·*x* + *b* = 0. After finding the distance between the hyperplane and closest data points of two classes, the optimization equation for maximization is defined as follows:

$$\min\_{\mathbf{w}, \ b} \Phi(\mathbf{w}) = \frac{1}{2} \|w\|^2,\tag{3}$$

where variables *w* and *b*, which satisfy the following convex quadratic programming, become variables that build the optimal hyperplane [45]:

$$\text{s.t.} \, y\_i(\mathbf{w}^T \mathbf{x}\_i + b) \ge 1, \; i = 1, \dots, l. \tag{4}$$

The SVM achieves high performance for a variety of problems. It is also known to be effective for the case of many features. Although SVM solves binary-class problems and can be applied to multiclass problems having various classes, it must solve multiple binary-class problems to derive accurate results. Thus, its calculation time is relatively long [45,46].

**The naïve Bayes method** is a typical classification method based on the statistical assumptions of the Bayes' theorem. It starts from the assumption that all input features are independent. When this is true, classes can be assigned through the following process [46]:

$$y(f\_1, f\_2, \dots, f\_m) = \arg\max\_{k \in \{1, \dots, K\}} p(\mathbb{C}\_k) \prod\_{i=1}^m p(f\_i \middle| \mathbb{C}\_k),\tag{5}$$

where *m* is the number of features, *k* is the number of classes, *fi* is the *i*th feature, *Ck* is the *k*th class, *p*(*Ck*) is the prior probability for *Ck*, and *p*(*fi Ck*) is the conditional probability for feature *fi* given class *Ck*.

The key merit of naïve Bayes is its short calculation time for learning. This is because, under the assumption that features are independent, high-dimensional density estimation is reduced to 1D kernel-density estimation. However, because the assumption that all features are independent is unrealistic, accuracy may decrease when performing classification using only a small sample of data. To increase accuracy, a large amount of data should be collected [46,47].

**The RIPPER** algorithm is a typical rule-set classification method [48]. Rules are derived by training the data using a separate-and-conquer algorithm. In turn, the rules are set up to cover as many datasets as possible, as developed using the current training data. The rules are pruned to maximize performance. Data correctly classified according to the rules are then removed from the training dataset [46]. The RIPPER algorithm overcomes the shortcoming of early rule algorithms, wherein big data could not be effectively processed. However, because the RIPPER algorithm starts by classifying two classes, performance can decrease when the number of classes increases. Its performance may also decrease because of its heuristic approach.

### 3.1.2. Deep Learning

After the AlphaGo (AI) beat Lee Se-dol (human) in the 2016 Google DeepMind Go challenge match, deep learning captured the attention of the worldwide public. However, research on deep learning had already been actively underway in academia and practical application fields. Deep learning is an extension of the artificial neural network, and it learns and makes decisions by configuring the number of layers that make up its neural network. It took a while for computer hardware to catch up, but with recent graphical processing-unit developments, deep learning has been widely applied in various fields. Deep learning automatically selects features through its training process. It does not require much assistance from domain experts and learns complex patterns of constantly evolving data [43]. As such, many related studies on internet traffic analysis have been published [49]. MLP and CNN deep-learning methods are used for this paper.

**The MLP** has a simple deep-learning structure and comprises an input layer, an output layer, and a hidden layer of neurons. Figure 2 shows the structure of the MLP. In each layer, several neurons are connected to the adjacent layer. The neurons calculate the weighted sum of inputs and output the results via a nonlinear activation function. In this process, the MLP uses a supervised back-propagation learning method. Because all nodes are connected within the MLP, each node in each layer has a specific weight, *wij*, with all nodes of the adjacent layer. Node weights are adjusted based on back-propagation, which minimizes the error of the overall result [49]. However, the MLP method has the disadvantage of being very complex and inefficient, owing to the huge number of variables the model must learn [43]. Accordingly, to use an MLP, it is necessary to acquire data that is not too complex or to pay close attention to time consumption. The OTT dataset used in this study has quantitative values for each feature. Since it does not have complicated feature structures such as images or videos, it shows sufficiently good performance even if only simple MLP is applied. Therefore, we tried to save the time required for learning and detection by using the simplest MLP structure possible. In this study, we conduct an experiment with an input layer, an output layer, and a single hidden layer between them.

**The CNN** is similar to the MLP, comprises several layers, and updates variables through learning. Although the MLP does not handle multi-dimensional inputs well, the CNN does so by applying a convolution layer. Figure 3 shows the structure of the CNN. The convolution layer produces results for the next layer by using kernels with learnable variables as inputs. The local filter is used to complete the mapping process, which is regarded as a convolution function. Additionally, because it is replicated in units, it shares the same weight vector and bias, thus increasing efficiency by greatly reducing the number of parameters. CNNs use a pooling process for down-sampling and can be widely applied to a variety of classifications. If the dimension of a vector used in the CNN process is 1, 2, or 3, it corresponds to 1D-, 2D-, and 3D-CNNs, respectively. The 1D-CNN is suitable for sequential data (e.g., language), the 2D-CNN is suitable for images or audio, and the 3D-CNN is suitable for video or large-volume images. Although there has been no research on classifying OTT traffic data by CNN, a study analyzing general network traffic via CNNs utilized a 1D-CNN [50]. This is because traffic characteristics are sequential; therefore, 1D-CNNs are sufficient and multi-dimensional CNNS are not required. In this study as well, OTT traffic was analyzed using a 1D-CNN since OTT traffic is similar to traditional traffic data. In addition, the filtering and pooling processes were performed using two convolutional layers as the complexity of the dataset was not high.

**Figure 3.** Convolutional neural network (CNN) process.

As described in detail in the experimental results of Section 4, when deep learning is applied to OTT user analysis, the accuracy is higher than when applying general machine-learning methods, although it takes much longer. The number of users of the dataset used in this study is 1,581, which is considerably less than the number of users serviced by ISPs or OTT providers. With larger numbers of simultaneous users, the time consumption could become prohibitive. Therefore, we apply the aforementioned MetaCost framework.

#### *3.2. Time-Aware Consumer Classification Based on MetaCost and Deep Learning*

In this study, we classify consumers into three consumption types (i.e., high, low, and average) by analyzing their OTT-usage traffic. Notably, there are two other classes of users that we ignore: Those that use an extremely heavy amount of OTT services and those who rarely use services. As these classes, which have extreme characteristics, can be easily classified via general machine learning, deep learning does not need to be used to classify these extrema. Therefore, we propose a framework that first filters high- and low-consumption consumers through a fast and relatively accurate machine-learning technique. Then, it performs deep-learning-based classification for the remaining customers. This framework shortens the overall computation time by reducing the number of samples to which deep learning is applied, allowing it to focus on the more ambiguous classes [41]. Figure 4 illustrates the proposed time-aware framework based on MetaCost.

**Figure 4.** Proposed framework.

Figure 5 shows a simple schematic of the high- and low-consumption filtering process, which forwards the non-extrema data to the deep-learning-based classifier. By setting the cost of errors that misclassify non-high-consumption data as "high" greater than that of errors that incorrectly classify high-consumption data as "medium" or "low", we prevent other classes of data from being misclassified as "definite high-consumption". When classifying definite low-consumption data, we filter out only the obvious data by setting the cost with the contrapositive logic. If the cost is set high in order to not mix the filtered data with other data, all data that are slightly ambiguous are forwarded to the next step, as shown with the "high-cost" process of Figure 5. As a result, the load on deep learning is mitigated. However, there is an increased chance of lower accuracy, because, during the filtering process, medium-consumption data can be misclassified as high- and low-consumption data. We adjust the tradeoff relationship between accuracy and time-consumption by controlling costs according to the number of data and resource state.

In the case of general machine learning, the weights for the errors resulting from the classification process remain the same. The MetaCost method sets the cost of errors differently and is suitable to be applied to the proposed filtering framework, because classification is performed in terms of minimizing costs. The MetaCost method assigns each data to a class satisfying the following equation:

$$\mathbf{x}'\mathbf{s}\text{ class} = \arg\min\_{i} \sum\_{j} p(j|\mathbf{x}) \,\mathrm{C}(i,j),\tag{6}$$

where *p*(*j <sup>x</sup>*) is the probability that *<sup>x</sup>* belongs to class *<sup>j</sup>*, and *<sup>C</sup>*(*i*, *<sup>j</sup>*) is the cost incurred when *<sup>x</sup>* actually belongs to class *j* but is classified as class *i*. After calculating the cost of misclassification for each datum, it is assigned to the class having the lowest cost [41].

**Figure 5.** Filtering results according to cost.

The proposed framework's classification time is far less than that of the simple deep-learning method, but deviations can occur based on the cost setting. However, deep learning is generally applied after filtering more than half of the data, making it advantageous over the proposed framework in terms of time consumption. Thus, even with greatly increased sizes of the analysis dataset, the strengths of proposed framework will stand out. As mentioned, the framework increases in flexibility via cost-setting adjustments. If the cost is properly adjusted according to the environment in which this framework is adopted, classification can be performed according to the time and accuracy desired by the analyst.

### *3.3. Dataset Description*

To verify the proposed methodology, we applied the pre-existing dataset mentioned in Section 3.1. Traffic captured directly from the Universidaa del Caucau (Unicauca) network in 2017 was converted into a dataset comprising 130 features and 1581 user samples. These samples were divided into classes of high, medium, and low consumption. The OTT usage data were well represented. Twenty-nine applications were analyzed, including 29 OTT services: Amazon, Apple Store, Apple iCloud, Apple iTunes, Deezer, Dropbox, EasyTaxi, Ebay, Facebook, Gmail, Google suite, Google Maps, HTTP\_Connect, HTTP\_Download, HTTP\_Proxy, Instagram, LastFM, MS OneDrive, Facebook Messenger, Netflix, Skype, Spotify, Teamspeak, Teamviewer, Twitch, Twitter, Waze, WhatsApp, Wikipedia, Yahoo, and YouTube. Features were extracted by analyzing the traffic flow of each service, as shown in Table 1. For each service, features were extrapolated from the dataset [37].


**Table 1.** Feature Description.

### **4. Results and Discussion**

#### *4.1. Machine and Deep Learning*

For classification based on KNN, decision tree, SVM, naïve Bayes, and RIPPER, we employed Weka, a JAVA-based machine-learning library [51]. For MLP and CNN, we employed scikit-learn and TensorFlow. For the experiments, the hardware included an Intel i7-1065G7 processor, 16-GB LPDDR4x memory, and NVIDIA® GeForce® MX250 graphics with GDDR5 2-GB graphic memory. To select machine learning parameters with the best performance for each method, we experimented with adjusting the various parameters to find appropriate values for the OTT datasets. For the KNN, *k* was set to 17, and J48 was used as the decision tree. In the SVM, a linear kernel was used, and, in RIPPER, the number of folds used for pruning was set to 10. The naïve Bayes classifier used Weka's default settings. For the MLP, ActivationELU was used as the activation function of the hidden and output layers, ADAM was used as the optimizer of the loss function, and AdaDelta was used as the bias updater. For the CNN, two each of convolution, pooling, and fully connected layers were used. ActiviationIdentity and ActivationSoftmax were used as the activation functions of the convolution and output layers, respectively. Adamax was used as the optimizer of the loss function, and AdaDelta was used as the bias updater. Additionally, the number of epochs was set to 100.

We considered recall, precision, and F-measure as the evaluation metrics, calculating them based on the basic true positives (TP), false positives (FP), and false negatives (FN). Recall indicates the number of classes detected among the actual classes, and is the same as TP. Precision is the accuracy of detection and refers to the probability that, when a datapoint is classified into a class, it actually falls into that class. F-Measure is used to obtain the harmonic average value for precision and recall and simultaneously indicates accuracy. These metrics are defined as follows:

 $\text{Recall} = \frac{\text{TP}}{\text{TP} + \text{FN}}$   $\text{Precision} = \frac{\text{TP}}{\text{TP} + \text{FP}}$  
$$\text{F} - \text{Measure} = \frac{2 \cdot \text{Recall} \cdot \text{Precision}}{\text{Recall} + \text{Precision}}$$

Table 2 shows the experimental results based on the aforementioned environment. In the case of the conventional machine-learning methods, KNN achieved good performance with a classification accuracy of 95.1%, and SVM achieved a satisfactory performance of 92.9%. Except for naïve Bayes, all machine-learning methods showed detection rates over 90%, confirming their applicability in classifying OTT users. The accuracy of deep learning was even higher: The use of MLP and CNN to classify consumers achieved a detection rate of 98.2 and 97.6%, respectively. Because the data input was not complex, we observed that the accuracy of MLP was higher than that of CNN. Because both deep-learning methods achieved high performance, we confirmed that their application could also be effectively applied to OTT user analysis.

**Table 2.** Classification results of machine and deep learning. Abbreviations: k-nearest neighbor (KNN); support-vector machine (SVM); repeated incremental pruning to produce error reduction (RIPPER); multilayer perceptron (MLP); convolutional neural network (CNN).


Tables 3 and 4 show the detailed classification results of the three types of consumers through deep learning. As shown in Table 3, MLP classified high- and low-consumption users with an accuracy of ≥98%. For medium consumption, although the classification accuracy was lower, it was relatively high at 96%. The results of the CNN shown in Table 4 show a similar tendency. The classification accuracy reached ~99% for high- and low-consumption users, whereas medium consumption showed a relatively accurate detection rate of 94.6%.

**Table 3.** Detailed classification results obtained through MLP.


When classifying OTT users based on deep learning, the classification accuracy was observed to be relatively high, as in other applications fields. However, as confirmed by the classification times shown in Table 2, deep learning took longer to classify consumers than did conventional machine-learning methods. The next subsection describes the MetaCost savings.


**Table 4.** Detailed classification results of CNN.

### *4.2. Time-Aware Consumer Classification*

To reduce the time required for deep learning, we first classified high- and low-consumption data by using machine learning and MetaCost. We then classified only the remaining ambiguous data using deep learning. KNN and J48 decision trees were the machine learning methods used as the primary filter. Although KNN showed the best performance among all machine-learning methods, J48 achieved fast and highly accurate results. The cost requirement of applying MetaCost is defined as "the cost incurred when classifying data other than high/low consumption as high or low consumption". Therefore, with the cost set to "high", ambiguous data are forwarded to the secondary deep-learning classification. This experiment was performed while changing the cost from 1 to 30 in steps of five. If the cost was one, the weight for all errors was one. Accordingly, the result obtained was the same as that without the application of MetaCost. If the cost was set higher than one, classification was performed to reduce costs. At each step as the cost approached 30, no significant difference was observed from the previous step. Thus, to observe the most conspicuous difference, the experiment was conducted with the cost set to 30. In the secondary classification process, the MLP algorithm was applied for deep learning. Table 5 shows the processing results of the primary filter using KNN and J48 while adjusting the cost from 1 to 30. The table presents the number of data filtered by the primary filter, incorrectly classified by the filter, and processed by deep learning (the secondary classification) with the final detection time.


**Table 5.** Filtering result according to cost change.

With an increase in the cost setting, only the more obvious data were filtered out. Thus, the number of filtered data decreased, and those processed through deep learning increased, resulting in an increase in detection time. However, even if the cost was set to an extremely high value of 30, the detection time was about half. This resulted in the best accuracy while reducing the detection time by more than half. If the cost was set to "low", the detection time was reduced to approximately 23%. However, in this case, the number of incorrect classifications increased, negatively affecting the classification accuracy of the entire framework. When using KNN as the primary filter, results showed fewer errors. However, the number of filtered data was less than that when using J48. Therefore, KNN was determined to utilize more time than J48. On the contrary, although J48 utilized less time because of more filtering, it resulted in more errors. Therefore, the overall classification result of J48 was poor. Table 6 summarizes the overall classification accuracy of the framework per filtering method. Because of space limitations, the detailed results are included in the Appendix A. As shown in the

results of Table 6, with an increase in the cost setting, the ambiguous data were forwarded for accurate deep learning, leading to higher accuracy. In terms of accuracy, the results showed only a slight difference when classifying the entire dataset using deep learning. The filter using KNN showed higher accuracy, because it filtered less data than did the filter using J48, resulting in more data being processed during the classification step. Therefore, as observed, the use of the filter with KNN utilized more classification time than that did that of J48. Overall, the filters using KNN and J48 showed classification accuracies of 97 and 96%, respectively, with no significant difference from the value obtained using only deep learning. For both filters, with the cost set higher, the accuracy increased, but the classification time also increased, as shown in Figure 6. Overall, while the accuracy of KNN was high, it utilized more time. When analyzing OTT users, if a considerable amount of data must be analyzed, the focus should be on reducing the time by setting the cost low. Furthermore, if the data to be analyzed are relatively small or if there is sufficient time for analysis, accuracy can be improved by setting the cost high. Thus, optimal time and detection rates can be set while adjusting the cost according to the given environment.



**Figure 6.** Changes in accuracy and time according to changes in the cost.

The dataset used in this study contained the metadata of 1581 people. Therefore, regardless of analysis time, a significant time difference was not observed. As mentioned, ISPs or OTT providers will likely face hundreds of thousands or millions of users. If the data corresponding to the actual number of users are analyzed using the proposed method, time savings will be clearly observed. To this end, by applying SMOTE [52], an oversampling method, we created a dataset with 159,681 instances. SMOTE is a technique for creating new samples based on existing samples. Unlike other oversampling techniques that simply duplicate existing samples, SMOTE creates synthetic data based on existing data. Therefore, it is possible to create a dataset that has similar characteristics to an existing dataset but has a much larger number of samples. We used SMOTE to create data with a large number of samples, similar to the real environment, and then verified our proposed framework. This amount was approximately 100 times larger than the original dataset. Table 7 shows the time differences between the methods using simple deep learning and the proposed framework based on the oversampled dataset. Because the

number of instances grew enormously, the time required for filtering was considerable. However, the time difference was far more conspicuous than that if we had classified the entire dataset using plain deep learning. When analyzing hundreds of thousands or millions of units of data, the proposed framework is confirmed to significantly reduce the time requirements.


**Table 7.** Changes in time taken for classification according to cost change (unit: Second).

### **5. Conclusions**

In this study, we proposed machine- and deep-learning methods for OTT user analysis to provide ISPs and OTT providers critical timely information about OTT usage data so that they can effectively monitor and execute pricing and mitigation plans. By classifying users according to OTT usage, we confirmed that the classification accuracy was high when using deep learning and conventional machine-learning methods. In particular, deep learning showed higher accuracy. This implies that the application of deep learning to OTT user classification was successful. With plain deep learning, the accuracy of OTT user classification is high, but the classification time takes longer. To shorten this time requirement, we proposed a time-aware MetaCost filtering framework. After first filtering the obvious data using a relatively light algorithm, deep learning was applied to only the most ambiguous data, significantly reducing classification time. However, the accuracy was about the same as with plain deep learning.

This study has the following implications for network and OTT providers. This is the first study that demonstrated how deep learning can be employed to classify OTT user behaviors in a timely manner. ISPs are heavily burdened with applying and maintaining requisite network infrastructure and load balancing to support not only OTT services, but all other internet services, much of which is privately or government contracted. Thus, these investments seriously drive strategy. Hence, timely and extremely accurate usage analysis is needed. This study, therefore, has a wide range of applications in all of those domains.

The proposed framework drastically reduces the time consumption of deep-learning methods with respect to ever-changing user behavior. In fact, when business providers analyze this information, they must consider hundreds of thousands of data items at once. The analysis of such a large amount of data using deep learning can be prohibitively time-consuming and requires heavy computer-resource investments. When applying the proposed method, the costs of time consumption can be drastically reduced.

The proposed method can be used to perform classification according to situations by adjusting the cost factor. In the case where the number of data is relatively small, or there is sufficient time or available resources, accuracy can be improved by increasing the number of data analyzed through deep learning (i.e., cost is set to "high"). On the contrary, if many cases must be analyzed promptly, the cost can be set to "low". Thus, the more obvious data are filtered out. As such, flexible responses are possible by adjusting the cost factor, and the proposed framework can be, therefore, used by providers for real analysis purposes.

In the future, we plan to focus more on the following points. First, when using deep learning, there is a need for a customized methodology suitable for the particular dataset. Because the OTT dataset used in this study comprised unsophisticated features, a simple MLP or CNN resulted in

significant outcomes. However, if complex data were to be analyzed instead, more complex deep learning algorithms must be used. Furthermore, analysis needs to be performed based on various types and categories of OTT user data. To the best of our knowledge, the dataset used in this study is the only public dataset that specializes in OTT. If more datasets related to OTT user behavior will be open to the public in the future, additional and improved research will be possible.

**Author Contributions:** Conceptualization, J.C. and Y.K.; methodology, J.C.; software, J.C.; validation, J.C. and Y.K.; formal analysis, J.C.; investigation, J.C. and Y.K.; resources, J.C.; data curation, J.C.; writing—original draft preparation, J.C.; writing—review and editing, Y.K.; visualization, J.C. and Y.K.; supervision, J.C. and Y.K.; project administration, J.C. and Y.K.; funding acquisition, Y.K. All authors have read and agreed to the published version of the manuscript.

**Funding:** This work was supported by the National Research Foundation of Korea (NRF) grant funded by the Korea government (MSIT) (No. 2020R1G1A1099559).

**Conflicts of Interest:** The authors declare no conflict of interest.

### **Appendix A**


**Table A1.** Detailed classification accuracy of each filter according to cost change.

### **References**


**Publisher's Note:** MDPI stays neutral with regard to jurisdictional claims in published maps and institutional affiliations.

© 2020 by the authors. Licensee MDPI, Basel, Switzerland. This article is an open access article distributed under the terms and conditions of the Creative Commons Attribution (CC BY) license (http://creativecommons.org/licenses/by/4.0/).

## **E**ff**ectiveness of Machine Learning Approaches Towards Credibility Assessment of Crowdfunding Projects for Reliable Recommendations**

### **Wafa Shafqat 1, Yung-Cheol Byun 1,\* and Namje Park <sup>2</sup>**


Received: 22 October 2020; Accepted: 16 December 2020; Published: 18 December 2020

**Abstract:** Recommendation systems aim to decipher user interests, preferences, and behavioral patterns automatically. However, it becomes trickier to make the most trustworthy and reliable recommendation to users, especially when their hardest earned money is at risk. The credibility of the recommendation is of magnificent importance in crowdfunding project recommendations. This research work devises a hybrid machine learning-based approach for credible crowdfunding projects' recommendations by wisely incorporating backers' sentiments and other influential features. The proposed model has four modules: a feature extraction module, a hybrid LDA-LSTM (latent Dirichlet allocation and long short-term memory) based latent topics evaluation module, credibility formulation, and recommendation module. The credibility analysis proffers a process of correlating project creator's proficiency, reviewers' sentiments, and their influence to estimate a project's authenticity level that makes our model robust to unauthentic and untrustworthy projects and profiles. The recommendation module selects projects based on the user's interests with the highest credible scores and recommends them. The proposed recommendation method harnesses numeric data and sentiment expressions linked with comments, backers' preferences, profile data, and the creator's credibility for quantitative examination of several alternative projects. The proposed model's evaluation depicts that credibility assessment based on the hybrid machine learning approach contributes efficient results (with 98% accuracy) than existing recommendation models. We have also evaluated our credibility assessment technique on different categories of the projects, i.e., suspended, canceled, delivered, and never delivered projects, and achieved satisfactory outcomes, i.e., 93%, 84%, 58%, and 93%, projects respectively accurately classify into our desired range of credibility.

**Keywords:** LDA; LSTM; crowdfunding; project recommendation system; optimization; deep learning

### **1. Introduction**

Recommendation systems aim to assist users in daily decision-making processes and are being utilized by perpetually developing online business ventures. Crowdfunding is a platform that plays the role of a venture capitalist for entrepreneurs with creative minds. The recommendation in crowdfunding becomes trickier and complicated than offline businesses due to many challenges, such as information scrutiny and less proficient investors [1]. Moreover, online data is less reliable and inclined to alteration, making it difficult for investors to rely on a new business idea [2]. Therefore, the credibility assessment of crowdfunding projects becomes an absolute necessity to mitigate the risk of fraud. Crowdfunding is undoubtedly becoming popular as a study [3] shows that approximately 6,445,080 fundraising campaigns were hosted in 2019, with gaming companies being the most successful in generating profit. It is also predicted that the growth of transaction rate annually will reach up to 5.8%,

resulting in a total amount of 1180.5 million dollars by 2024 [4]. Despite the remarkable development and inexhaustible possibilities that crowdfunding provides, the challenges and risks of trust, reliability, transparency, etc., are equally daunting, seemingly mounting ones. This study attempts to plug that credibility gap by analyzing and filtering key players' features towards trust-building among investors and the creator. In addition to basic campaign features, we also concentrate on the comments section of the crowdfunding sites, which plays a significant role in fighting against the considerable risks of deceitful online events.

In this paper, we propose a credibility formulation for project recommendations based on a hybrid model. Our proposed architecture has four modules: a text analysis module for project comments, a deep learning module, a credibility estimation module, and a recommendation module. The input data is based on comments-related features and other project-related features. The comment-related features are derived from a comment section of a campaign where users leave their feedback about particular topics; other project-related features include project funding goal, creator's experience, number of images/videos/updates/comments, etc. We perform tokenization, streaming, stop words removal, and data normalization in the data preprocessing layer. In the next step, we perform parameters estimation and topic modeling through latent Dirichlet allocation (LDA). LDA clusters the words with the same meaning in a single topic and is passed to the long short-term memory (LSTM) layer, where the input data consists of the word embeddings, topic embeddings per time-step, and topic distributions.

The LSTM is then trained against new comments to generate sentiments, i.e., positive, negative, and neutral sentiments. As the output from the LSTM layer, we get topic class, accuracy, and project classification. These results are used in the recommendation module to equate and analyze the product's credibility through sentiment score and authenticity calculation. For optimization, we compute the objective function that has both maximization and minimization function. We also formulate the authenticity score and credibility of the project. Through our developed Equation, we get the credibility score and recommended product as our output. We have developed various equations step by step to build the recommendation system considering the critical aspects of positive and negative comments from users and mentioned how authenticity and credibility inter-related for project evaluation are. Our results show that the proposed approach is feasible for all scenarios and achieves high accuracy in recommendation result and authenticity level evaluation and low error rate. The proposed model's evaluation depicts that credibility assessment based on the hybrid machine learning approach contributes efficient results (with 98% accuracy) than existing recommendation models. We have also evaluated our credibility assessment technique on different categories of the projects and achieved satisfactory outcomes. As 95%, 89%, 58%, and 96% of the projects from their respective categories, i.e., suspended, canceled, delivered, and never delivered projects categories were accurately classified into our desired range of credibility.

The rest of this paper includes related works in Section 2, data in Section 3, the proposed method in Section 4, results in Section 5, and conclusion in Section 6.

#### **2. Related Works**

In this era of internet and digitalization, an enormous amount of textual data is generated at a high rate. Text data analysis applications are widespread, starting from customer review analysis to extracting and finding a large dataset's hidden meaning. Blei proposes a novel approach to recognize the topics, which ultimately led to sentiments classification, documents classification, and unlocked relatively many assessment prospects for textual data [5]. Topic models are of crucial importance for the illustration of discrete data and are used in different research fields such as medical sciences [6], software engineering [7], geography [8], and political sciences [9], etc. There are many topic modeling techniques; each has its strengths and limitations. The most frequently used approaches include latent semantic analysis (LSA) [10], probabilistic latent semantic analysis (PLSA) [11], latent Dirichlet allocation (LDA) [12], and correlated topic model (CTM) [13].

LSA's primary focus is to generate different representations of texts based on vectors to create semantic content [10,14]. These vector representations are designed to choose related words by computing the similarity among text data. LSA has many applications such as keyword matching, word quality assessment, power collaborative learning, guidance in career choices, making optimal teams [15], reduction of dimensions [16], and identification of research trends [17]. PLSA was introduced to fix the limitations of LSA [18]. It has many implications, including the differentiation of the words with several meanings and clustering of words that share similar contexts [19]. In [20], PSLA is introduced as an aspect model based on a latent variable responsible for linking observations with unseen class variables. In addition to introducing advancements in LSA, PSLA has many other applications, including recommender systems and computer vision [21–23]. LDA model aims to overcome the limitations of LSA and PSLA in capturing the exchangeability of document words.

LDA being an unsupervised approach for topic modeling, has recently become very popular, mainly for topic discovery in a large corpus. In [24], LDA is used for text mining that is based on Bayesian topic models. LDA is also a generative and probabilistic model that attempts to imitate the writing task. Therefore, it attempts to produce a document if a topic is given. There is a variety of LDA based algorithms used in different domains, including author–topic analysis [25], LDA based bioinformatics [26], temporal text mining [27], supervised topic models, and latent co-clustering, etc. In simple words, LDA's fundamental idea is that each document is represented as a mixture of topics. Each topic represents a discrete probability distribution reflecting each word's likelihood to occur in a specific topic. Therefore, a document is described as probability distributions of words in each topic. Certainly, LDA has many applications such as role discovery [28], emotion topic [29], automatic grading of essays [30], and email filtering [31], etc. Biterm topic modeling (BTM) is a topic modeling approach over short texts. These topic modeling methods are becoming a significant job because of the pervasiveness of the short texts available on the internet. BTM is also used to discover discriminative and comprehensible latent topics from short text [32].

Recommendation system (RS) is an intelligent system that suggests items to users that might interest them. Some of the practical example applications of RSs include movie, book, tourist spot recommendations, etc. It is a point of amusement to discover how, "People you may know" feature on Facebook or LinkedIn. In a personalized RS, users get item suggestions based on their past behaviors and social networks-based interpersonal relationships. There are four categories of personalized recommendation systems based on the approach, content-based filtering, collaborative filtering (CF), knowledge-based filtering, and hybrid. A novel clustering method is proposed in [33] that uses the latent class regression model as a baseline model, which considers both the general ratings and textual reviews. In [34], a system that assesses a user's location as an attribute of a recommendation system is proposed. A recommendation method is suggested in [35], which investigates the difference between user feedback to discover a customer's preferences. It considers user ratings and focuses on the sparsity issue of the data. In [36], a CF method is being suggested that uses ratings of different items and feedbacks on various social networks such as Twitter.

A convolutional neural network (CNN) devised by Krizhevsky et al. is referred to as deep CNN [37] that leaned 1000 semantic concepts for training based on ImageNet Large Scale Visual Recognition Challenge (ILSVRC) 2012 dataset. Deep CNN proposed by [38] is not suitable for the clothing domain. Therefore, fully connected layers have been included between the seventh and eighth layers to fill the gap between semantics and mid-level features. In [39], the author built a CNN model for the classification of the music genre. This model comprises two convolutional layers, one fully connected layer, and two max-pooling layers. Further, there are ten softmax units with a logistic regression layer to classify the music genre.

Xin Liu et al. [40] used a fusion of matrix factorization and LDA to build a web content-based recommendation model that recommends to the user's fake credibility information to analyze their reaction and improve the model. Schwarz et al. [41] considered measuring webpage popularity, page rank metric, and popularity of a web page to assess a user's web credibility. Studies [42,43]

have shown that varied linguistic features, writing styles, and project creators' patterns reveal how communication impacts crowdfunding projects' success. Generally, crowdfunding success is predicted by extracting LDA's semantic features and then by feature selection and data mining [44]. Most of the literature studies are focused on simple embeddings and have not considered using words plus topic embeddings for LSTM training. We have incorporated these embeddings to make more meaningful recommendations that are highly authentic and trustworthy. Moreover, our methodology is novel because we focus on crowdfunding comments to analyze and formulate their impact on the crowdfunding project's credibility. In other sections, we have presented how we have overcome the shortcomings of the literature studies to build a system that considers several factors to recommend credible crowdfunding projects.

### **3. Proposed Credibility Formulation for Project Recommendation Based on Hybrid Model**

This section elaborates the credibility assessment formulation based on learned topics from text and other vital features. The proposed approach uses LDA and LSTM as underlying methods for the credibility assessment process. The overall procedure is divided into multiple tasks as shown in Figure 1, which primarily includes data collection, features selection, text data analysis for topics discovery, topic classification, and formulation for credibility estimation and recommendations. Each task is elaborated separately in the following subsections.

**Figure 1.** Layered view of the proposed approach.

#### *3.1. Input Data*

Each crowdfunding project is rich with the information and data it has in terms of the project's data and user's profile data. The project-based data includes many elements such as project description, duration, number of backers, numbers of comments, and project's success status, etc. Similarly, the user's profile data is related to the creator's information such as name, ID, linked social networks, number of friends, number of created or backed projects, etc. We are mainly focusing on the comments section of a project as comments reveal a lot of information about a project's status and its creator's behavior through backers' experiences. In addition to features extracted from the comment section, we have also focused on the statistical features such as the number of comments, updates, pledged amount, number of backers, etc. We also recorded time delay between different posts to track the project creator's activities. We have collected data from a famous reward-based crowdfunding

platform, i.e., Kickstarter. Its mission is to bring creative projects to life that belong to 15 different categories and eight sections: arts, comics & illustration, design and tech, film, food and craft, games, music, and publishing. Table 1 describes the data in detail. There is no limitation on the length of a comment.



The temporal patterns of a review, interaction patterns between project backers and creators, the average timeline required from the proposal stage to the approval state varies for every project. The importance of social link and user description in assessing credibility is described in later sections.

### *3.2. Data Pre-Processing*

This unit is in charge of several jobs. It first tokenizes the comments into multiple words. Then these tokenized words are passed through the cleansing unit. Here, all the punctuations are removed, and words are passed through the stemming unit. This unit lower cases all the words and convert each word to its root. (e.g., working is replaced with work). Then, we filter out all the stop words. Stop words are used in any language for grammatical reasons (e.g., a, an, is, etc.) after this processing comment is passed to LDA for further processing.

Then we label those clusters into meaningful topics. Therefore, after LDA, we have topic distributions representing the probability of a topic in a document and word distributions representing the probability of a word in a topic. These probability distributions are then prepared as an LSTM input. For LSTM, the word embedding and topic embedding are also generated. These embedding against each new input comment are trained in an LSTM network. The topic classes are distributed in three basic types of sentiments, i.e., positive sentiments, negative sentiments, and neutral. Therefore, the percentage of each topic class is calculated and assigned a sentiment class accordingly.

### *3.3. LDA and LSTM Based Hybrid Model*

The preprocessed data is passed to the hybrid module responsible for the data's primary processing. Here, data is first handed over to the topic modeling process, where LDA is applied. The number of topics and Dirichlet parameters is initiated. LDA generates clusters of words that have the highest similarity.

### A. Topic Discovery and Classification

We used LDA for topics discovery in the comments data. We used comments to discover topics as the comments left by backers can present their emotions, feelings, thoughts, and experiences related to the project. Therefore, reviews or comments are powerful enough to shape other's decisions. Figure 2 elaborates on the overall process of LDA. Each project's input data is in the form of comments; each comment is treated as one document that results in N documents per project. Data preprocessing is a crucial and vital part of any NLP technique; therefore, we perform essential yet necessary preprocessing tasks on input data such as removing quotes, stop words, and URLs, tokenization and stemming, etc. Data preprocessing has been influenced by the paper [42], which helps us work with

short-texts and proves that LDA works with equivalent efficiency. Once the data is preprocessed, LDA is performed where we set the Dirichlet parameters to calculate desired distributions. We present the output in terms of probability distributions of topics over projects and word distributions over documents. This output is then used as input for the next step, where we use these discovered topics as ground truth and train our LSTM model to predict the topic class of new comments. All the learned topics are divided into different classes, and each class depicts a specific sentiment.

**Figure 2.** Latent Dirichlet allocation (LDA) process.

### B. Deep Learning using LSTM

We are using a bidirectional LSTM for capturing the context dependencies concerning time. A bidirectional LSTM is analyzed in its natural order and inverse order when an input is provided to capture maximum dependencies within the data. We are using a 128-unit LSTM (bidirectional) for this purpose. The preprocessing module's input is passed to an embedding layer that converts the input into a 64-bit vector representation. This representation is then processed by the LSTM layer, which is then connected to a dense layer. This layer helps to consolidate the LSTM results. The output layer gives the probability distribution of the output category. The detailed architecture of the proposed approach is presented in Figure 3.

### *3.4. Project Credibility Estimation*

In this section, we present a detailed explanation of the credibility module. The overall process of deriving formulas steps by steps to estimate the credibility of a project is delivered. Trust is an ultimate significant element in any domain that helps to gain the customer's confidence. It is valid for e-commerce sites and online social networks, as well. Therefore, multiple trust-aware recommender systems are being proposed that adopt user's trust statements and their personal or profile data to improve the quality of recommendations considerably.

As we target crowdfunding projects, we aim to formulate an equation to calculate any project's credibility before recommending it to a user. A highly credible recommendation is a project that most likely reflects the user-defined interests and categories with higher chances of its delivery. It must also reflect the lowest probability of factors that can disturb the project's trustworthiness, such as communication delays and less frequent updates, etc. A credible project can precisely be defined as a project with the maximum likelihood of completing and delivering to the backers within the promised period. Various factors are associated with a project's credibility; we define and link a documents' credibility with its estimated authenticity score range. A project's authenticity is a multi-fold view of different and latent aspects, such as latent aspects of a creator's profile and all his or her external social links. It also involves the frequency of account usage and updates from creators. In other words, keeping the backers up to date with each development or progress in the project can earn more credibility points. In addition to that, factors such as the most frequent keywords used, promises related to product delivery or rewards delivery, and investors' sentiments are also crucial. These sentiments of backers are discovered during the LDA process to find latent topics in their comments. There can be multiple topics in a document, and each topic represents a particular class of sentiments. As shown in Table 2 [45], we identified 12 topic classes labeled Topic-1 to Topic-12. The number of topics was varied between 2 and 30 during the experiments to find the optimal number of topics.

**Figure 3.** Detailed view of the proposed approach.

The coherence score was increasing as the number of topics was growing. We selected and evaluated the topics based on the coherence score before flattening out, i.e., 12 topics. After training LSTM, the classification of each comment is done into one of these topic classes. We have divided these sentiment classes into three categories, and this division is customized based on the problem, i.e., credibility assessment. These categories are referred to as A, B, and C. Category A is responsible for extremely negative comments, which is represented by Topic-4 to Topic-7; category B means negative reviews, which is characterized by Topic-1 to Topic-3; and category C is representing positive or neutral reviews which are represented by Topic-8 to Topic-12. More emphasis is laid on the negative comments because the negative comments and reviews significantly impact the viewer's mind and

decision-making process than positive comments regarding credibility or trust. Therefore, we divided the negative comments into extremely unfavorable class A and negative class B.


**Table 2.** Topic classes identified using LDA analysis.

To evaluate our selected topics, we measured the agreement between two raters using Cohen's Kappa coefficient [46] and followed the process mentioned in [47] to assess our LDA model. Two students (student A and student B) from different laboratories who were unaware of our proposed methodology and had no prior knowledge about the list of LDA topics were requested to extract topics from 250 sampled reviews.

Student A and student B were not allowed to communicate or discuss their thought process behind labeling each review. Student A and student B could identify 9 and 11 topics, respectively. Student A had seven topics in common with LDA, whereas student B had ten topics common with LDA. Among all the topics, we selected six topics that were most common among the two students' topics to measure our LDA model's reliability, as shown in Table 3. As we can see from Table 3, student A and student B have a high degree of agreement for all six topics. The LDA model and respective students' contract is also relatively high, as indicated by the Kappa coefficient.

Category A is for extremely negative comments and severe nature and typically reflects anger by filing lawsuits or complaints. Category B is for relatively simple and generic negative comments that reflect emotions of sadness or disappointment. The classification is based on the nature of malicious content. All other comments belong to category C. The purpose behind this arrangement with more emphasis upon negative comments is the underlying prominence or impact of the malicious content on a product's credibility. Table 4 summarizes the parameters used for authenticity measures with their definitions and notations. In addition to sentiments, we have also included other relevant and impactful features such as readability of content referred to as readScore, the existence of a profile picture, etc.


**Table 3.** Reliability Assessment of LDA model using Cohen's Kappa coefficient.

**Table 4.** Definitions of the parameters of authenticity.


Hence, by incorporating all the factors mentioned above, we have formulated an equation that helps calculate a given project's authenticity. To figure the authenticity of a project, it must first fulfill the eligibility criteria given in Equation (1). Once a project passes the eligibility criteria, Equation (2) is used to calculate the authenticity of it. The eligibility criteria are based on a project's content and partially on the profile associated features in Equation (1).

$$\text{Eligibility}\_{criteria} = -(e\text{NegA} + \alpha \ast \text{picY}) \tag{1}$$

Here, α represents the weight associated with the existence of a profile picture. The weightage assigned to *picY* is lower than the weightage of *eNeg*A because of the level of impact asserted by each parameter. The value of α is set to 0.4. From the above Equation, we define the ranges for both the parameters.

$$e\text{NegA} = \begin{cases} 0 & \text{if } A \le 0 \\ 1 & \text{if } A > 0 \end{cases} \tag{2}$$

Similarly,

$$pricY = \begin{cases} 0 & \text{if } proffile \text{ picture} = \text{Exists} \\ 1 & \text{if } proffile \text{ picture} = \text{Does not exists} \end{cases} \tag{3}$$

Hence from above Equations (2) and (3), we have

$$Elegibility\_{project} = \begin{cases} 0 & \text{favorable} \\ <0 \ge -0.4 & \text{can be considered} \\ <-0.4 & \text{unfavorable} \end{cases} \tag{4}$$

Therefore, based on Equation (1) and following the conditions in Equation (4), we can list all possible scenarios of eligibility in Table 5. The content in *eNeg*A is extremely unfavorable as one can sense fears, suspicion, and frustrations in it. Therefore, this category is handled independently to alleviate the probability of any unreliable recommendation. For a reliable project, it must be free from any of the comments in *eNeg*A category. Thus, we used this to set our eligibility criteria. The objective function targets getting the maximum percentage of positive comments, i.e., category C. It also targets to get the maximum number of social links of the project's creator.


**Table 5.** All possible cases for a project's eligibility criteria.

In crowdfunding, a backer's faith and confidence rely on the content authenticity and creator's limpidity. Therefore, these aspects are fundamental to a project's success. Table 4 shows that the factor delaycomm is one prime feature of the project, representing a creator's communication styles such as his updates and comments. This feature, delaycomm can be defined as the average time gap between any consecutive posts by the project creator in an update or a comment. It shows the communication rate of a project creator towards the development of a project. Due to the impact of delaycomm, the project's authenticity will be damaged if the communication delay upsurges.

After observing and estimating all the relevant features, all the values are normalized between 0 to 1. Here, 0 represents the least authentic feature, and 1 illustrates the highly authentic feature. In other words, these values depict the trustworthiness of a project. Equation (5) below describes this relationship, i.e., the higher the authenticity is, the higher the reliability of a project turns out.

$$Authenticity\_{project} \propto Creibility\_{project} \tag{5}$$

As a result, a project has different credibility levels, i.e., extremely low, low, normal, high, and extremely high credibility. Each credibility level falls into another degree of authenticity range. The extremely low and low credible projects have higher chances of getting forged. It means the projects with lower credibility levels have the utmost possibilities of fighting with non-payments, no communication or communication delays, delays in posts by the creator in the form of updates or comments, and late or no deliveries. Therefore, such projects are not favorable to be recommended to backers to invest in. Instead, a project with a higher credibility level (high or extremely high credibility) is undoubtedly a profitable project recommended to backers. It has the maximum probability of on-time delivery with more consistent patterns of communication throughout its duration.

For any recommendation system, the percentage of positive and negative reviews is pre-eminent as it reflects a user's attitude towards a product. Therefore, we assess the following points wisely:

1. A fundamental requirement for a product to be reliable is to have a maximum percentage of positive comments and a minimum negative comments rate. A product with a relatively high number of negative reviews becomes less favorable. Therefore, Equations (6) and (7) represent the relationship of comments with authenticity.

$$\text{Authenticity} \ll \text{[PosCi]} \tag{6}$$

and

$$\text{Authenticicity} \ll \text{[1/NcgBi]}\tag{7}$$

where the percentage of positive and negative comments is referred to as *PosC*i and *NegB*i, respectively.

2. The accessibility of social and profile information such as profile links, display pictures, number of friends or followers, etc., are persuasive and compelling elements for a profile's credibility. Thus, the more a project creator shares personal and relevant information, the easier it gets to earn trust. Therefore, we can say,

$$\text{Authenticity} \ll \text{[Links}\_{\text{Ext}}\text{]}\tag{8}$$

In the above Equation (8), LinksExt is the number of links a person provides for his/her external social media networks, such as Facebook, Twitter, etc.

3. The clarity of speech also plays a vital role in trust development. If the content is easy to follow and understand, a user will easily connect and comprehend it. It helps diminish the misunderstandings, and the confidence level of the reader increases. Therefore,

$$\text{Authenticicity} \propto \text{[1/read}\_{\text{Score}}\text{]}\tag{9}$$

In Equation (9), readScore is the readability score of a document. If readScore is high, the document is difficult to follow or to understand. The lower the readability score is, the higher probability is to understand it fast.

4. The communication patterns are the key to trust maintenance. A smoother and consistent communication can help people to put their trust in it. If there is no communication from the product creator, it will cause frustration and anger in backers and lose their interests. Therefore, the communication delay should be minimized between the creator's posts.

$$\text{Authenticity} \ll \text{[1/delay}\_{\text{comm}}] \tag{10}$$

In Equation (10), delaycomm is the average delay between any successive posts, i.e., comments or updates by the project creator. The higher delays will negatively affect project authenticity.

5. Hence, we can summarize the factors mentioned above as

$$\text{Authenticicity} \ll \text{[PosCi, Links\_{Ext}]} \tag{11}$$

also,

$$\text{Authenticity} \ll \left[ 1/\text{NegBi}, \text{delay}\_{\text{comm}}, \text{read}\_{\text{Score}} \right] \tag{12}$$

By combining Equations (11) and (12), Equation (13) is formulated as below,

$$\text{Authenticity} \propto \text{[PosCi, Links]} \text{NegBi, delay}\_{\text{comm}} \text{ read}\_{\text{Score}} \text{} \tag{13}$$

6. We divide the Equation into two parts; the similar factors based on their priority are combined. Hence, Equation (14) combines sentiment-based factors.

$$\text{Authenticicity} = \text{[PosCi/NegBi]} \tag{14}$$

This factor is only associated with product comments. For higher authenticity, *PosC*i has to be greater than *NegB*i. We have combined other features related to the product or creator into one Equation as,

$$\text{Authenticity} = \left[ \text{Links}\_{\text{Ext}} / \text{read}\_{\text{Score}} + \text{delay}\_{\text{comm}} \right] \tag{15}$$

7. Then combine all the factors in one place results into Equation (16) as below,

$$\text{Authenticity}\_{\text{project}} = \left[ \sum\_{i=1}^{n} \frac{\text{PosCi}}{\text{NegBi}} + \left( \frac{\text{Links}\_{\text{EM}}}{\text{read}\_{\text{Score}} + \text{delay}\_{\text{comm}}} \right) \right] \tag{16}$$

8. At the final step, we apply optimizations and formulate our objective functions. We have both maximization and minimization functions. The maximization function maximizes the values for favorable factors, and the minimization function underrates the cost of the least desirable parameters. Hence, we can now formulate the credibility estimation in terms of maximization and minimization functions in Equation (17).

$$\text{Creditability}\_{project} = \left\lfloor \sum\_{i=1}^{n} \frac{\max\left(\text{PosCi}\right)}{\min\left(\text{NegBi}\right)} + \left\lfloor \frac{\max\left(\text{Limks}\_{\text{Ext}}\right)}{\min\left(\text{read}\_{\text{Score}}\right) + \min\left(\text{delay}\_{\text{comm}}\right)} \right\rfloor \right\rfloor \tag{17}$$

For the above Equation, we can define the ranges of all the parameters as below in Equations (18)–(22).

$$\text{NegBi} = \begin{cases} 0 & \text{if } \% \,\text{age of } \text{negative components} = 0\\ 1 & \text{if } \% \,\text{age of } \text{negative components} = 100\% \end{cases} \tag{18}$$

$$\text{PosCi} = \begin{cases} 0 & \text{if } \% \text{age of positive components} = 0\\ 1 & \text{if } \% \text{age of positive components} = 100\% \end{cases} \tag{19}$$

$$\text{Links}\_{\text{Ext}} = \begin{cases} 0 & \text{if } \text{Number of links} = 0\\ > 0 \le 9 & \text{if } \text{number of links} > 0 \end{cases} \tag{20}$$

The value of LinksExt was decided based on the maximum number of external links provided by the project creator. In our case, the maximum number of links a person can provide is considered to be 9. Therefore, LinksExt can have any value between 0 and 9.

$$\text{read}\_{\text{Score}} = \begin{pmatrix} \text{near 1} & \text{Comprehensible (easy to understand)}\\ \ge 50 \le 100 & \text{Incomprehensible or range (difficult to understand)} \end{pmatrix} \tag{21}$$

$$\text{delay}\_{\text{comm}} = \begin{bmatrix} 0 \text{–} \text{365 days} \end{bmatrix} \tag{22}$$

Following Table 6, we can define the maximum and minimum ranges of each parameter.


**Table 6.** The value ranges for each credibility parameters.

### **4. Implementation and Experimental Setup**

In this section, we present our implementation environment, along with the experimental setup, in detail. This section also explains the evaluation metrics used for results assessment.

### *4.1. Experimental Setup*

The core system components include Ubuntu 18.04.1 as an operating system (LTS version), 32 Gb memory, and Nvidia GeForce 1080 as a graphics processing unit (GPU). In addition to the core system component, we used python language for development along with Tensorflow API.

### *4.2. Evaluation Metrics*

The performance of our system is measured by using the following evaluation metrics.

1. Accuracy: The accuracy of the model is calculated by using the following formula as shown in Equation (23)

$$Accuracy = 1 - \frac{\|\|Y - \hat{Y}\|\|F\|}{\|\|Y\|\|F} \tag{23}$$

where *Y* & *Y*ˆ and represent the actual data and predicted data, respectively.

2. Root mean square error (RMSE): The RMSE is calculated using Equation (24).

$$RMSE = \sqrt{\frac{1}{MN} \sum\_{j=1}^{M} \sum\_{i=1}^{N} \left( y^{ij} - \hat{y}^{j} \right)^{2}} \tag{24}$$

where *yi<sup>j</sup>* and *yi*<sup>ˆ</sup> *<sup>j</sup>* are subsets of *Y* & *Y*ˆ and represent the actual data and predicted data at the *j*th time sample in the ith session, respectively. M is the total time samples, and N is the number of projects. RMSE is precisely used to evaluate the prediction error. The smaller the value of RMSE is, the better is prediction rate or score according to Equation (25).

$$\text{Precision rate} \propto \frac{1}{\text{RMSE}}\tag{25}$$

While accuracy is used to detect predictions' precision, it has an opposite effect than RMSE on the prediction rate, as shown in Equation (26). The higher the value of accuracy is, the better is the prediction rate.

$$\text{Prediction rate} \propto \text{Accuracy}\tag{26}$$

### **5. Results**

This section presents the results and analysis for crowdfunding project recommendations based on the user's previous interests and credibility. In Section 5.1, we offer a study of the recommendation results for crowdfunding projects. In Section 5.2, we report the accuracy of the proposed model results compared with other models. Table 7 shows the selection criteria for credible projects with different levels of credibility.


### *5.1. Statistical Analysis of Recommendation Results*

To observe the results of our recommendation module, we used ground truth data. This data includes 100 projects, 55 non-scams, 20 suspended projects, ten canceled projects, and 15 successfully funded projects. Our proposed model works efficiently with high accuracy in both scenarios, i.e., when projects are from the same category or different categories. This dataset exemplifies all possible use case scenarios. We can evaluate how well our recommendation system performs on each type of project in terms of its funding status. The above Figure 4 shows the percentage for each category of crowdfunding projects.

**Figure 4.** Percentage of crowdfunding projects for each category.

These types are categorized based on the funding status of a project, i.e., "Non-Scam" are projects that are successfully delivered after successful funding; "successfully funded scam" are projects that successfully raised the required funds but failed to deliver; "canceled" category represents projects that have been withdrawn by the project creator before its funding period expires; and "suspended" type means those projects which have been discontinued by the platform in case they figure out any suspicious activity or content.

We have tested our model on all the categories mentioned above of projects to find their authenticity. In Figure 5, we have estimated the authenticity levels for the suspended projects. The *x*-axis shows

the authenticity levels between 0 and 1, and the *y*-axis presents the percentage of suspended projects. The estimated authenticity level for 93% of the suspended projects falls in the range (0–0.2).

**Figure 5.** Authenticity estimation for suspended projects.

The data used for this experiment included 80 suspended projects, and 74 projects were falling into the highly undesired range of credibility. This means that these projects are highly undesirable for backers. That is true because these projects are being suspended for some suspicious activities. From this range, we can interpret that most of the projects that have been suspended fail to fulfill the selection criteria of the credibility assessment tool for the recommendation. The statistical analysis of the results is presented in Table 8.


**Table 8.** Statistical analysis of credibility assessment of suspended projects (Total projects = 80).

In Figure 6, we have evaluated the authenticity levels for 70 canceled projects. The estimated authenticity level for 84% of the canceled projects falls in the range (0.4–0.6). Hence, keeping the risk factor in mind, these projects can be considered for investments. The statistical analysis is presented in Table 9. These projects are canceled for multiple reasons, such as lack of funding and budget issues during development phases.

The results show that for most cases, the predicted authenticity level range is between 0 and 0.2. We used 120 undelivered projects, i.e., successfully funded but never delivered projects, and out of these selected projects, 112 projects did not meet the credibility criteria and are highly undesired projects. It represents that regardless of successfully raising funds, backers are disappointed with the progress and development. For such projects, comments play a vital role in understanding a creator's behavior towards his investors after successfully collecting the desired funds. False promises, long delays in communication, or disappearance from the platform are the essential characteristics found in such cases.

**Figure 6.** Authenticity estimation for canceled projects.

**Table 9.** Statistical analysis of credibility assessment of canceled projects (Total projects = 70).


Figure 7 presents the accuracy of the recommendation results on the successfully funded projects that didn't deliver, i.e., scam projects.

**Figure 7.** Authenticity estimation for successfully funded scam projects.

The statistical analysis is presented in Table 10. These projects are undelivered and counted as scam projects because the creators didn't fulfill the promises and lacked transparency during the project's development phase.


**Table 10.** Statistical analysis of credibility assessment of undelivered projects (Total projects = 120).

Figure 8 presents an exciting trend. It shows the authenticity level estimation accuracy for non-scam projects.

**Figure 8.** Authenticity estimation for Non-Scam projects.

The inclination depicts rare chances for an authentic and genuine project to have parameters that can lower authenticity scores. Frequently projects are falling in the range of 0.4 to 0.9 and reflecting that comments are going in the gray range, usually for less risky projects.

The statistical analysis is in Table 11. These projects are successfully delivered to the backers. We used 120 successfully delivered projects for this experiment, and 64 of them were highly credible.


**Table 11.** Statistical analysis of credibility assessment of delivered projects (Total projects = 120).

Different learning rates are used for experiments, i.e., 0.1, 0.01, and 0.001 referred to as LR\_0.1, LR\_0.01, and LR\_0.001, respectively, in Figure 9 that evaluate RMSE for a different number of iterations. It can be detected that the testing errors start to get decreased if the learning rate gets smaller. For example, it represents that with a shorter learning rate value, the system's performance improves.

**Figure 9.** Testing error against different learning rates.

### *5.2. Comparison with Other Approaches*

Here, we have used RMSE as an evaluation metric to evaluate our technique with different ML approaches such as basic NN, bidirectional LSTM, an integrated model of recurrent neural network (RNN) and LDA referred as RNN-LDA, etc. Table 12 presents the RMSE value as a comparison with other models.

**Table 12.** Evaluation Metrics for Applied Machine Learning Approaches.


Figure 10 presents the accuracy percentage of different models in comparison with our proposed model. It shows that topic models, combined with deep learning models, can achieve better performance than other models.

**Figure 10.** Accuracy Rate of Different Models.

### **6. Conclusions**

We have proposed the methodology for measuring a project's credibility to build a recommendation system. The proposed method uses textual and non-textual data. This system is developed to help the users in selecting reliable and trustworthy options in their preferred categories. The proposed method is a hybrid model of LDA-LSTM and topic modeling that joins the benefits of both (1) LSTMs that captures time dependencies for class and topic prediction and (2) topic modeling that extracts topics that nicely summarize the content. A case study on crowdfunding is performed to analyze and test the proposed system's behavior. We have also embedded an optimized recommendation strategy based on a project's credibility.

This study aims to overcome the limitations of topic models and deep learning and get the most out of both approaches. The main objectives include:


This joint model of LDA-LSTM exploits words and topic embedding, and the temporal data attain 96% accuracy in predicting the topic categories accurately. The topics classes discovered were also evaluated in the context of helping investors identify suspicious campaigns. The prediction quality can be improved if we find out different configurations of comments concerning a project's timeline. We experimented with this by dividing the comments into five various batches of comments. We have not considered projects that have less than 50 comments to maintain the quality of the results.

Many developed applications for recommendation systems in different fields have been proposed. Our proposed approach is a novel approach to recommend a credible crowdfunding project to the best of our knowledge. Moreover, none of the works have focused on crowdfunding comments to find discussion trends and their impact on project credibility. Hence, in crowdfunding, this approach can be used to recommend safe or secure projects to investors. In Table 13, we show a comparison analysis of our proposed model with existing models of recommender systems. In [47–50], the authors use Kickstarter for predictions of the project's success. Others are related to taking comments and updates for estimating the completion of projects and then recommend them to the user. We summarize this research work's contributions: (1) a hybrid method is proposed for reliable and promising recommendations. This approach can model user preferences and word representations in a typical and dynamic style to empower the active measurement of the semantic similarity among the user's preferences and the words. (2) The proposed algorithm is to infer the dynamic embeddings of both the documents and words. We offer a credibility measurement approach for reliable recommendations. The results show that our proposed method outperforms similar state-of-the-art methods significantly.



### **7. Discussion**

Recommendation systems help users in their decision-making process. Many applications of these systems nowadays in every domain, e.g., location recommendation to tourists, product recommendation to online buyers, restaurant recommendation, route recommendation for travelers, etc. In other words, the need and importance of recommendation systems are not limited to just one platform; crowdfunding is also taking advantage of such applications to make this platform trustworthy for their investors. In this paper, we propose a hybrid model for crowdfunding project recommendations to backers.

The main contribution of this study is we evaluate different features of a campaign to assess its credibility. A credibility assessment is required to build the trust of backers in a campaign. If a backer is partially aware of a campaign's outcome, he can easily decide on investing in it or not. It is essential to build trustworthy recommendation systems, especially when users' hard-earned money is at risk.

We have tried to delve into the details of a campaign and analyze the outcomes of different campaigns based on their funding status. The hybrid model based on topic modeling and deep learning can (1) learn latent topics in comments, (2) to predict the outcome of a project based on the topics discovered so far, and (3) the credibility formulation process carefully evaluates the impact of each feature on the result of a project.

**Author Contributions:** W.S. conceived the idea for this paper, designed the experiments, wrote the article, assisted in algorithms implementation, and assisted with design and simulation; Y.-C.B. finalized, evaluated proof-read the manuscript, and supervised the work; N.P. did investigation, proof reading and evaluation. All authors have read and agreed to the published version of the manuscript.

**Funding:** This research received no external funding.

**Acknowledgments:** This work was supported by the Ministry of Education of the Republic of Korea and the National Research Foundation of Korea (NRF-2019S1A5C2A04083374), and also following are the results of a study on the "Leaders in INdustry-university Cooperation+" Project, supported by the Ministry of Education and National Research Foundation of Korea.

**Conflicts of Interest:** The authors declare no conflict of interest.

### **References**


**Publisher's Note:** MDPI stays neutral with regard to jurisdictional claims in published maps and institutional affiliations.

© 2020 by the authors. Licensee MDPI, Basel, Switzerland. This article is an open access article distributed under the terms and conditions of the Creative Commons Attribution (CC BY) license (http://creativecommons.org/licenses/by/4.0/).

## *Article* **Prediction of Stock Performance Using Deep Neural Networks**

### **Yanlei Gu 1,\*, Takuya Shibukawa 2, Yohei Kondo 2, Shintaro Nagao <sup>2</sup> and Shunsuke Kamijo <sup>3</sup>**


Received: 1 October 2020; Accepted: 10 November 2020; Published: 17 November 2020

**Abstract:** Stock performance prediction is one of the most challenging issues in time series data analysis. Machine learning models have been widely used to predict financial time series during the past decades. Even though automatic trading systems that use Artificial Intelligence (AI) have become a commonplace topic, there are few examples that successfully leverage the proven method invented by human stock traders to build automatic trading systems. This study proposes to build an automatic trading system by integrating AI and the proven method invented by human stock traders. In this study, firstly, the knowledge and experience of the successful stock traders are extracted from their related publications. After that, a Long Short-Term Memory-based deep neural network is developed to use the human stock traders' knowledge in the automatic trading system. In this study, four different strategies are developed for the stock performance prediction and feature selection is performed to achieve the best performance in the classification of good performance stocks. Finally, the proposed deep neural network is trained and evaluated based on the historic data of the Japanese stock market. Experimental results indicate that the proposed ranking-based stock classification considering historical volatility strategy has the best performance in the developed four strategies. This method can achieve about a 20% earning rate per year over the basis of all stocks and has a lower risk than the basis. Comparison experiments also show that the proposed method outperforms conventional methods.

**Keywords:** deep neural network; stock performance; earning rate; volatility

### **1. Introduction**

Stock performance prediction is one of the most challenging issues in time series data analysis. How to accurately predict stock performance changing is an open question with respect to the financial world and academia field. Stock performance prediction is a difficult task, due to the complexity and dynamic of the markets and many inexplicit, intertwined factors involved. Economic analysts and stock traders are the earliest pioneers who perform the prediction of stock performance. In the past several decades, thousands of books in stock trading have been published.

Many economic analysts and stock traders have studied the historical patterns of financial time series data and have proposed various methods to predict stock performance. In order to achieve a promising performance, most of these methods require careful selection of index variables and finding the sharing features among the distinguished stocks. William J. O'Neil and M. Weinstein are two representatives of successful traders. They summarized their stock trading experience in the publications [1–4]. William J. O'Neil's CAN SLIM method has a huge following, and also performed well in American Association of Individual Investors (AAII)'s implementation of his model [5]. M. Minervini revealed the proven, time-tested trading system he used to achieve triple-digit returns for five consecutive years, averaging 220% per year [6]. Many followers referred to their methodology in stock trading due to their remarkable achievement.

On the other hand, machine learning models, such as Artificial Neural Networks (ANNs) [7–12], Support Vector Regression (SVR) [13–15], Genetic Algorithms (GA) [16], as well as hybrid models [17] have beenwidely used to predict financial time series during recent decades. In addition, the time-series problem considers Dynamic TimeWarping (DTW) that handles scaling and shifting, which is common in the stock market. The recently developed DTW Network is an algorithm candidate for financial time series data processing [18]. Ramos-Requena et al. [19] used the Hurst exponent to measure the correlation and co-movement between two different series. Krollner et al. [20] surveyed papers using machine learning techniques for financial time series forecasting based on technique categories, such as ANN-based, evolutionary and optimization techniques, and multiple/hybrid methods. Cavalcante et al. [21] provided a comprehensive overview of the most important primary studies, which cover techniques such as ANN, SVM, hybrid mechanisms, optimization, and ensemble methods. The surveys indicate that the approaches differ regarding the number and types of variables used in modeling financial behavior; however, there is no consensus on which input variables are the best to be used. In addition, it is important to note that there is no well-established methodology to guide the construction of a successful intelligent trading system. The profit evaluation of the proposed methods when used in real-world applications are generally neglected [21].

Recently, deep learning, as an advanced version of ANN, has attracted attention in the machine learning field because of its high performance in areas such as image recognition and speech recognition. In the field of financial forecasting, a similar new trend considers that a deep neural network has the possibility to increase the accuracy of stock market prediction [22,23]. There are two main deep learning approaches that have been used in stock market prediction: Recurrent Neural Networks (RNN) and Convolutional Neural Networks (CNN). Rout et al. made use of a low complexity RNN for stock market prediction [24]. Pinheiro et al. explored RNN with character-level language model pre-training for both intraday and interday stock market forecasting. The proposed automated trading system that, given the release of news information about a company, predicted changes in stock prices [25]. Li et al. adopted the Long Short-Term Memory (LSTM) neural network, which is an improved version of RNN, and incorporates investor sentiment and market factors to analyze the irrational component of stock price [26]. Nelson et al. studied the usage of LSTM networks to predict future trends of stock prices based on the price history, alongside with technical analysis indicators [27]. Bao et al. presented a deep learning framework where wavelet transforms (WT), stacked autoencoders (SAEs), and LSTM are combined for stock price forecasting [28]. Fischer et al. used deep learning, random forests, gradient-boosted trees, and different ensembles as forecasting methods on all S&P 500 constituents from 1992 to 2015. One key finding in their research is that LSTM networks outperform memory-free classification methods [29]. In order to show accountability to their customers, Nakagawa et al. proposed to approximate and linearize the learned LSTM models by layer-wise relevance propagation [30].

Compared to RNN and LSTM, there are a relatively few examples of applying CNN for stock market prediction. Sezer et al. proposed a novel algorithmic trading model CNN-TA using a 2-D convolutional neural network based on 2-D images converted from financial time series data [31]. Zhou et al. proposed a generic framework employing LSTM and CNN for adversarial training to forecast high-frequency stock market [32]. On the whole, the LSTM network is the most widely used deep learning technology for stock performance prediction.

As mentioned above, automatic trading systems that use Artificial Intelligence (AI) have become a commonplace topic, but there are few examples that successfully leverage the proven method invented by human stock traders to build automatic trading systems. The first contribution of this study is the development of an intelligent trading system by integrating AI and the knowledge of human stock traders. In this study, the important index variables suggested by economic analysts and stock traders are used in a deep neural network to predict future stock performance. The second contribution of this study is the verification of the effectiveness of the knowledge of human stock traders and various investment strategies for constructing a successful intelligent trading system. In this study, what index variables are the most significant, and how to perform the stock performance prediction to maximize earning and minimize the risk of investment are investigated. This study is focused on Japanese stock data to explore a reliable investment algorithm for the Japanese stock market. This also aims to verify whether the method invented based on United State (US) stocks is also effective in Japanese stocks, because the traders William J. O'Neil and M. Minervini summarize their experience based on US stocks.

The rest of the paper is organized as follows: Section 2 describes the important index variables and four strategies for stock classification. Section 3 presents the proposed deep neural network to classify the distinguished stocks with good performance. Section 4 shows the evaluation of the proposed systems. Finally, Section 5 concludes this paper.

### **2. Important Index Variables and Stock Classification**

### *2.1. Important Index Variables for Stock Performance Prediction*

In the current stock market, there are hundreds of index variables indicating the value of a stock from different aspects. Professional analysts and stock traders have tried hard to find the correlation between variables and the future performance of stocks. William J. O'Neil and M. Minervini provided many important points and rules for successful stock trading [1,2]. Table 1 lists the 21 index variables (are also called features) that are the most frequently used for recognizing the distinguished stocks in their related publications [1,2].

Among the most important issues in the development of an intelligent trading system is to decide what features should be used for stock performance prediction. One way is just following the suggestions of human stock traders and feeding all features into the developed system. In addition to using all these suggested features, this study presents a feature selection test and verifies the effectiveness of those features. Based on the definition and characteristic of the features, the 21 features are categorized into four groups: price-related features, trading volume features, company financial status-related features, and others. The results of the feature selection test are discussed in Section 4.

In this study, the related data of the important indices with a weekly resolution are downloaded from the stock database and these important features are used as the input of the deep learning algorithm. Daily price-related data and daily trading volume data are also used as the input data of the deep learning algorithm because these two kinds of data (price and trading volume) are the most important for the prediction of stock price from the viewpoint of human stock traders. Using the additional daily resolution data can avoid missing significant dynamic in each week. When the weekly and daily data are used together, the two kinds of data should be synchronized based on time. The solution for data synchronization is explained in Section 3.



### *2.2. Definition of Positive Samples for Stock Classification*

One of the simplest ways of constructing an intelligent trading system is to employ the binary classification algorithm to classify all stocks as two groups: positive samples which are the stocks with good future performance, and negative samples which are the rest of the stocks. This study presents four strategies for classification, and evaluates the strategies from the aspects of both earning rate and risk of investment.

#### 2.2.1. Constant Threshold-Based Stock Classification

Fund managers are concerned about the rising rate α<sup>C</sup> of stock price—expressed in Equation (1). For example, if the price rising rate of a stock could surpass a threshold in the next 12 weeks, this stock could be a good candidate for investment. In Equation (1), CP is the closing price in the current week, and HPin\_next\_12weeks denotes the highest price in the next 12 weeks. In this constant threshold-based stock classification, α<sup>C</sup> is used to classify stocks. It means that if α<sup>C</sup> of a stock is higher than the threshold, the stock will be defined as a positive sample in the stock classification; otherwise, the stock is a negative sample. In this study, 70% was chosen as the threshold based on the experience of fund managers. If the developed intelligent trading system uses the constant threshold-based method to select stocks, the system will simply predict whether α<sup>C</sup> of the stock is over the threshold or not. The system will buy the positive stock in the current week, and sell the positive stock when its α<sup>C</sup> reaches the threshold in the next 12 weeks.

$$\alpha\_{\mathbb{C}} = \frac{\text{HP}\_{\text{in\\_ncxt\\_12wecks}} - \text{CP}}{\text{CP}} \tag{1}$$

### 2.2.2. Ranking-Based Stock Classification

Because the situation of the market is different every year, in a "good" year, many stocks have good performance and have a high price rising rate. In a "bad" year, the number of stocks with a high price rising rate becomes fewer. In this case, the investment will focus on fewer stocks if the constant threshold method is used for selecting stocks. However, it is necessary to maintain the number of selected stocks and distribute the investment in different stocks to reduce risk.

This study proposes the second strategy in the stock classification. All the stock samples are ranked based on the rising rate α<sup>R</sup> expressed in Equation (2), then the top *x*% samples are defined as positive samples and the rest (100 − *x*)% of the samples are negative samples. In Equation (2), CP is the closing price in the current week, and CPafter\_12weeks denotes the closing price after 12 weeks. In this ranking-based stock classification, α<sup>R</sup> is used to classify the stocks. The value of *x* was empirically decided as 10 in this research. When the developed intelligent trading system uses this method to select stocks, the system will predict whether the stock is ranked as the top 10% or not, and select the 10% stocks as the positive samples for investment. In this method, there is no constant threshold to decide whether the stocks are in the top 10% or not; therefore, it is impossible to use the strategy of the constant threshold-based method (sell the positive stock when α<sup>C</sup> of the stock reaches the threshold) for trading. In this method, the developed intelligent trading system will predict whether the α<sup>R</sup> of the stock is in the top 10%, and keep the detected top 10% stocks for 12 weeks before selling them. Therefore, α<sup>R</sup> is defined as the price rising rate after 12 weeks, which is different from Equation (1).

$$\alpha\_{\rm R} = \frac{\rm CP\_{after\\_12\,week} - \rm CP}{\rm CP} \tag{2}$$

### 2.2.3. Constant Threshold-Based Stock Classification Considering Historical Volatility

In addition to the price rising rate, volatility is also a factor that needs to be considered in the investment. It is necessary to select the stocks which have both a high price rising rate and low volatility. Therefore, this study proposes the third strategy: constant threshold-based stock classification considering historical volatility. In this method, the target rate β<sup>C</sup> is defined as Equation (3).

$$\beta\_{\mathbb{C}} = \frac{\alpha\_{\mathbb{C}}}{\text{STD} \{ \mathbf{CP}\_{\text{in\\_past\\_12weeks}} / \text{CP} \}} \tag{3}$$

where α<sup>C</sup> can be calculated using Equation (1), and **CP**in\_past\_12weeks is a vector which includes the daily closing price in the past 12 weeks. STD **CP**in\_past\_12weeks/CP is the standard deviation of the normalized price in the past 12 weeks. The value of STD **CP**in\_past\_12weeks/CP can indicate the historical volatility. In this method, if β<sup>C</sup> of a stock is higher than a threshold, the stock will be defined as a positive sample in the stock classification; otherwise, the stock is a negative sample. In this study, 8 was chosen as the threshold based on the experience of fund managers. When the developed intelligent trading system uses this strategy, the system will predict whether β<sup>C</sup> of the stock is over the threshold or not. The system will buy the positive stock in the current week, and sell the positive stock when β<sup>C</sup> of the stock reaches the threshold in the next 12 weeks.

### 2.2.4. Ranking Threshold-Based Stock Classification Considering Historical Volatility

Similar to the idea in the second strategy, it is also possible to develop the ranking threshold-based stock classification considering historical volatility. In this method, the target rate β<sup>R</sup> is defined as Equation (4).

$$\beta\_{\rm R} = \frac{\alpha\_{\rm R}}{\rm STD} \left[ \left( \rm CP\_{\rm in\\_past\\_12\,weks} \right) / \rm CP \right] \tag{4}$$

where α<sup>R</sup> can be calculated using Equation (2). The developed intelligent trading system will predict whether β<sup>R</sup> of the stock is in the top 10%, and keep the detected top 10% stocks for 12 weeks before selling them.

The designed four strategies are summarized in Table 2. Section 4 presents the performance of the stock trading system developed based on the proposed 4 strategies.


#### **Table 2.** Summary of the proposed four strategies for stock classification.

### **3. Deep Neural Network-Based Model for Stock Performance Prediction**

#### *3.1. Long Short-Term Memory Networks*

Long Short-Term Memory networks—usually just called "LSTMs"—are a special kind of RNN equipped with a special gating mechanism that controls access to memory cells. Since the introduction of the gates, LSTM and its variant have shown great promise in tackling various sequence modeling tasks in machine learning—e.g., natural language processing, image captioning, and speech recognition. Basically, a LSTM unit consists of an input gate, a forget gate, and an output gate. The architecture of an LSTM unit is shown in Figure 1.

**Figure 1.** Visualization of a Long Short-Term Memory (LSTM) unit.

Suppose that **x***<sup>t</sup>* is the input and **h***t*−<sup>1</sup> is the hidden output from the last time step *t*-1, the input gate decides how much of the new information will be added to the cell state **<sup>c</sup>***t*, and generates a candidate **<sup>~</sup> c***t* by:

$$\mathbf{i}\_l = \sigma(\mathbf{W}\_{\text{xi}}\mathbf{x}\_l + \mathbf{W}\_{\text{hi}}\mathbf{h}\_{l-1} + \mathbf{b}\_i) \tag{5}$$

$$
\tilde{\mathbf{c}}\_{t} = \phi \left( \mathbf{W}\_{\text{xc}} \mathbf{x}\_{t} + \mathbf{W}\_{\text{hc}} \mathbf{h}\_{t-1} + \mathbf{b}\_{\text{c}} \right) \tag{6}
$$

where **<sup>i</sup>***<sup>t</sup>* can be thought of as a knob that the LSTM learns to selectively consider **<sup>~</sup> c***<sup>t</sup>* for the current time step. σ is the logistic sigmoid function and φ is *tanh*. Generally, **W** terms denote weight matrices (e.g., **W***xi* is the matrix of weights from the input to the input gate), and **b** terms are the bias vectors. The forget gate decides how previous information will be kept in the new time step, and is defined as:

$$\mathbf{f}\_t = \sigma(\mathbf{W}\_{xf}\mathbf{x}\_t + \mathbf{W}\_{hf}\mathbf{h}\_{t-1} + \mathbf{b}\_f) \tag{7}$$

Then, the cell state **c***t* is updated by:

$$\mathbf{c}\_{t} = \mathbf{f}\_{t} \odot \mathbf{c}\_{t-1} + \mathbf{i}\_{t} \odot \tilde{\mathbf{c}}\_{t} \tag{8}$$

where is the element-wise product of the vectors. Then, the output gate uses the output **o***<sup>t</sup>* to control what is then read from the new cell state **c***t* onto the hidden vector **h***t* as follows:

$$\mathbf{o}\_{t} = \sigma(\mathbf{W}\_{\text{xo}}\mathbf{x}\_{t} + \mathbf{W}\_{ho}\mathbf{h}\_{t-1} + \mathbf{b}\_{o}) \tag{9}$$

$$\mathbf{h}\_{l} = \mathbf{o}\_{l} \odot \phi(\mathbf{c}\_{l}) \tag{10}$$

In this study, the functional LSTM(·,·,·) is used as shorthand for the LSTM model in Equation (11):

$$\mathbf{r}(\mathbf{h}\_{l}, \mathbf{c}\_{l}) = \text{LSTM}(\mathbf{x}\_{l}, \mathbf{h}\_{t-1}, \mathbf{c}\_{t-1}, \mathbf{W}, \mathbf{b}) \tag{11}$$

where **W** and **b** include the weight matrices and bias vectors indicated in Equations (5)–(9). The value of **W** and **b** are determined in the training step.

### *3.2. Concatenated Double-Layered LSTM for Stock Performance Prediction*

This study proposes a LSTM-based network to predict the future performance of stocks by classification. The proposed network classifies stocks into two categories (buying or not) based on the historical sequence data. The "Many to one" model has been widely used in sequence data processing. To fully use the memory and forget ability of LSTM, our proposed network is also a "many to one" model. The architecture of the proposed classification network is shown in Figure 2. This means that when classifying the stocks into two categories (buying or not), the historical sequence

data from time *t*− *n* to *t*: (**r***t*−*n*, ... ,**r***t*) are input into the network together. In the proposed network, *n* was empirically decided as 52 by considering the experience of professional traders.

**Figure 2.** "Many to one" architecture for the classification of stocks.

Figure 3 corresponds to the model in Figure 2. It shows the architecture of the model block. There are double-layered LSTMs. Finally, the output is connected with the last LSTM layer. Here, **x***<sup>t</sup>* denotes the input data at week *t*. The double-layered LSTM model can be explained by:

$$\mathbf{h}\left(\mathbf{h}\_{t}^{1},\mathbf{c}\_{t}^{1}\right) = \text{LSTM1}\left(\mathbf{x}\_{t},\mathbf{h}\_{t-1}^{1},\mathbf{c}\_{t-1}^{1},\mathbf{W}^{1},\mathbf{b}^{1}\right) \tag{12}$$

$$\mathbf{h}\left(\mathbf{h}\_{t}^{2},\mathbf{c}\_{t}^{2}\right) = \text{LSTM2}\left(\mathbf{h}\_{t}^{1},\mathbf{h}\_{t-1}^{2},\mathbf{c}\_{t-1}^{2},\mathbf{W}^{2},\mathbf{b}^{2}\right) \tag{13}$$

In Figure 3, **s***<sup>n</sup> <sup>t</sup>* stands for **h***n <sup>t</sup>* , **c***<sup>n</sup> t* , *n* = 1 *or* 2. To reduce the complexity of the task, in this study, a binary-class classification system was developed for stock performance prediction. This means that the classification system is expected to recognize two categories. Therefore, the output of the final hidden layer connects two nodes to indicate the probabilities for two categories, the probabilities can be estimated from the output of the second LSTM layer as:

$$\mathbf{P}\_{l} = \mathbf{W}\_{2} \ \mathbf{h}\_{l}^{2} + \mathbf{b}\_{2} \tag{14}$$

where **W**<sup>2</sup> is a weight matrix from the hidden layer to the output layer, and **b**<sup>2</sup> is the bias vector. **P***<sup>t</sup>* is a vector to indicate the probability of the sample for two categories: 0 and 1. The category

"0" means not buying a stock, and the category "1" means buying a stock. The softmax layer and classification layer are responsible for normalization and category selection which are explained in the following equations:

$$\text{Normalization: } \sigma(\mathbf{P}\_t)\_i = \frac{e^{P\_{t,i}}}{\sum\_{k=0}^{1} e^{P\_{t,k}}} \text{ for } i = 0 \text{ or } 1 \text{ and } \mathbf{P}\_t = [P\_{t,0,\*} P\_{t,1}] \tag{15}$$

$$\text{Capacity selection} = \begin{cases} \begin{bmatrix} 1,0 \end{bmatrix} \text{ if } \sigma(\mathbf{P}\_t)\_0 > \sigma(\mathbf{P}\_t)\_1\\ \begin{bmatrix} 0,1 \end{bmatrix} \text{else} \end{cases} \tag{16}$$

The double-layer LSTM model can predict whether the automatic trading system should buy or not buy a stock, given the past 52-week history information of that stock. The output of the double-layer LSTM model could be [1, 0] or [0, 1]. [1, 0] means not buying the stock, and [0, 1] means buying the stock. The output ([1, 0] or [0, 1]) is decided in the category selection block based on the comparison of the normalized probabilities provided by the softmax layer. If the probability of category "0" (not buying a stock) is higher than the probability of category "1" (buying a stock), the output is [1, 0]. Otherwise, the output will be [0, 1].

In the training of the network, (**r***t*−*n*, ... ,**r***t*); Ground Truth ([1, 0]or[0, 1]) is used as the sample data because of the "many to one" architecture. The training algorithm automatically adjusts the parameters in the model based on the principle of the gradient descent. In the training, the loss function is a cross entropy:

$$Loss = \ -\sum\_{j=1}^{N} \mathbf{T}\_j (\log(\sigma(\mathbf{P}\_{j,t})))^T \tag{17}$$

where, *N* is the number of samples in the training dataset. **T***<sup>j</sup>* is the row vector format ground truth for sample *j*. The objective of the training process is to minimize the value of Loss Function Equation (17). In Figure 3, both weekly and daily data are used as the input of the LSTM-based deep learning network; one weekly datum can be connected with five daily data from Monday to Friday.

Considering the particularity of the stock classification, the mistakes in the classification have different practical meanings. For example, in comparison to false negative (stocks with good performance are missed in detection), the false positive (stocks are incorrectly detected as good performance stocks) has a higher risk in the real investment. It is possible to give a higher weighting to false positive in the loss function to force the training to reduce the false positive. For example, the ratio of the weighting of false negative and false positive could be 1:2, 1:3, and so on. Therefore, the loss function is reformed as Equation (18). In the experiments, this research presents an attempt to find the best option for parameter **W***Loss* in order to achieve good performance of the developed trading system.

$$Loss = -\sum\_{j=1}^{N} \mathbf{T}\_j \Big(\mathbf{W}\_{Loss} \odot \log\Big(\sigma\Big(\mathbf{P}\_{j,t}\Big)\Big)\Big)^T \tag{18}$$

### **4. Experiment Results**

### *4.1. Experiment Setup and Evaluation Criteria*

In this study, a deep neural network was adopted and 52-week historical data of features were used as the input of the deep neural network for a binary classification. For example, when the system performed the classification on 2018/04/01 to select the stocks with good performance in the next 12 weeks, the historical data of 2017/04/03–2018/04/01 (52 weeks) were input into the deep neural network. The ground truth of each historical data is binary data: buying the stock or not, when talking about the future 12-week performance. Because the data are organized weekly, one stock can provide 52 samples per year. For example, the historical data of the samples could be 2017/04/03–2018/04/01, 2017/04/10–2018/04/08, 2017/04/17–2018/04/15, ... , 2018/03/27–2019/03/25. The future 12-week

data of these samples are 2018/04/01–2018/06/25, 2018/04/08–2018/07/02, 2018/04/15–2018/07/09, ... , 2019/03/25–2019/06/17, as shown in Figure 4. In the following description, this paper uses the time of the end of data to denote the 52-week length historical data input into the deep neural network. For example, "2018/04/01" denotes the historical data "2017/04/03–2018/04/01".

**Figure 4.** The 52 samples extracted from one stock data.

In this study, the training dataset, validation dataset, and test dataset were separated based on the year and month as shown in Table 3. In order to verify the repeatability of the proposed method, this study presents the evaluation of the different datasets. For example, when the system was tested on the dataset of the period from 2018/04/01 to 2018/09/30 (as indicated in the final row of Table 3), the data from 2017/04/01 to 2017/12/31 were used for validation, and the data from 2001/04/01 to 2016/09/30 were used for training. The datasets in each row of Table 3 are considered as one set. Averagely, the number of samples in each training, validation, and test dataset is about 450,000, 45,000, and 30,000, respectively. It is important to note that there is no overlap among training, validation, and test data in each set. The basic process in the evaluation of each set of datasets is to use the training dataset for training the model and obtaining multiple classifiers. After that, the best classifier is selected based on the validation dataset. Finally, the selected classifier is evaluated in the test dataset. This process was conducted on each set of datasets to demonstrate the repeatability of the proposed methods.

The training dataset was used to train the model. Training is an iteration process with multiple epochs. One epoch means that all training data have been used once for backpropagation. In this study, the number of epochs was set as 50, because 50 epochs are enough for the convergency of the training process. The training process output one classifier after each epoch. Therefore, 50 different classifiers were generated after 50 epochs. Theoretically, the final classifier should have the best performance. However, the performance of the classifiers did not change too much in the training process. One reason is that enough training data were provided for the deep learning algorithm. After several epochs, the training processing converged, and the parameters of the classifier were optimal.


**Table 3.** Training dataset, validation dataset, and test dataset for repeatability evaluation.

However, how to choose the best classifier is a problem. In this study, a validation dataset was used to choose the best classifier. As shown in Table 3, the validation dataset is the most recent year before the test dataset. In addition, there is three-month gap between the validation data and test data. When the system works on the day of 2018/03/31 and predicts the future of stocks in the next 12 weeks, the validation dataset should be the data from 2017/04/01 to 2018/03/31. However, the future 12-week data for the historical data from 2018/01/01 to 2018/03/31 are not available on the day 2018/03/31. Therefore, the data from 2018/01/01 to 2018/03/31 cannot be used for validation, and the validation dataset has a 9-month period.

In addition, as described in Section 3, the low false positive value is also expected in the stock selection. Therefore, when choosing the classifier, it is also necessary to consider which one has a low false positive value and maintain the high true positive value at the same time. In this study, the following best precision criterion was adopted to select the classifier in the validation dataset:

$$\underset{\text{i-Prmax}}{\text{argmax}}(\text{Precision}), \text{ where } \text{Precision} = \frac{\text{TP}}{\text{TP} + \text{FP}} \tag{19}$$

where TP is the number of true positive samples and FP is the number of false positive samples.

In the repeatability evaluation, the test dataset had half a year period, and the validation dataset had a 9-month period. The training dataset is the data excluding the test and validation dataset. For example, when the test dataset is data from 2011/04/01 to 2011/09/30, the validation dataset is data from 2010/04/01 to 2010/12/31. In this case, the training dataset was the data from 2001/04/01 to 2010/03/31 and the data from 2012/10/01 to 2018/09/30. In this study, future data after the test data were used for training, because the deep learning needs huge training data to achieve good performance. Using the data after the test data period increases the number of training samples. It is important to note that there is a one-year gap from the end of the test data period to the training data period, because excluding the data in that one year can strictly guarantee that any part of the test data is not used in the training.

In this study, True Positive Rate (TPR, Recall), True Negative Rate (TNR), Average Correction Rate (ACR), and Precision were used to evaluate the performance of the developed prediction systems. The evaluation criteria are denoted in Equations (20)–(23):

$$\text{TPR} = \text{Recall} = \frac{\text{TP}}{\text{P}} = \frac{\text{TP}}{\text{TP} + \text{FN}} \tag{20}$$

$$\text{TNR} = \frac{\text{TN}}{\text{N}} = \frac{\text{TN}}{\text{TN} + \text{FP}} \tag{21}$$

$$\text{ACR} = \frac{\text{TPR} + \text{TNR}}{2} \tag{22}$$

$$\text{Precision} = \frac{\text{TP}}{\text{TP} + \text{FP}} \tag{23}$$

where P is the number of all positive samples, and N is the number of all negative samples. TP is the number of true positive samples, FP is the number of false positive samples, TN is the number of true negative samples, and FN is the number of false negative samples.

In addition to the four criteria, the average maximum price rising rate of stocks and average maximum price decreasing rate of stocks were also used in the evaluation. Moreover, this study also presents a simulation of stock trading to evaluate the performance of the proposed methods, and the details of the simulation are presented in each following each subsection.

### *4.2. Results of Constant Threshold-Based Stock Classification*

In the evaluation of the method of constant threshold-based stock classification, two factors should be discussed: **W***Loss* parameter in Equation (18), and features in Table 1. Table 4 shows the performance of the classification using all features and **W***Loss* (1:1) values.

The first column indicates the time period of the test dataset. True negative rate, true positive rate (recall), average correction rate, and precision are listed from the second to fifth columns. The average of the maximum rising rate of detected good performance stocks and all stocks are demonstrated in the sixth and seventh columns. The average of the maximum decreasing rate of detected good performance stocks and all stocks are shown in the eighth and ninth columns.

Moreover, this study also presents a simulation of a real stock trading system. In the case of the binary-class classification system using 70% rising rate threshold, the system sets up a 70% rising rate as the selling point. The system will firstly buy all selected stocks. If a selected stock (detected positive sample) achieves 70% rising rate, the system sells it immediately. Otherwise, the stock is kept and sold by the end of 12 weeks. The tenth and eleventh columns of Table 4 show the earning rate of the simulated stock trading system. In addition, the twelfth column of Table 4 provides the basis of all stocks. The basis is the average earning rate from present to 12 weeks later.

In addition, the other two criteria are used in the evaluation of risk: Sharpe ratio with trading on selling point and Sharpe ratio without trading on selling point. The two criteria are defined as Equations (24) and (25):

$$\text{SR}\_{\text{T}} = \frac{\left(\text{P}\_{\text{selling}} - \text{CP}\right) / \text{CP}}{\text{STD} \{ \text{CP}\_{\text{until\\_selling}} / \text{CP} \}} \tag{24}$$

$$\text{SR}\_{\text{NT}} = \frac{(\text{CP}\_{\text{after\\_12weeks}} - \text{CP})/\text{CP}}{\text{STD}(\text{CP}\_{\text{in\\_next\\_12weeks}}/\text{CP})} \tag{25}$$

where CP is the closing price in the current week, Pselling is the price when selling the stock, and CPafter\_12weeks is the close price after 12 weeks. In Equation (24), **CP**until\_selling is a vector which includes the daily closing price from the current week to selling. **CP**in\_next\_12weeks is a vector which includes the daily closing price in the next 12 weeks.

In fact, the Sharpe ratio without trading on selling point means the stocks will be sold at the end of 12 weeks. The thirteenth and fourteenth columns of Table 4 show the Sharpe ratio with trading on selling point, the fifteenth and sixteenth columns of Table 4 illustrate the Sharpe ratio without trading on selling point. In addition, the average of all tests is listed in the last row of Table 4. This study presents the evaluation of different features and **W***Loss* values. Because of the limitation of the page length, Table 5 shows the summary of these evaluations. In this study, the results generated using different **W***Loss* values were compared, and then the best **W***Loss* values were chosen for the feature selection. The following conclusions can be obtained from the data in Table 5:


### *4.3. Results of Ranking-Based Stock Classification*

Similar to the constant threshold-based stock classification, this study also presents multiple evaluations for the ranking-based stock classification. In the evaluation, the effect of different values of **W***Loss* and input features was tested. In the ranking-based stock classification, the top 10% stocks were considered as positive samples. Table 6 shows results generated using the different configurations. There was no fixed threshold for stock trading; therefore, the simulated stock trading system kept the detected or all stocks until the end of the following 12 weeks and then sold them. Thus, there was no Sharpe ratio with trading. The following conclusions can be obtained from the data in Table 6.


### *4.4. Results of Constant Threshold-Based Stock Classification Considering Historical Volatility*

Similar to the constant threshold-based stock classification, this study also presents multiple evaluations for the constant threshold-based stock classification considering historical volatility. In the evaluation, the effect of different values of the parameter **W***Loss* and input features was tested. Table 7 shows results using different configurations. In the ranking-based stock classification, the threshold for target rate β<sup>C</sup> was set as eight to classify positive samples. In the simulated trading system, the target rate eight was set as the selling point. If the stock achieves the target rate eight, it will be sold immediately. Otherwise, the system will keep the stock and sell it by the end of 12 weeks. The following conclusion can be obtained from the data in Table 7.




**Table 5.** Constant threshold-based stock classification with different **W***Loss* values and different features (all: all features; price: price-related features; trading volume:

trading volume feature; financial status: company financial status-related features).



**Table 6.** Ranking-based stock classification with di fferent **W***Loss* values and di fferent features (all: all features; price: price-related features; trading volume: trading


### *4.5. Results of Ranking-Based Stock Classification Considering Historical Volatility*

Similar to the previous methods, this study also presents multiple evaluations for the ranking-based stock classification considering historical volatility. In the evaluation, the effect of different values of parameter **W***Loss* and input features was tested. Table 8 shows the results using different configurations. In this method, the top 10% of stocks were considered as positive samples. Table 8 shows a summary of the evaluations. There was no fixed threshold for stock trading; therefore, the simulated stock trading system kept the detected or all stocks until the end of next 12 weeks, then sold them. Thus, there was no Sharpe ratio with trading. The following conclusion can be obtained from the data in Table 8.


In addition, in this study, the proposed deep neural network-based stock performance prediction method was compared with two conventional methods: Logistic Regression-based classification and Support Vector Machine (SVM)-based classification. This comparison was performed for the ranking-based stock classification considering historical volatility. The comparison in Table 9 shows the proposed method has a higher earning rate and a lower risk than the conventional methods.



### **5. Conclusions**

This study presents four strategies for stock classification and performs feature selection to achieve a higher earning rate and lower risk in stock classification. The following points are concluded based on the evaluations and analysis:


The proposed method can have about 36% (9.044% per 12 weeks) earning rate per year in the Japanese stock market. However, excellent human stock traders have achieved triple-digit returns per year [6]. There is still a gap between the performance of the developed intelligent trading system and the achievements of human stock traders. This paper proposed the use of the classification network to classify the stocks into two categories: buying or not. In the future, a regression network will be developed to predict the exact value of the future price. In this way, the developed trading system is expected to obtain higher earnings.

**Author Contributions:** Conceptualization, S.K. and S.N.; methodology, S.K., S.N., T.S., Y.K. and Y.G.; data acquisition, T.S. and Y.G.; software, Y.G.; experiment, Y.G.; writing, Y.G.; project administration, S.K. and S.N.; All authors have read and agreed to the published version of the manuscript.

**Funding:** This research received no external funding.

**Conflicts of Interest:** The authors declare no conflict of interest.

### **References**


**Publisher's Note:** MDPI stays neutral with regard to jurisdictional claims in published maps and institutional affiliations.

© 2020 by the authors. Licensee MDPI, Basel, Switzerland. This article is an open access article distributed under the terms and conditions of the Creative Commons Attribution (CC BY) license (http://creativecommons.org/licenses/by/4.0/).

## *Article* **Hybrid Forecasting Models Based on the Neural Networks for the Volatility of Bitcoin**

### **Monghwan Seo <sup>1</sup> and Geonwoo Kim 2,\***


Received: 10 June 2020; Accepted: 9 July 2020; Published: 10 July 2020

**Abstract:** In this paper, we study the volatility forecasts in the Bitcoin market, which has become popular in the global market in recent years. Since the volatility forecasts help trading decisions of traders who want a profit, the volatility forecasting is an important task in the market. For the improvement of the forecasting accuracy of Bitcoin's volatility, we develop the hybrid forecasting models combining the GARCH family models with the machine learning (ML) approach. Specifically, we adopt Artificial Neural Network (ANN) and Higher Order Neural Network (HONN) for the ML approach and construct the hybrid models using the outputs of the GARCH models and several relevant variables as input variables. We carry out many experiments based on the proposed models and compare the forecasting accuracy of the models. In addition, we provide the Model Confidence Set (MCS) test to find statistically the best model. The results show that the hybrid models based on HONN provide more accurate forecasts than the other models.

**Keywords:** Bitcoin; artificial neural network; higher order neural network; volatility forecasting; hybrid models

### **1. Introduction and Review of Models**

### *1.1. Introduction*

Online transactions over the Internet have depended on trusted financial institutions, which are central players for safe transactions. Nakamoto [1] proposed Bitcoin as a digital currency to provide an easy method to perform online transactions. Bitcoin is a peer-to-peer crypocurrency system, where Bitcoin transactions occur with no central players. All Bitcoin transactions are verified by the nodes of the peer-to-peer networks and added to the blockchain as the Bitcoin ledger. The information of all historical transactions and all Bitcoin clients is stored in the blackchain. That is, Bitcoin transactions are recorded in the blockchain. The value of Bitcoin is not based on the economic condition in any country and depends on only the supply and demand of the network. Thus, Bitcoin has been utilized widely as a digital currency that can be exchanged for real products or services based on the Bitcoin market value. In fact, there are various digital currencies such as Ethereum, Ripple, Stellar, etc. However, we focus only on Bitcoin because the Bitcoin market capitalization is about 50% of the total estimated digital currency capitalization at present.

As the Bitcoin market has grown over the years, there have been many studies to analyze the Bitcoin market in recent years. Urquhart [2] studied the efficiency of Bitcoin market. In an efficient market, due to the random nature of unpredictable events, variations are random. To find the inefficiency, Urquhart employed a battery of highly powerful tests for randomness and found evidence of inefficiency. The high-frequency multifractal properties of Bitcoin were examined in [3]. Gajardo et al. [4] analyzed the asymmetric multifractal cross-correlations among stock market indices, commodities and Bitcoin. Yonghong et al. [5] also investigated the time-varying long-term memory in the Bitcoin market. Dyhrberg [6,7] showed that Bitcoin has a clear role in the market for portfolio management. Some researchers studied Bitcoin as an investment vehicle [8–10]. They found out that Bitcoin investment has characteristic features such as high average return and volatility. Although the volatilities of various financial indices have an important impact on the Bitcoin market, the most important factor that affects the high volatility of Bitcoin is the speculative behavior of users. In addition, there was a study on economic analyses of Bitcoin as a currency [11]. According to Iwamura et al. [11] and Yermack [12], Bitcoin may not be suitable as currency since Bitcoin has high volatility. Baur et al. [13] also showed that Bitcoin is used as a speculative investment due to high volatility and large returns. In practice, since the Bitcoin market has high volatility, the study on the volatility of Bitcoin has been very important. We focus on the volatility of Bitcoin in this paper. Specifically, we study the accurate methods for forecasting of Bitcoin volatility.

Many researchers have investigated the analysis and prediction of Bitcoin volatility recently. Baur and Dimpfl [14] analyzed asymmetric volatility effects for Bitcoin. Other studies attempted to show that Bitcoin volatility has some properties such as chaos, randomness, multi-fractality and long-range memory [15,16]. Additionally, there have been many studies on the forecasting of Bitcoin volatility. Balcilar et al. [17] studied the prediction of Bitcoin volatility with a quantile test based on the trading volume. Katsiampa [18] investigated several GARCH family models to find the best model for Bitcoin volatility and found that the AR-CGARCH is the optimal model. Chu et al. [19] provided the best fitting models based on GARCH models for volatilities of cryptocurrencies including Bitcoin. They fit 12 GARCH models to each cryptocurrency and found that IGARCH (1,1) model provides a good fit. Conrad et al. [20] used the GARCH-MIDAS model to improve the prediction of long-term Bitcoin volatility. However, GARCH models have limitations that are hard to capture complex fluctuation and nonlinear correlation of time series data. In order to overcome these limitations, many researchers have proposed the non-parametric forecasting methods based on machine learning approaches such as ANN for better forecasting of Bitcoin volatility [21–23].

Over the past few years, there have been various hybrid models based on ANN to improve the forecasting ability of the time series data. In particular, the hybrid models based on ANN and GARCH models have been proposed to improve forecast accuracy for the time-series data such as market indices, exchange rate, stock volatility, gold price, oil price and metal, etc. [24–30]. These results have shown that the hybrid models have an advantage compared to ANN models. The so-called ANN-GARCH models are the hybrid models that incorporate the GARCH forecasts as the explanatory variables to the ANN models and have been developed consistently by many researchers. For instance, Hajizadeh et al. [31] proposed two ANN-GARCH models to improve the forecasting performance of the S&P 500 index volatility. They used various input variables including financial indicators and the simulated volatility by GARCH models, and the proposed hybrid model with EGARCH model show better accuracy than the traditional GARCH models and ANN models. Kristjanpoller et al. [32] provided the methodology and the application for the volatility forecast of three Latin American stock indexes using a hybrid ANN-GARCH model. Lahmiri and Boukadoum [33] presented an ensemble system based on a hybrid EGARCH-ANN model which is trained with a different distributional assumption. In addition, Seo et al. [34] constructed the hybrid ANN-GARCH model with Google domestic trend and various activation functions for better forecasting accuracy of S&P 500 index volatility. In this paper, we also employ the ANN-GARCH models for accurate forecasting of the realized volatility of Bitcoin. Specifically, we develop ANN-GARCH models with HONN and Google trends (GT) data and compare the proposed models to find the best fitting model for Bitcoin volatility.

The contribution of this work is to find the optimal hybrid model for forecasting Bitcoin's volatility. To present our result, this paper is structured as follows. In the next subsection, we review the models used in this paper. In Section 2, we describe the data used for the proposed hybrid models. In Section 3, we construct efficient hybrid models and provide the results of the experiments by the proposed models. In Section 4, we present the concluding remarks.

### *1.2. Review of Models*

In this section, we introduce GARCH family models used to construct our hybrid models. More specifically, we review the GARCH model, EGARCH model and GJR-GARCH model. The forecasts by GARCH family models are used as the explanatory variables to ANN. We also review ANN model and HONN model with various activation functions used in this paper.

#### 1.2.1. GARCH Model

The ARCH model proposed by Engle [35] was the first model with the conditional distribution to describe the fat tail characteristics or the volatility clustering properties of time series. However, the ARCH model has computational problems when a large number of parameters are needed for a high order model. To solve these problems, Bollerslev [36] proposed the GARCH model, which is one of the most popular models for forecasting the volatility of time series. Since the GARCH models include the conditional variance terms as well as the squared residual terms, the models can predict the volatility well by using a sum of weighted products of the predicted variance from the past.

The GARCH (*p*, *q*) model is defined as the follows.

$$y\_t^2 = w + \sum\_{i=1}^q \alpha\_i \varepsilon\_{t-i}^2 + \sum\_{i=1}^p \beta\_i y\_{t-i\prime}^2 \tag{1}$$

where *ε<sup>t</sup>* = *ytZt*, {*Zt*} is a sequence of independent and identically distributed random variables with zero mean and unit variance, {*εt*} is a sequence of the error terms, the positive parameters *α<sup>i</sup>* and *<sup>β</sup><sup>i</sup>* satisfy the condition <sup>∑</sup>*<sup>q</sup> <sup>i</sup>*=<sup>1</sup> *<sup>α</sup><sup>i</sup>* <sup>+</sup> <sup>∑</sup>*<sup>p</sup> <sup>i</sup>*=<sup>1</sup> *β<sup>i</sup>* < 1 for the stability of the GARCH model. This condition ensures that the conditional variance *yt* has nonnegative values and finite expected value. Here, *w*, *α<sup>i</sup>* and *β<sup>i</sup>* are the estimated parameters by using maximum likelihood estimation.

### 1.2.2. EGARCH Model

The exponential GARCH (EGARCH) model proposed by Nelson [37] allows negative parameters unlike the GARCH model. That is, the parameters of the model have no restrictions to ensure the non-negativity of the volatility. This model can describe the volatility leverage effect which reflects the asymmetric impacts and captures asymmetric behavior of the time series.

The EGARCH (*p*, *q*) model is defined as follows.

$$\log y\_t^2 = w + \sum\_{i=1}^q \alpha\_i \left[ \frac{|\varepsilon\_{t-i}|}{y\_{t-i}} - \sqrt{\frac{2}{\pi}} + \gamma \frac{\varepsilon\_{t-i}}{y\_{t-i}} \right] + \sum\_{i=1}^p \beta\_i \log y\_{t-i}^2,\tag{2}$$

where *α<sup>i</sup>* with no restrictions captures the volatility clustering effect, *β<sup>i</sup>* measures the persistence in conditional volatility irrespective of the events in the market and *γ* measures the asymmetric leverage coefficient to describe the leverage effect of volatility. *αi*, *β<sup>i</sup>* and *γ* are parameters to be estimated.

#### 1.2.3. GJR-GARCH Model

The GJR-GARCH model proposed by Glosten et al. [38] is one of nonlinear GARCH family models to allow for asymmetry effects by integrating a dichotomous variable into the GARCH model. This model allows the larger impact of negative shocks to have a more distinct impact on volatility than a positive impact. The model also presented improved forecasting ability [39].

The conditional variance of GJR-GARCH (*p*, *q*) model is defined as follows.

$$y\_t^2 = w + \sum\_{i=1}^q \left[\alpha\_i + \gamma\_i \mathbf{1}\_{\{\varepsilon\_{t-i} < 0\}}\right] \varepsilon\_{t-i}^2 + \sum\_{i=1}^p \beta\_i y\_{t-i}^2 \tag{3}$$

where

$$\mathbf{1}\_{\{\cdot\}} = \begin{cases} 1, & \varepsilon\_{t-i} < 0, \\ 0, & \varepsilon\_{t-i} \ge 0, \end{cases}$$

and

$$w \ge 0, p \ge 0, q \ge 0, a\_i \ge 0, \beta\_i \ge 0, a\_i + \gamma\_i \ge 0 \quad \text{and} \quad \sum\_{i=1}^p a\_i + \sum\_{i=1}^q \beta\_i + \frac{1}{2} \sum\_{i=1}^q \gamma\_i < 1.$$

where *α<sup>i</sup>* and *β<sup>i</sup>* are similar to the coefficients in the EGARCH model, and *γ<sup>i</sup>* means the asymmetric leverage coefficient. The parameters *w*, *αi*, *β<sup>i</sup>* and *γ<sup>i</sup>* are estimated by the maximum likelihood approach.

#### 1.2.4. Artificial Neural Network (ANN)

ANN is one of the nonparametric nonlinear models which are used widely to overcome the limitations of the linear models in machine learning. ANN is constructed appropriately based on the characteristics extracted from the real data and has no hypothesis about the underlying model. ANN also has at least three layers (input layer, hidden layer, output layer). ANN with single hidden layer used for forecasting is illustrated in Figure 1.

**Figure 1.** The structure of Artificial Neural Network (ANN).

The output result from input layer and hidden layer is generally as follows.

$$\text{output} = f\left(\sum\_{i=0}^{n} x\_i w\_i\right),\tag{4}$$

where *xi* and *wi* represent the set of input data from node *i* and the weight associated with the connection to the node *i*, and *f* is one of the activation functions. The activation functions used in this paper are presented in Table 1. The sigmoid function shows high sensitivity to small changes in input variables. This property provides a good classifier. The hyperbolic tangent function (Tanh) has an advantage over the sigmoid function. Since the derivative of the function is steeper, it will have faster learning and grading. In addition, it is well known that the Rectified Linear Unit (ReLU) is a good estimator and show very efficient calculation when all neurons are activated in the same manner. Exponential Linear Unit (ELU) provides fast learning because ELU shrinks the difference between the unit natural gradient and the normal gradient.


**Table 1.** Activation functions used in this paper.

The main work of ANN is to find the optimal weights for better performance using the activation functions. We use the back-propagation method to obtain the weights. We also carry out many experiments with four activation functions to find the best forecasting model.

#### 1.2.5. Higher Order Neural Network (HONN)

HONN proposed by Giles and Maxwell [40] has been widely used to simulate the higher-order nonlinear inputs and to provide some basis for the simulations as 'open box' [41]. Because first-order networks do not take advantage of meaningful relationships between the input variables, the networks need a lot of training passes with a large training set. To improve this disadvantage, HONN has been developed. In general, with the selection of good input variables, it is known that HONN provides better forecasting performance than the classic ANN.

In Equation (4), the independent variable is presented as the linear combination. Specifically, the variable is expressed by multiplying each input variable (*xi*) by a weight (*wi*) and adding the results. We can easily make out the higher-order terms of the inputs from the first-order terms. Here, we consider the second order HONN to improve the volatility forecasting. Let us define the input vector *x* and the weight vector *w*by

$$\vec{x} = [\mathbf{x}\_0, \mathbf{x}\_1, \dots, \mathbf{x}\_n] \text{ and } \vec{w} = [w\_0, w\_1, \dots, w\_n]\_\prime$$

respectively. Then the input vector *x<sup>h</sup>* and the weight vector *w<sup>h</sup>* in HONN are given by

$$\vec{\mathbf{x}}\_{\parallel} = [\mathbf{x}\_{\emptyset}, \mathbf{x}\_{1}, \dots, \mathbf{x}\_{\emptyset}, \mathbf{x}\_{\emptyset}^{2}, \mathbf{x}\_{\emptyset} \mathbf{x}\_{1}, \mathbf{x}\_{\emptyset} \mathbf{x}\_{2}, \dots, \mathbf{x}\_{\text{,}}, \mathbf{x}\_{\emptyset}, \mathbf{x}\_{\emptyset}^{2}] \text{ and } \ \vec{\mathbf{u}}\_{\parallel} = [\mathbf{u}\_{\emptyset}, \mathbf{u}\_{1}, \dots, \mathbf{u}\_{\emptyset}, \mathbf{u}\_{0}, \mathbf{u}\_{01}, \mathbf{u}\_{02}, \dots, \mathbf{u}\_{\emptyset - 1 \text{tr}}, \mathbf{u}\_{\emptyset}], \text{ and } \mathbf{u}\_{\emptyset} = \mathbf{u}\_{\emptyset}$$

respectively. From these vectors, the output with the activation functions *f* can be calculated as follows.

$$\text{output} = f\left(\vec{w}\_h \cdot \vec{x}\_h\right) = f\left(\sum\_{i=0}^n w\_i \mathbf{x}\_i + \sum\_{i=0}^n \sum\_{j=i}^n w\_{ij} \mathbf{x}\_i \mathbf{x}\_j\right). \tag{5}$$

The structure of a second-order HONN used in this paper is illustrated in Figure 2. We construct the hybrid models based on this second-order HONN for the accurate forecasting.

**Figure 2.** The structure of Higher Order Neural Network (HONN).

#### **2. Material and Methods**

The time series data analyzed in this paper were the daily historical prices of Bitcoin over the period between 1 January 2012 and 30 November 2019. The data were downloaded from the website (https://bitcoincharts.com/). To define the volatility of Bitcoin price, the closing prices *pt* at time *t* are transformed into log return *rt* = log *pt* − log *pt*−1. The realized volatility of Bitcoin was computed as the variance of *rt*, and the realized volatilities in a 5-day window as weekly volatilities are used to analyze the volatility of Bitcoin in this paper. Then, the realized volatility (*RVt*) of Bitcoin at time *t* is computed as

$$RV\_t = \frac{1}{5} \sum\_{i=t+1}^{t+5} (r\_i - \overline{r}\_t)^2 \rho$$

where *rt* is mean of *rt* during 5 days after time *t*.

In order to improve the accuracy of the volatility forecast, the selection of the input data which influence on the volatility of Bitcoin is very important. In this paper, we consider the GT data and VIX data as the explanatory variables. GT is the data that presents the popularity of search queries related to various sectors in Google. In fact, GT data has been used as explanatory variables in the ANN to forecast of the financial time series by many researchers [34,42–44]. We used 'Bitcoin' GT data as the input variable, which is a good measure to describe the Bitcoin market [45]. VIX index introduced the Chicago Board Options Exchange (CBOE) in 2004 extrapolates the future volatility from the liquid options written on the S&P 500 and is calculated as the square root of the risk-neutral expectation of the 30 days variance of the S&P 500 return which is estimated by the forward option price expiring in 30 days. From the previous works [46,47], we can find the significant relationship between the VIX index and Bitcoin. Thus, we choose the VIX index as the input data to the ANN-based on the researches. Specifically, 5-days moving averages of VIX index and GT data are used as the input data. In Figure 3, the time series of log return *rt* of Bitcoin price are displayed. Figures 4 and 5 illustrate the realized volatility of bitcoin price and VIX index, respectively.

**Figure 3.** Log return *rt* of Bitcoin price from 1 January 2012 to 30 November 2019.

**Figure 4.** Realized volatility *RVt* of *rt* afrom 1 January 2012 to 30 November 2019.

**Figure 5.** VIX index from 1 January 2012 to 30 November 2019.

In order to construct a more accurate model for forecasting of Bitcoin volatility, we use the 1-day lagged weekly volatility (*LVt*) as the endogenous variable and the outputs of GARCH family models as the exogenous variables. In other words, *LVt* and GARCH family outputs are used as the input variables to improve the forecasting ability of the hybrid model. Here, the outputs of the GARCH models introduced in the previous section are used, and *LVt* is calculated by

$$LV\_t = \frac{1}{5} \sum\_{i=t-1}^{t-5} (r\_i - \overline{r}\_t)^2. \tag{6}$$

,

Note that days in windows of *LVt* have no intersection with 5 days in windows of *RVt*. *LVt* is displayed in Figure 6. In this study, 80% of the data set (in-sample: 2012.01.01–2018.04.30) are used for training, and 20% (out-of-sample: 2018.05.01–2019.11.30) of the data set are used for testing. All experiments are implemented using Python 3. Additionally, we utilize three measures to compare the performance of the proposed models. These measures are the mean absolute error (MAE), the root mean square error (RMSE) and the mean absolute percentage error (MAPE) and as follows.

$$\begin{aligned} \text{MAE} &= \frac{1}{n} \sum\_{t} |\hat{\sigma}\_{t} - RV\_{t}|, \\ \text{RMSE} &= \left( \frac{1}{n} \sum\_{t} \left( \hat{\sigma}\_{t} - RV\_{t} \right)^{2} \right)^{1/2}, \\ \text{MAPE} &= \frac{1}{n} \sum\_{t} \left| \frac{\hat{\sigma}\_{t} - RV\_{t}}{RV\_{t}} \right|, \end{aligned}$$

where *σ*ˆ*<sup>t</sup>* is the predicted volatility of Bitcoin and *n* is the number of the predicted data. Obviously, the lower values of the measures, the better accuracy of the model. For more details, see [48].

**Figure 6.** Lagged volatility *LVt* of *rt* from 1 January 2012 to 30 November 2019.

### **3. Hybrid Models and Results**

In this paper, we propose several hybrid models based on GARCH family models, ANN and HONN to find a more accurate model for forecasting of Bitcoin volatility. Specifically, the hybrid models are constructed with the ANN by using the selected GARCH models and the selected explanatory variables. The models are implemented by the ANN with a single hidden layer and various neurons using the back-propagation method and classified according to whether including the explanatory variables or not. The proposed models are used for 1-day ahead forecast of weekly realized volatility, and then the best model is determined by comparing the results.

We compare the proposed models to find the best volatility forecasting model in the bitcoin market. We first forecast the volatility of Bitcoin price using the classic GARCH family models. Concretely, we use GARCH, EGARCH and GJR-GARCH model among the GARCH family models and the (*p*, *q*) parameters ranging from (1,1) to (3,3). In order to find the optimal GARCH model for the hybrid model, we provide AIC and BIC values in Table 2 and three measures to compare the performances of the models for forecasting volatilities in Table 3. According to the results in Table 2 and AIC and BIC criteria, EGARCH(3,3) model is the best model. On the other hand, according to the results in Table 3, we can see that the GJR-GARCH(1,1) model performs the best among the introduced GARCH family models.




**Table 3.** GARCH models performance.

Other models except for the classic GARCH models are based upon the ANN approach or the HONN approach. In other words, the models are constructed by using the selected input variables to ANN or HONN. Similar to [31,34], we propose the ANN-GARCH models for the forecasting of the Bitcoin volatility using the outputs of the GARCH family models. Specifically, we define the GT-GARCH model and GT-VIX-GARCH model according to the input variables. The input variables of the models are in Table 4. In order to find the optimal number of nodes in the hidden layer and the activation function for the models, we carry out the experiments using the Adam optimizer method [49] to update the network weights. The results are indicated with four activation functions in Tables 5 and 6. As shown in Tables 5 and 6, two measures (MAE, RMSE) show that the GT-GARCH model is better than the GT-VIX-GARCH model, and one measure (MAPE) shows a different result. From these results, we can not find a significant performance difference between the GT-VIX-GARCH model and the GT-GARCH model. That is, we conclude that two models may have a similar predictive ability. To improve the accuracy of the model, we adopt the HONN approach. Specifically, we propose three types of hybrid models (GT-H model, GT-VIX-H model, GT-VIX-GARCH-H model) based on the HONN.

Tables 7–9 are presented the results of the models based on the HONN. To examine well the proposed models based on the HONN, we present a summary of the input variables of each model in Table 10. In Table 10, '*LVt*' is in Equation (6), 'GT' means Google trends data, 'VIX' means VIX index data, 'GJR-GARCH(1,1)' means forecast by GJR-GARCH(1,1) and 'EGARCH(3,3)' means forecast by EGARCH(3,3). Tables 7 and 8 present the results of the HONN model without the outputs of GARCH models as shown in Table 10. We can see that MAE and MAPE in Tables 7 and 8 increase in all cases as compared to the values in Tables 5 and 6. That is, GT-H model and GT-VIX-H model do not show better performance compared to the models based on the ANN. To improve the model, we adopt the HONN model with the outputs of GARCH family models. Among the introduced GARCH models, we chose GJR-GARCH(1,1) and EGARCH(3,3) from the results in Tables 2 and 3. By using the outputs of GJR-GARCH(1,1) and EGARCH(3,3) as input variables in the HONN, we finally construct and propose a new type of hybrid model (GT-VIX-GARCH-H model) for better forecasting of Bitcoin volatility.



**Table 5.** GT-GARCH model performance.

**Table 6.** GT-VIX-GARCH model performance.



**Table 7.** GT-H model performance.

**Table 8.** GT-VIX-H model performance.



#### **Table 9.** GT-VIX-GARCH-H model performance.

**Table 10.** Input variables of models.


Table 9 shows the results of three performance measures obtained by the GT-VIX-GARCH-H model. We can see the improvement in forecasting accuracy in Table 9. The results in Table 9 show that the hybrid models with selected GARCH models based on the HONN model for volatility forecasting of Bitcoin reduce the performance measures (MAE, RMSE, MAPE). That is, in all cases, the measures decrease compared to the measures of the other models. More specifically, compared to the GJR-GARCH(1,1) forecast, MAE is reduced by 11 %, MAPE is reduced by 30 %. Furthermore, we analyze the robustness of our results to determine whether the proposed models are statistically significant. For the analysis, we apply the MCS test [50] to GT-VIX-GARCH-H models. The detailed results of the MCS test, which can be interpreted as a level of confidence for the forecasts, are presented in Table 11. According to the results in Table 11, we can find that the GT-VIX-GARCH-H model with the Relu function and 30 nodes, which has the lowest MAE, is the best model for forecasting of Bitcoin volatility.


### **4. Concluding Remarks**

We develop the models based on the neural networks for forecasting volatility of Bitcoin price in this paper. Specifically, we propose several hybrid models to improve the forecasting and conduct more than 10,000 experiments to find the optimized model. We investigate as follows. Firstly, we construct the ANN-GARCH models with 1-day lagged volatility, Google Trends, VIX and outputs of GARCH models based on the previous works. Secondly, we propose the new hybrid models which incorporate the outputs of GARCH models as input to HONN model. HONN model, which use the linear combinations of the variables as the input variables, is efficient and performs generally better than the classic ANN mode when the number of good input variables for the ANN model is small. In fact, most of the proposed hybrid models show good performances with no statistical difference, but we focus on finding the best forecasting model for Bitcoin's volatility.

In order to find the best model among the proposed models, we carry out many experiments changing the activation functions and the number of nodes. We also adopt three performance measures to compare the forecasting accuracy of the proposed models. Consequently, the hybrid models based on the HONN model which can capture higher-order correlations in input variables show the improved performance for forecasting of Bitcoin volatility. Compared to the best GARCH model, the best GT-VIX-GARCH-H model improves by 11%, 2.2% and 30% for MAE, RMSE and MAPE, respectively. In addition, compared to the best ANN-GARCH model, the best GT-VIX-GARCH-H model improves by 2.2%, 2.5% and 3.9% for MAE, RMSE and MAPE, respectively. In other words, these results show that the hybrid models based on the HONN model provide more accurate forecasting results and are appropriate for forecasting of volatility in the Bitcoin market.

**Author Contributions:** G.K. designed the experiments; M.S. collected and analyzed the data; M.S. and G.K. contributed analysis tools; M.S. and G.K. wrote the paper. All authors have read and agreed to the published version of the manuscript.

**Funding:** This work is supported by the National Research Foundation of Korea grant funded by the Korea government (No. NRF-2017R1E1A1A03070886).

**Conflicts of Interest:** The authors declare no conflict of interest.

#### **References**


© 2020 by the authors. Licensee MDPI, Basel, Switzerland. This article is an open access article distributed under the terms and conditions of the Creative Commons Attribution (CC BY) license (http://creativecommons.org/licenses/by/4.0/).

## *Article* **A Comparison of Time-Series Predictions for Healthcare Emergency Department Indicators and the Impact of COVID-19**

**Diego Duarte 1,2, Chris Walshaw 2,\* and Nadarajah Ramesh <sup>2</sup>**


**Featured Application: This application is being developed as part of a suite of tools used by the National Health Service (NHS) to analyse and predict pressure in resource management indicators (see transformingsystems.com for more details).**

**Abstract:** Across the world, healthcare systems are under stress and this has been hugely exacerbated by the COVID pandemic. Key Performance Indicators (KPIs), usually in the form of time-series data, are used to help manage that stress. Making reliable predictions of these indicators, particularly for emergency departments (ED), can facilitate acute unit planning, enhance quality of care and optimise resources. This motivates models that can forecast relevant KPIs and this paper addresses that need by comparing the Autoregressive Integrated Moving Average (ARIMA) method, a purely statistical model, to Prophet, a decomposable forecasting model based on trend, seasonality and holidays variables, and to the General Regression Neural Network (GRNN), a machine learning model. The dataset analysed is formed of four hourly valued indicators from a UK hospital: Patients in Department; Number of Attendances; Unallocated Patients with a DTA (Decision to Admit); Medically Fit for Discharge. Typically, the data exhibit regular patterns and seasonal trends and can be impacted by external factors such as the weather or major incidents. The COVID pandemic is an extreme instance of the latter and the behaviour of sample data changed dramatically. The capacity to quickly adapt to these changes is crucial and is a factor that shows better results for GRNN in both accuracy and reliability.

**Keywords:** healthcare; COVID; time-series predictions; machine learning; ARIMA; Prophet; GRNN

### **1. Introduction**

Across the world, healthcare systems are under stress and this has been hugely exacerbated by the COVID pandemic. Key Performance Indicators (KPIs), usually in the form of time-series data, are used to help manage that stress. Making reliable predictions of these indicators, particularly for emergency departments (ED), can help to identify pressure points in advance and also allows for scenario planning, for example, to optimise staff shifts and planning escalation actions.

According to Medway Foundation Trust (MFT), where the results of this study are applied, it is important to accurately forecast after exceptional events in their data, such as the pandemic, because forecasts are of increased importance at these critical moments. When the uncertainty level is greater, correct predictions may benefit their decision making more than usual (although models must also perform well in non-exceptional circumstances).

When analysing the healthcare systems, great significance has been placed on predicting patient arrivals in acute units, and in particular emergency department (ED) attendances and throughput. Typically, patients arrive at irregular intervals, often beyond the control of the hospital, and arrivals show strong seasonal and stochastic fluctuations driven by factors such as weather, disease outbreaks, day of the week and socio-demographic

**Citation:** Duarte, D.; Walshaw, C.; Ramesh, N. A Comparison of Time-Series Predictions for Healthcare Emergency Department Indicators and the Impact of COVID-19. *Appl. Sci.* **2021**, *11*, 3561. https://doi.org/ 10.3390/app11083561

Academic Editors: Anton Civit and Grzegorz Dudek

Received: 5 November 2020 Accepted: 25 March 2021 Published: 15 April 2021

**Publisher's Note:** MDPI stays neutral with regard to jurisdictional claims in published maps and institutional affiliations.

**Copyright:** © 2021 by the authors. Licensee MDPI, Basel, Switzerland. This article is an open access article distributed under the terms and conditions of the Creative Commons Attribution (CC BY) license (https:// creativecommons.org/licenses/by/ 4.0/).

effects [1]. Accordingly, researchers use a variety of methods to predict ED visits over various periodic intervals, e.g., [2,3].

Furthermore, predicting KPIs in EDs can depend on several factors, e.g., [4,5]. Typically, EDs deal with a variety of life-threatening emergencies, arising from causes such as car accidents, disease, pollution, work-place hazards, aging, changing population, pandemics and many other sources. However, multivariate models usually consider only a small number of dependent/independent variables out of numerous possibilities, and in ED that can result in dangerous omissions. For example, variables that are not captured (especially those with high correlations, multi-variable correlation lags and high levels of noise) may drive multi-variate predictions of captured variables in the wrong direction [6].

For predictions of this type, a common approach is to auto-correlate the data. Since all the numerous variables have been an influence on the past values of the time-series under investigation, prediction models can use these data as a baseline to calculate future timeseries values. Apart from this baseline, other variables, both dependent and independent, and represented by other indicators, may be included in the model by considering the covariance between them.

Accurate and reliable forecasts of various types of indicator can help to efficiently allocate key healthcare resources, including staff, equipment and vehicles, when and where they are most needed [2], and avoid allocating them at non-critical periods. This justifies considerable efforts to make reliable and accurate predictions of ED data to help hospital managers in making decisions about how to meet the expected healthcare demand effectively and in a timely fashion [1].

As a linear model, ARIMA can typically capture linear patterns in a time series efficiently, and so many studies adopt it either to evaluate relationships between variables, or use ARIMA forecasts as a benchmark against which to test the effectiveness of other models, e.g., [7].

This has been applied to ED data previously and as examples, Sun et al. [3], Afilal et al. [8], and Milner [9] have developed ARIMA models for forecasting hospital ED attendances and have verified the fitness of ARIMA, in that context, as a readily available tool for predicting ED workload. In particular, Milner [9] used ARIMA to forecast yearly patient attendances at EDs in the Trent region of the UK and found that the forecast for attendances was just 3 per cent away from the actual figure. Meanwhile, Sun et al. [3] were able to improve on the ARIMA results for daily patient attendances in Singapore General Hospital by incorporating other variables such as public holidays, periodicities and air quality.

Another option is to use proprietary models. One such, Prophet, is a forecasting model developed at Facebook Research, an R&D branch of the company. Taylor & Letham [10] observed that ARIMA forecasts can be prone to large trend errors, particularly when a change in trend occurs close to the cut-off period, and often fail to capture any seasonality. In experiments with Prophet, Taylor and Letham [10] predicted the daily number of Facebook events in a 30-day horizon with 25% more accuracy than ARIMA. The accurate prediction of these daily Facebook events, which are captured by an indicator with similar characteristics as some of the healthcare indicators, particularly in terms of in trend, seasonality and holidays, and which are at the heart of the Prophet model decomposition, motivated their use in this study.

Finally, a number of Machine Learning models are used to predict data, specifically for time-series processes, and in particular Neural Network variants have presented encouraging results for classification problems, e.g., [11,12]. In this study, after fitting models of LSTM (Long Short-Term Memory), RNN (Recurrent Neural Network) and RBN (Radial Basis Network), we present the results for the model found to best fit the data sample in our simulations, that is RBN, and in particular GRNN (Generalised Regression Neural Networks).

The rest of the paper is organised as follows: Sections 2–4 give more details of ARIMA, Prophet and GRNN, respectively. Section 5 describes the data and Section 6 gives details of how COVID has impacted on it. The main parts of the paper are Section 7 which describes

the experimentation, and Section 8 that analyses those results. Finally, Section 9 presents conclusions and suggestions for further work.

#### **2. Autoregressive Integrated Moving Average (ARIMA)**

The Autoregressive Integrated Moving Average (ARIMA) model integrates Auto Regressive and Moving Average calculations. A basic requirement for predicting timeseries with ARIMA is that the time-series should be stationary or, at the very least, trendstationary [3,7]. A stationary series is one that has no trend, and where variations around the mean have a constant amplitude, e.g., [1,13,14]. Although ARIMA expects a stationary stochastic process as input, very few datasets are natively in such format, thus the use of differencing to "stationarise" is in the model identification stage [1,7] (Figure 1).

The ARIMA model has 3 key parameters (*p*, *d*, *q*), all non-negative integers: *p* is the order (the number of time lags) of the autoregressive model, *d* is the degree of differencing (the number of times the data have had past values subtracted if random walk is chosen), and *q* is the order of the moving-average model. The values for *p*, *d*, and *q* are defined by calculations and subsequent analysis and the process of fitting an ARIMA model, by calculating these parameters, is commonly known as the Box-Jenkins method [15].

The integration of the *p* autoregressive and *q* moving average terms, gives rise to the formula [15]:

$$\hat{y} = \mathfrak{c} + \mathfrak{e}\_{\mathfrak{t}} + \sum\_{i=1}^{p} \phi\_{i} y\_{t-i} + \sum\_{i=1}^{q} \theta\_{i} \mathfrak{e}\_{t-1}$$

where *c* represents a constant variable, error terms *<sup>t</sup>* are generally assumed to be independent, uniformly distributed random variables, *yt* are previous observed values of the time-series and *φ<sup>i</sup>* and *θ<sup>i</sup>* are coefficients determined as part of the process.

Each indicator under analysis here has been through the three-stage ARIMA/Box-Jenkins iterative modelling approach (Figure 1):


### **3. Prophet**

Prophet is a forecasting model from Facebook Research. As discussed by Taylor & Letham, forecasting is a data science task, central to many activities within a large organization such as Facebook, and crucial for capacity planning and the efficient allocation of resources [10].

Prophet is motivated by the idea that not every prediction problem can be tackled using the same solution, and has been created with the aim of optimising business forecasting tasks encountered at Facebook that typically have some or all of the following characteristics [10]:


Prophet uses a composite time-series model to predict *y*(*t*), with three main component, trend, seasonality, and holidays, that are combined as:

$$y(t) = g(t) + s(t) + h(t) + \varepsilon\_t$$

Here, *g*(*t*) is the trend function, *s*(*t*) models seasonal changes, and *h*(*t*) represents holidays, whilst *ε<sup>t</sup>* represents an error term which is expected to be normally distributed [10].

### *3.1. Trend Model*

Prophet's main forecast component is the trend term, *g*(*t*), which defines how the time-series has developed previously and how it is expected to continue. Two choices are available for this trend function depending on the data characteristics: a Nonlinear Saturating Growth model, where the growth is non-linear and expected to saturate at a carrying capacity, and a Linear Model, where the rate of growth is stable [17].

In the case of the experiments in Section 7, the choice of trend function was determined automatically by Prophet.

### *3.2. Seasonality*

Prophet also considers seasonal data patterns arising from similar behaviour repeated over several data intervals. To address this component of the model Prophet employs Fourier series' to capture and model periodic effects [17].

This is very appropriate for the data under investigation which may exhibit multiple seasonal patterns. For example, hospital ED departments are typically busier every day between 5 p.m. and 8 p.m. from a variety of causes such as people suffering injuries when commuting, taking relatives to hospital after working hours, or injuries when exercising, etc. Meanwhile longer nested trends can arise as hospital EDs are usually busier in winter as compared with summer due to respiratory illness, slippery conditions, etc.

### *3.3. Holidays and Events*

Holidays and special events provide a relatively significant, and normally predictable changes in time-series. Normally these do not have a regular periodic pattern (unlike, say, weekends) and thus the effects are not well modelled by a smooth cycle.

To address this component Prophet offers the functionality to include a custom list of holidays and events, both past and future [10]. However, this feature has not been tested in the investigations in Section 7.

### **4. General Regression Neural Network (GRNN)**

The General Regression Neural Network (GRNN) proposed by Specht [18] is a feedforward neural network, responding to an input pattern by processing the input data from one layer to the next with no feedback paths. GRNN is a one-pass learning algorithm with a highly parallel structure. Even in the case a time-series has sparse observations with a non-stochastic process, the model outputs smooth transitions between resulting observations. The algorithmic form may be used for any regression problem provided that linearity is not justified.

The GRNN is considered a type of RBF (Radial Basis Function) neural network, that employs a fast, single pass learning. It consists of a hidden layer with RBF neurons. Typically, the hidden (Pattern) layer has the number of neurons similar to training examples. The center of a neuron is the linked training example, thus the output provides a measure of the closeness of the input vector to the training example. A subsequent summation layer is added to compute the results (Figure 2).

Normally, a neuron uses the multivariate Gaussian function [19].

$$\mathcal{G}(\mathbf{x}, \mathbf{x\_i}) = \exp\left(-\frac{||\mathbf{x} - \mathbf{x\_i}||^2}{2\sigma^2}\right)$$

where xi is the centre, *σ* the smoothing parameter and *x* the input vector.

Considering a training set of size *n*, patterns {x1 ... xn}, and their associated targets, {y1 ... yn}, the output is calculated in 2 stages, first the hidden layer produces weights *wi* representing levels of similarity of xi to training patterns:

$$w\_i = \frac{\exp\left(-\frac{||\mathbf{x} - \mathbf{x\_i}||^2}{2\sigma^2}\right)}{\sum\_{j=1}^n \exp\left(-\frac{||\mathbf{x} - \mathbf{x\_j}||^2}{2\sigma^2}\right)}$$

The further away the training pattern is, the smaller the effect in the weight. The total sum of weights corresponds to 1, representing the proportional strength in the result.

The smoothing parameter *σ*, provides control on how many parameters are to be considered in the calculation and is an important part of model fitting, as it depends on the level of correlation and lags.

The second stage is the resulting calculation of future values *y*(*t*):

$$y(t) = \sum\_{i=1}^{n} w\_i y\_i$$

The resulting calculation is nothing more than a weighted average of training targets.

**Figure 2.** Schematic diagram of generalized regression neural networks (GRNN).

### *4.1. Multi-Step Ahead Strategy*

When predicting an hourly time-series of ED data, such as the Number of Patient Attendances, it is fair to expect to forecast a number of hours ahead, ideally every hour of the next seven days. For this a multi-step ahead strategy must be employed. There are two options, the MIMO (Multiple Input Multiple Output) strategy, and a Recursive strategy.

The MIMO strategy for predicting a series of future observations consists in employing training targets vectors with consecutive observations of the time-series, with the size of those vectors corresponding to the number of predicted observations desired.

The Recursive strategy is based in the one step ahead model as provided by the above equations. The model is applied recursively, matching the desired number of steps ahead, feeding predicted values as input variables. The recursive strategy was the option utilised in the simulations of GRNN in this study: the motivation was not to depend on training target vectors, as this can potentially demand frequent re-training.

### *4.2. Controlling Seasonality by Autoregressive Lags*

The GRNN model does not contain an embedded seasonal parameter, but this can be addressed by the use of autoregressive lags, grouping correlated data into the training vector. The lag definition and the smoothing parameters can be indicated by the AC (Auto Correlation) and PAC (Partial Auto Correlation) functions as discussed in Section 2.

#### **5. Data Description**

The choice of the most valuable Emergency Department Key Performance Indicators (ED KPIs), with the intention of capturing pressure in an acute unit, is an inter-disciplinary issue. Input from healthcare professionals, particularly hospital managers, and data scientists, was sought for this study. However, the KPIs must be chosen from data that are already being captured and not every theoretically useful indicator may be readily available.

The four KPIs under analysis, chosen from those available, are shown in Table 1. These data have been provided by Medway Foundation Trust (MFT), located in Kent, South East England and, in particular, focuses on the MFT Hospital ED Acute Unit. The time-series frequency of all 4 the KPIs is hourly.

**Table 1.** Key Performance Indicator Descriptions.


KPIs of this nature are key metrics in evaluating NHS performance and related KPIs used as targets. For example, the NHS Constitution sets out that a minimum of 95% of patients attending an A&E department should be admitted, transferred or discharged within 4 h of their arrival [20].

#### **6. COVID Impact**

The COVID pandemic has impacted UK society, and hence the KPI data, dramatically. The first government advice on social distancing was published on 12 March 2020, before a formal lockdown was announced on 23 March 2020, led to a huge reduction in consumer demand in certain sectors, closure of business and factories and disrupted supply chains [21].

In preparation for the pandemic, hospitals invested in expanding resources that were necessary to treat a high volume of patients with the infection, including respirators, PPE (Personal Protective Equipment), staff, protocols and treatments [22].

During March, NHS trusts rapidly re-designed their services on a large scale to release capacity for treating patients with COVID-19. This included discharging thousands to free up beds, postponing planned treatment, shifting appointments online where possible and redeploying staff, a process covered widely in the media. NHS England alone published more than 50 sets of guidance to hospital specialists for the treatment of non-COVID-19 patients during the pandemic [22].

The impact of such measures was easily observed in the data, simulations and results of this study. Due to the fact of the lockdown added to the fear of contracting COVID, the chosen KPIs had a sudden and dramatic reductions in absolute numbers, as fewer were going to the ED.

Not all prediction models reviewed in this paper were able to immediately account for the sudden changes of this nature and hence their accuracy and reliability dropped drastically. In particular immediately obvious inaccuracies in the models are due to the fact they mostly rely on auto-regressive and seasonal parameters of 6 to 8 weeks prior to the current time. Thus, it usually took 3 to 6 weeks for the models to learn and adapt to the change depending on the indicator. However, well-trained machine learning models present encouraging results, as can be seen below.

### **7. Simulations and Results**

Simulations with the three presented forecasting models have been performed to compare the relative accuracy and reliability when applied to the four chosen indicators. ARIMA, traditionally the standard model for healthcare time-series predictions, is used as a benchmark.

Observing the ACF (Auto Correlation Function) and initial model training, a strong correlation was found by hour and weekday, so data have been classified by these parameters when applying the models. Further to the classification, it was observed that the autocorrelation of the data was stronger in the 8 weeks prior to the forecasting, as the residuals when utilizing this interval were the smallest in simulations. Summarizing, each prediction *y*ˆ was calculated using a rolling actuals timeseries subset of the 8 previous observations, classified with a lag = 168.

The subset of 8 observations means that when applying GRNN, the size of the input layer was equal to 8, although the number is small, it presented smaller residuals than trials with sizes of input layer ranging from 12, 20 and 30, this is believed to be due to the higher correlation of the data observed in the ACF. The sigma parameter for the GRNN was chosen to be 0.3, which was found to be the best fit in preliminary experimentation.

Each model has been trained until its residuals were minimal in simulations, then parameters fixed and subsequently applied to the same rolling actuals time-series subset, of the same data source, with equal periods and frequency. Predictions were then compared with actual values and analysed for accuracy and reliability over the period from 1 January 2020 to the end of October 2020. Each prediction made is hourly and has an event horizon of *h* = 168 h, corresponding to predicting every hour for 7 days in advance.

The results are presented in a number of figures showing actual observed values, predictions and residuals and then analysed in Section 8.

### *7.1. Patients in Department*

Figure 3 shows a comparison of actual observed data (in black) for the Patients in Department indicator compared with the ARIMA (green), Prophet (yellow) and GRNN (purple) predictions. Here, the red coloured square indicates the start of lockdown due to the COVID-19 pandemic. Three regions of the chart can be clearly distinguished (to the left, inside and to the right of the square) and indicate the normal number of Patients per hour in the Medway Foundation Trust ED before, during and after the lockdown, the sudden change, and the way data are slowly growing back in an apparent linear trend, from May to June. It is particularly interesting to see that ARIMA and Prophet overestimate their predictions for a five-week period, whilst GRNN is quick to adjust to the data ingress change. Another interesting observation is that the variance of the data is reduced after the lockdown, making the data easier to predict for all models. This is confirmed in the following figures which show predictions and actuals (top) and residuals (bottom) in more detail for each model.

**Figure 3.** Comparison of Actuals, ARIMA, Prophet and GRNN, February 2020–June 2020.

In Figure 4 it can be observed that the ARIMA model predicts the Patients in Department time-series with a fair accuracy until the lockdown period, where the residuals

negatively increase. Then, after five weeks the model learns the new levels and the residuals are again back to acceptable levels.

**Figure 4.** Comparison of Actual and ARIMA with residual distribution chart, January 2020–November 2020.

In Figure 5 it can be observed that Prophet, similar to ARIMA, predicts the Patients in Department indicator with fair accuracy until lockdown, during which the predictions are overestimated and the residuals increase by an even greater margin than ARIMA. However, after five weeks the model learns the new levels and the residuals are again back to acceptable levels.

**Figure 5.** Comparison of Actual and Prophet with residual distribution chart, January 2020–November 2020.

In Figure 6 it can be observed that GRNN, in contrast to the ARIMA and Prophet models, is able to quickly adapt to the change in observed values of Patients in Department due to COVID-19, keeping a good level of accuracy throughout the whole data sample.

### *7.2. Patient Attendances*

Figure 7 shows a comparison of actual (black) data compared to ARIMA (green), Prophet (yellow) and GRNN (purple) predictions for the Patient Attendances indicator. The observed data contain a greater variance than Patients in Department, and this can alter the performance of prediction models. The COVID-19 pandemic caused a similar change in data behaviour to Patient Attendances, and GRNN was also the only model that quickly adapted to the change and accurately predicted the timeseries variables during this period.

**Figure 7.** Comparison of Actuals, ARIMA, Prophet and GRNN, February 2020–June 2020.

The following figures breakdown this information and show the residual distributions. Observing the residuals charts at the bottom of Figures 8–10, in the pre-lockdown phase ARIMA appears to be the best fit model as the residuals are concentrated closer to the x axis, whilst GRNN appears to have greater residuals and less accuracy in this phase. However, when the pandemic changed the data, GRNN was the only one of the three models to quickly adapt and predict with accuracy almost instantaneously, whilst both ARIMA and Prophet overestimated for approximately a 5 weeks period. In the postlockdown phase, after all three models learned the new data behaviour, the performance is similar.

**Figure 8.** Comparison of Actual and ARIMA with residual distribution chart, January 2020–November 2020.

**Figure 9.** Comparison of Actual and Prophet with residual distribution chart, January 2020–November 2020.

**Figure 10.** Comparison of Actual and GRNN with residual distribution chart, January 2020–November 2020.

#### *7.3. Unallocated Patients with DTA*

Figure 11 shows a comparison of actual (black) data compared to ARIMA (green), Prophet (yellow) and GRNN (purple) predictions for the Unallocated Patients with DTA indicator. This indicator is perhaps the most challenging to predict in this study, as it includes a native and "variable" seasonal factor in that typically the decisions to admit patients are made in specific periods of the day, i.e., at the beginning of the doctors' shifts, after their first meetings with patients. As doctors' shifts commence at different times (unknown by the models), the lag factor of the seasonality is not constant, making the indicator somewhat unpredictable. This is the main reason residuals are typically greater than for other indicators.

**Figure 11.** Comparison of Actuals, ARIMA, Prophet and GRNN, February 2020–June 2020.

Looking at the residuals charts at the bottom of Figures 12–14, the unpredictability mentioned can be clearly be observed by the high residuals in the pre-COVID phase for all three analysed models (ARIMA, Prophet and GRNN). During the lockdown affected period, as expected only GRNN predicts values reliably. In the post-lockdown learning period, all three models are able to reliably predict Unallocated Patients with DTA similarly, until October, when data change behaviour again, and ARIMA seems to explain the variables better.

**Figure 12.** Comparison of Actual and ARIMA with residual distribution chart, January 2020–November 2020.

**Figure 13.** Comparison of Actual and Prophet with residual distribution chart, January 2020–November 2020.

**Figure 14.** Comparison of Actual and GRNN with residual distribution chart, January 2020–November 2020.

### *7.4. Medically Fit for Discharge*

Figure 15 shows a comparison of actual (black) data compared to ARIMA (green), Prophet (yellow) and GRNN (purple) predictions for the Medically Fit for Discharge indicator. This indicator, unlike Unallocated Patients with DTA presents a very uniform seasonality and is thus a candidate for higher accuracy and reliability predictions. The reason is that doctors typically discharge patients at the same fixed times every day and this creates a very uniform lag factor, that is important to all auto-regressive processes. However, once again in the red square, it can be observed that GRNN is the only prediction model that can quickly adapt and provide reliable predictions during the lockdown period. Finally, there are some obvious outliers in the data, especially in February, but investigation has shown that these are due to data quality issues rather than unusual occurrences.

**Figure 15.** Comparison of Actuals, ARIMA, Prophet and GRNN, February 2020–June 2020.

Once again, the residual charts in Figures 16–18 show very small residuals for all three models, apart from the lockdown period, where GRNN is the only model that can adapt rapidly. This clear accuracy advantage is also highlighted by the red squares in Figures 3, 7, 11 and 15. This may be dependent on the training, especially the size of the input layer, which, although the same for all models (8), by being relatively small may benefit GRNN, when the data change behaviour, and allow it a more agile response.

Finally, in the prediction chart of Figure 18, a few GRNN predictions are missing, and this is due to the model being sensitive to missing input of actual observations in this indicator time-series.

**Figure 16.** Comparison of Actual and ARIMA with residual distribution chart, January 2020–November 2020.

**Figure 17.** Comparison of Actual and Prophet with residual distribution chart, January 2020–November 2020.

**Figure 18.** Comparison of Actual and GRNN with residual distribution chart, January 2020–November 2020.

#### **8. Reliability and Accuracy Analysis**

Time-series predictions are commonly evaluated by looking at the "residuals", as shown above, which are typically measured as the difference between actual values and predictions, often using the root mean squared error (RMSE).

#### *8.1. Analysis*

When emphasizing the importance of relevant measurements in evaluation of forecasting models, Krollner et al., 2010 [23] stated that over 80% of the papers reported that their model outperformed the benchmark model. However, most analysed studies do not consider real world constraints like trading costs and slippage. In this study, for the data presented, the residual analysis has been under discussion, as the usual measurements do not necessarily capture the quality of the time series forecasting in an Emergency Department if considering the applied science aspect, with real-world constraints. In particular, when analysing the environment of an ED, it is often more important to have approximated values constantly close, or within a certain threshold of the actuals, rather than keeping a low perceptual average RMSE.

To understand this, consider a situation where the average RMSE is low but exhibits occasional spikes. In an ED situation a spike corresponds to high unpredicted demand, such as a sudden large influx of patients, and may have dangerous clinical outcomes for patients if resource capacity is significantly exceeded on those occasions. A prediction which is less accurate on average, but which exhibits fewer spikes in RMSE is much preferred as it predicts changing demand more reliably, even if it is not always exact.

For this reason, in the analysis below (Table 2) every simulation contains confidence bands and accuracy metrics which provide a more readable measure for healthcare professionals (following discussions) to compare predictions. Thus "Very good" predictions have residuals that are no further than 15% of the actual mean away from the actual counterpart, "Good" are within 15–25%, "Regular" are within 25–40% of the mean, and any prediction with residual above 40% of the actual mean is considered "Unreliable". Accordingly, all predictions that turned out to be within 40% of the actual values are classed as "Reliable" and overall "Reliability" (i.e., the total percentage of "Very good", "Good" and "Regular" predictions) is used as a summary metric.


**Table 2.** Reliability comparison of ARIMA, Prophet and GRNN.

In addition, the average RMSE is also provided (Table 3) for each model and indicator, as the most neutral statistical residual measurement, even though, for the reasons outlined above, it may not the best fit for the ED environment.

**Table 3.** Accuracy comparison of ARIMA, Prophet and GRNN.


In both cases, green highlighting indicates the model with the best results.

#### *8.2. Discussion*

As can be seen, and as already previewed above, the machine learning technique GRNN usually provides the best results for this dataset and usually outperforming ARIMA, which is traditionally used for predicting ED indicators. However, ARIMA usually comes a close second and actually outperforms GRNN for the Patient Attendances, probably due to the very high variance in this indicator.

The relative performance is heavily impacted by the ability of GRNN to adapt rapidly to sudden dramatic changes in the data trends and for more stable time-series data, the performance of all three is likely to be closer.

The COVID-19 pandemic drove just such a change and strongly altered the behaviour of ED indicators, which then presented a challenge to the purely autoregressive models represented by ARIMA and Prophet, which rely solely on abstractions of previous values. GRNN was able to adapt to the change in data patterns promptly, resulting in closer predictions during this change period for all 4 indicators, as seen in Section 7. In addition, when analysing the entire 10 months of data for reliability (Table 2) and accuracy (Table 3), GRNN was the best fit model for 3 out of the 4 indicators, as seen in Section 8.

Although the difference in both accuracy and reliability are evident and relevant, data analysis suggests that outside of the pandemic period, the improvements in accuracy and reliability are reduced to slim margins. Hence, further simulations with different KPIs and data sample periods are necessary to further validate the findings of this study.

#### **9. Conclusions**

The main conclusion of this paper is that prediction of Emergency Department (ED) indicators can be improved by using machine learning models such as GRNN in comparison to traditional models such as ARIMA, particularly when sudden changes are observed in the data.

Indeed, it may be the case that ARIMA is better overall when indicators are very stable and so the proposed extension of this study is aimed at creating hybrid predictions, mixing observed predictions of ARIMA, GRNN, and possibly even Prophet, into a mixed time-series forecasting model, based on the short-term seasonal accuracy factors achieved by each model.

**Author Contributions:** Conceptualization, all authors; methodology, D.D.; software, D.D.; investigation, D.D.; data curation, D.D.; writing—original draft preparation, D.D.; writing—review and editing, C.W. and N.R.; visualization, D.D.; supervision, C.W. and N.R.; funding acquisition, C.W. All authors have read and agreed to the published version of the manuscript.

**Funding:** This research was partially funded by an Innovate UK Knowledge Transfer Partnership, KTP10507.

**Data Availability Statement:** Restrictions apply to the availability of these data. Data was obtained from Transforming Systems and are available from the D.D. subject to the permission of Transforming Systems.

**Conflicts of Interest:** The authors declare no conflict of interest.

#### **References**


## *Article* **Data-Driven Real-Time Online Taxi-Hailing Demand Forecasting Based on Machine Learning Method**

### **Zhizhen Liu, Hong Chen \*, Xiaoke Sun and Hengrui Chen**

College of Transportation Engineering, Chang'an University, Xi'an 710000, China; 2018021077@chd.edu.cn (Z.L.); 2019021061@chd.edu.cn (X.S.); 2019021060@chd.edu.cn (H.C.)

**\*** Correspondence: glch@chd.edu.cn; Tel.: +86-137-0022-9619

Received: 13 August 2020; Accepted: 22 September 2020; Published: 24 September 2020

### **Featured Application: This research provides a valuable data-driven method on forecasting the online taxi-hailing demand, and it could be potentially applied to developing multi-modes transportation prediction.**

**Abstract:** The development of the intelligent transport system has created conditions for solving the supply–demand imbalance of public transportation services. For example, forecasting the demand for online taxi-hailing could help to rebalance the resource of taxis. In this research, we introduced a method to forecast real-time online taxi-hailing demand. First, we analyze the relation between taxi demand and online taxi-hailing demand. Next, we propose six models containing different information based on backpropagation neural network (BPNN) and extreme gradient boosting (XGB) to forecast online taxi-hailing demand. Finally, we present a real-time online taxi-hailing demand forecasting model considering the projected taxi demand ("PTX"). The results indicate that including more information leads to better prediction performance, and the results show that including the information of projected taxi demand leads to a reduction of MAPE from 0.190 to 0.183 and an RMSE reduction from 23.921 to 21.050, and it increases R2 from 0.845 to 0.853. The analysis indicates the demand regularity of online taxi-hailing and taxi, and the experiment realizes real-time prediction of online taxi-hailing by considering the projected taxi demand. The proposed method can help to schedule online taxi-hailing resources in advance.

**Keywords:** online taxi-hailing demand; backpropagation neural network; extreme gradient boosting; real-time prediction

### **1. Introduction**

With the development of the intelligent transportation system, the travel of residents is growing more convenient. Nevertheless, because of the information asymmetry between passengers and drivers, the spatial and temporal distribution of passengers and drivers are inconsistent. The limited urban transportation resources were wasted by the information asymmetry between passengers and drivers. Therefore, trip demand in the urban area urgently needs to be studied. Recently, online taxi-hailing has gradually become the primary trip mode for urban residents. Meanwhile, the taxi still assumes the function of public transportation for urban residents. Under these circumstances, the online taxi-hailing demand would be affected by the taxi demand because of the homogeneity between the taxi and online taxi-hailing. Thus, we should take the taxi demand into account while studying the online taxi-hailing demand.

In the past, research that focused on forecasting traffic demand was mostly based on environmental data and GPS data [1–32]. Moreover, the research mined the features of GPS data and environmental data to forecast the trip demand, while the research ignored the relationship between the taxi and online taxi-hailing.

Therefore, this study aims to enhance the prediction effects of forecasting online taxi-hailing demand considering the taxi demand. Moreover, this research is a follow-up experiment of [32]. First, we use Pearson correlation analysis to screen the determinative influence factors to enhance the prediction accuracy. Then, online taxi-hailing demand forecasting models based on extreme gradient boosting (XGB) and backpropagation neural network (BPNN) were introduced to explore the relationship between taxi demand and online taxi-hailing demand. Next, we realize the real-time forecasting of online taxi-hailing demand by proposing a data-driven prediction method. This study would help to enhance the accuracy of online taxi-hailing demand forecasting and is essential for rebalancing traffic resources.

The literature review related to our study is presented in Section 2. Section 3 describes the data and the preprocessing of data in this study. Next, we proposed methods to enhance the accuracy of predicting online taxi-hailing demand in Section 4, while Section 5 concludes the results. Finally, the discussion and conclusion are shown in Section 6.

### **2. Related Work**

Over the years, numerous works have been dedicated to enhancing the accuracy of trip demand forecasting. The first application of the trip demand forecasting is predicting trip demand based on a four-step process considering spatiotemporal factors [1]. L. Moreira-Matias et al. predicted the spatial distribution of taxi demand by presenting a method [2]. Then, he proposed a learning model considering real-time data to forecast the taxi-passenger demand's spatiotemporal distribution [3]. Next, he proposed a combination forecasting model to forecast the taxi-passenger demand's spatiotemporal distribution [4]. K. Zhang et al. forecasted the location of hotspots and tested the heat of the hotspots by presenting an adaptive forecasting method [5]. Next, N. Davis et al. proposed a time-series method to forecast the taxi demand by mining the regulation of taxi mobile app data [6]. X. Peng et al. forecasted the taxi demand hotspots based on social media check-ins to reduce the imbalanced supply and demand of taxis [7]. K. Zhao et al. predicted the taxi demand through three forecasting methods, respectively, based on the Markov model, Lempel–Ziv–Welch model, and ANN model [8]. Besides the GPS data and environmental data, J. Xu et al. also considered historical traffic behaviors as an important variable in the taxi demand forecasting problem, and they proposed an LSTM method to forecast taxi demand in several urban areas [9]. D. Zhang improved the hidden Markov chain model and proposed a D-model to forecast the taxi demand [10]. For exploring the relationship between taxi and subway, Y. Bao et al. took the interaction between taxi demand and subway demand into account to explore the impacts of the interaction on the accuracy of taxi demand and proposed a taxi demand prediction method based on a neural network model [11]. N. Davis explored the impacts of tessellation on-demand prediction effects and proposed a combination algorithm of different tessellation strategies to predict taxi demand [12].

The research above considered the impacts of the GPS data and the environmental data on prediction accuracy, but they did not take real-world event information into account. To address this problem, I. Markou et al. mined the real-world event information from unstructured data, and they applied the machine learning method to realize taxi demand forecasting [13]. S. Ishiguro et al. introduced the real-time demographic data into the taxi demand forecasting method and explored the impacts of demographic data on taxi demand forecasting accuracy by a stacked denoising autoencoder [14]. S. Liao conducted a comparison of two deep neural networks for forecasting trip demand and found that DNNs perform better than other traditional machine learning methods [15]. U. Vanichrujee et al. presented an ensemble method consisting of the LSTM model, GRU model, and extreme gradient boosting model (XGB) to forecast taxi demand [16]. J. Xu proposed a sequence learning method considering the historical demand to forecast trip demand [17]. H. Yao et al. presented a multi-view spatiotemporal network framework to simulate spatiotemporal relationships and forecasted the traffic demand [18]. H. Yan analyzed taxi requests and proposed a Bayesian hierarchical semiparametric model to forecast taxi demand [20]. L. Kuang introduced the unstructured data

into a deep learning method to forecast the trip demand [21]. However, the methods above ignored the destination of passengers. L. Liu proposed a method to forecast the taxi demand between origin–destination pairs [22]. I. Markou introduced real-world events into the prediction method and used the data to forecast traffic demand [23]. F. Rodrigues et al. explored the relationship between drop-off points and pick-up points and proposed a spatio-temporal LSTM model to forecast the taxi demand [25]. F. Terroso-Saenz predicted taxi demand through the QUADRIVEN method based on human-generated data [26]. Y. Xu proposed a graph and time-series learning model considering the relationships between non-adjacent for city-wide taxi demand prediction [27]. H. Yu proposed a deep spatiotemporal recurrent convolutional neural network to forecast traffic flow [28]. X. Liu explored the impacts of the socio-economic, transport system, and land-use patterns on taxi demand forecasting [29]. A. Saadallah introduced the BRIGHT method, which is an ensemble of time series analysis models to forecast taxi demand precisely [30]. A. Safikhani proposed a STAR model to analyze the spatiotemporal distribution of taxis and introduced the LASSO-type penalized methods to tackle parameter estimation [31]. Recently, Z. Liu proposed a combination forecasting model considering the random forest method and ridge regression method to predict taxi demand in hotspots [32].

In general, given the relationship between different trip modes, more attempts can be justified. This study is initiated by a real-world case study to better understand the underlying relationship between the demands of different trip modes.

### **3. Data Description**

### *3.1. Taxi GPS Data*

We obtained the GPS data from the Xi'an Taxi Management Office in Xi'an city of China. The data include location information, vehicle state information, time information, and license plate information. Moreover, the taxi GPS data were recorded every 5 s for 30 days in November 2016 and include 40 million points which are located in Xi'an city of China. The GPS data were cleaned and selected. An instance of taxi GPS data is shown in Table 1.


**Table 1.** An instance of taxi GPS data.

#### *3.2. Online Taxi-Hailing GPS Data*

Online taxi-hailing GPS data are from Didi Chuxing GAIA Initiative, and the GPS data are located in Xi'an city of China. The dataset consists of 600 million track points, and it was recorded every 2–4 s for 30 days in November 2016. An instance of online taxi-hailing GPS data is shown in Table 2.


**Table 2.** An instance of online taxi-hailing GPS data.

### *3.3. Environmental Data*

The environmental data conclude air quality data and meteorological data. The air quality data in Xi'an city are from the official website of Green Breathing. The meteorological data in Xi'an city were derived from the National Meteorological Information Center. This study selects the hourly environmental data of Xi'an. In general, the environmental data contain 15 dimensions for the research (Table 3).

**Table 3.** Environmental data structure description.


#### **4. Methods**

### *4.1. Feature Selection*

Ensuring that the correlations between the features and the dependent variables are important in the prediction problem. Likewise, ensuring that the features are independent of one another is also important for improving the prediction accuracy. While modeling the forecasting method, both the features which exhibit strong, multiple collinearities and the features which have a low correlation with the dependent variable should be eliminated for enhancing the prediction accuracy. Thus, we choose the Pearson correlation analysis to test the correlation of all features and the dependent variable [33,34]. The calculation of Pearson correlation analysis is as Equation (1).

$$\mathfrak{g}\_{\mathbf{X},\mathbf{Y}} = \frac{\mathrm{cov}(\mathbf{X},\mathbf{Y})}{\sigma\_{\mathbf{X}}\sigma\_{\mathbf{Y}}} = \frac{\mathrm{E}[(\mathbf{X}-\mu\_{\mathbf{X}})(\mathbf{Y}-\mu\_{\mathbf{X}})]}{\sigma\_{\mathbf{X}}\sigma\_{\mathbf{Y}}},\tag{1}$$

cov(X, Y) is the covariance between the features X and Y. σ<sup>X</sup> and σ<sup>Y</sup> indicate the standard deviations of the features X and Y. ρX,Y is the correlation value of the features X and Y. The value range of ρX,Y belongs to (−1, 1). If ρX,Y > 0, the two features are positively correlated: if ρX,Y < 0, the two features are negatively correlated. The larger absolute value of ρX,Y indicates a stronger correlation between the features X and Y.

### *4.2. BPNN*

Artificial neural networks (ANNs) possess attributes of learning, generalizing, parallel processing, and error endurance. These attributes make the ANNs useful in modeling complex situations. Therefore, we employ BPNN, a type of ANN, for forecasting online taxi-hailing demand in this study [35,36]. A three-layer BPNN employed in this paper is shown in Figure 1 [37]. In Figure 1, "T" indicates the information of time factors, "E" is the information of environmental factors, and "TX" represents the information of taxi demand.

**Figure 1.** An instance of three-layer backpropagation neural network (BPNN).

The connection weights among nodes are obtained by data training in the backpropagation process. Then, it produces the minimized least-mean-square error between the true and the estimated values from the neural network's output. First, the connection weights are assigned initial values. Then, the weights are updated based on the back-propagated error between the predicted and true output values. Assume that there are n input neurons, m hidden neurons, and one output neuron, a training process can be described as follows.

Hidden layer stage: Calculating the outputs of all neurons in the hidden layer as Equations (2) and (3).

$$\mathbf{i}\text{ net}\_{\rangle} = \sum\_{\mathbf{i}=0}^{n} \mathbf{v}\_{\mathbf{i}\uparrow} \mathbf{x}\_{\mathbf{i}\prime} \mathbf{j} = 1,2,\cdots,\text{ m}\_{\prime} \tag{2}$$

$$\mathbf{y}\_{\mathbf{j}} = \mathbf{f}\_{\mathbf{H}}(\mathbf{net}\_{\mathbf{j}}), \mathbf{j} = \mathbf{1}, \mathbf{2}, \cdots, \mathbf{m}, \tag{3}$$

netj is the activation value of the jth node, yj is the output of the hidden layer, and fH is the activation function of a node; the activation function is the rectified linear unit function as Equation (4).

$$\mathbf{f\_{H}(x)} = \max(0, \mathbf{x}), \tag{4}$$

Output stage: The outputs of all neurons in the output layer are as Equation (5).

$$\mathbf{O} = \mathbf{f}\_{\mathbf{o}} \left( \sum\_{\mathbf{j}=0}^{\mathbf{m}} \mathbf{w}\_{\mathbf{ik}} \mathbf{y}\_{\mathbf{j}} \right) \tag{5}$$

fo is the activation function as Equation (4). All weights are assigned random values initially and then modified by the delta rule according to the learning samples.

The three-layer BPNN above is the basic application of BPNN in online taxi-hailing demand prediction method. To find out the best network structure of BPNN for different forecasting models, we should use the grid search algorithm to determine the network structures of the models based on BPNN.

### *4.3. XGB*

XGB is a boosting model based on a classification and regression tree (CART), which takes full advantage of the residual of a base classifier [38]. The boosting algorithm combines simple tree models to establish a more precise model, and it overcomes the influence of the interference signal. The prediction is as Equation (6).

$$\mathbf{y}\_{\mathbf{i}} = \sum\_{\mathbf{k}=\mathbf{l}}^{\mathbf{K}} \mathbf{f}\_{\mathbf{k}}(\mathbf{x}\_{\mathbf{i}}), \mathbf{f}\_{\mathbf{k}} \in \mathbf{F} \tag{6}$$

fk is the kth tree, K is the number of trees, and F is a set of all trees.

Suppose that S = x1, y1 , ... xi, yi ... Xw, yw is a known dataset with N samples where x has L features, and y is the label of different emitters. The objective function is as Equation (7).

$$\mathbf{O} = \sum\_{\mathbf{i}=1}^{N} \mathbf{l}\left(\mathbf{\hat{y}}\_{\mathbf{i}}, \mathbf{y}\_{\mathbf{i}}\right) + \sum\_{\mathbf{k}=1}^{K} \mathbf{r}(\mathbf{f}\_{\mathbf{k}}),\tag{7}$$

yˆi is the predicted value of xi, l represents the difference between the true and predicted values. r(fk) is the regularized term of kth trees, which penalize the complexity of the model to avoid overfitting, and it could be calculated as Equation (8).

$$\mathbf{r}(\mathbf{f}\_k) = \mathbf{y}\mathbf{T} + \frac{\omega}{2} \|\boldsymbol{\vartheta}\|\,\tag{8}$$

γ, ω are penalty coefficients, T is the number of leaves in the tree, and ϑ is leaf weight.

### *4.4. Evaluation Criteria*

Moreover, three accuracy measures are applied to evaluate the performance of online taxi-hailing prediction. The measures are root-mean-square error (RMSE), mean absolute percentage error (MAPE) and goodness of fit (R2), which are calculated as Equations (9)–(11).

$$\text{RMSE} = \left(\text{T}^{-1} \sum\_{\text{n}=1}^{\text{T}} \left(\text{C}\_{\text{n}} - \text{C}\_{\text{n}}\right)^{2}\right)^{1/2},\tag{9}$$

$$\text{MAPE} = \text{T}^{-1} \sum\_{\mathbf{n}=1}^{\mathbf{T}} \left| \left( \mathbf{C}\_{\mathbf{n}} - \mathbf{C}\_{\mathbf{n}} \right) / \mathbf{C}\_{\mathbf{n}} \right|, \tag{10}$$

$$\mathbf{R}^2 = \frac{\sum\_{\mathbf{i}=1}^{N} \left(\mathbf{C}\_\mathbf{n} - \overline{\mathbf{C}}\right)^2}{\sum\_{\mathbf{i}=1}^{N} \left(\mathbf{C}\_\mathbf{n} - \overline{\mathbf{C}}\right)^2},\tag{11}$$

Cˆ n, Cn and C are the true, the predicted, and the mean value, respectively. Then, T is the number of samples.

#### **5. Results**

### *5.1. Feature Selection*

Before we predict the online taxi-hailing demand, we should select a reasonable set of forecasting features. Therefore, we use Python to calculate the correlations among prediction indicators, and we eliminate factors with strong collinearity and factors with low cross-correlation. Correlations among environmental factors are as Table 4. In Table 4, "OT" and "TX", respectively, indicate online taxi-hailing demand and taxi demand, "DW" is the day of the week, "HD" represents the hour of the day, and "WON" indicates whether the day is a workday. Other features are as Table 3.


**Table 4.** Correlations among environmental factors.

As shown in Table 4, we find that the values of correlations among AQ, AQI, PM2.5, PM10, and CO are more than 0.8. Therefore, we remove AQI, PM2.5, PM10, and CO from the predictive factors. Next, we eliminate the features whose correlation with the OT factor is less than 0.2. Predictive indicators of online taxi-hailing demand areas are shown in Table 5. Predictive indicators in Table 5 are divided into "T", "E" and "TX". "T" indicates the information of the time, "E" represents the environmental factors, and "TX" contains information about taxi demand.


**Table 5.** Predictive indicators of trip demand.

Then, all data are proceeded through by the One-Hot Encoder function in the scikit-learn. preprocessing library. An instance of the DW indicator is shown in Figure 2.


**Figure 2.** An instance of One-Hot Encoder.

After the encoding of indicators in Table 5, the dimension of the dataset was expanded to 58. Additionally, the first 23 days of November 2016 are taken as the training set, with the other seven days as the testing set in this study.

### *5.2. Data Preprocessing*

In this study, we choose the Bell Tower area as the research object according to the study of Liu et al. [32], because the Bell Tower area contains the most trip demand. The Bell Tower area is a commercial area, and its traffic demand exhibits a robust tidal phenomenon. The Bell Tower area is as in Figure 3.

Then, we cut taxi data and online taxi-hailing data into time slices. The trip demand for taxi and online taxi-hailing areas is shown in Figure 4. We find that the taxi demand and online taxi-hailing demand are regular, and taxi demand decreases while the online taxi-hailing demand increases in peak hours.

### *5.3. Online Taxi-Hailing Demand Forecasting*

Then, we forecast the online taxi-hailing demand in Bell Tower area based on the BPNN and XGB. We test the prediction effects of different indicators based on the BPNN and XGB. In the experiment, we add time factors, environmental factors, and taxi demand factors into models based on BPNN and XGB. Models with different impacting factors are shown in Table 6. Next, we use the grid search algorithm to adjust the hyperparameters of models based on BPNN and XGB. Moreover, the hyperparameters for the models are illustrated in Table 7. Furthermore, the results of models with different impacting factors are shown in Figures 5 and 6. Additionally, the factors of "T", "E" and "TX" are shown in Table 5.

**Figure 3.** Research scope in the Bell Tower area.

**Figure 4.** The demand for taxi and online taxi-hailing in Bell Tower area.





**Table 7.** *Cont.*

**Figure 5.** Online taxi-hailing demand prediction results of models based on BPNN. (**a**) Result of model "BPNN + T"; (**b**) Result of model "BPNN + T + E"; (**c**) Result of model "BPNN + T + E + TX".

**Figure 6.** Online taxi-hailing demand prediction results of models based on extreme gradient boosting (XGB). (**a**) Result of model "BPNN + T"; (**b**) Result of model "BPNN + T + E"; (**c**) Result of model "BPNN + T + E + TX".

Then we use RMSE, MAPE and R2 to test the prediction effect of the models (Table 8). Table 8 shows the RMSE, MAPE and R2 of six different models' test datasets in the Bell Tower area. Comparing the performance of predictions based on BPNN, our results show that the model "BPNN + T + E + TX" is the best-performing method for solving online taxi-hailing prediction problems. Moreover, among three predictions based on XGB, the model "XGB + T + E + TX" is the best-performing method for online taxi-hailing prediction problems.

Next, we analyze the contributions of the different sources of information. From Table 8, we can find that including information about taxi demand ("TX") enhances the prediction effects based on BPNN and XGB. In the BPNN models, including information "E" leads to a MAPE reduction from 0.224 to 0.190, while it decreases RMSE from 28.576 to 23.921, and increases the R2 from 0.819 to 0.845. Likewise, including information "TX" leads to a MAPE reduction from 0.190 to 0.132, and it increases the R<sup>2</sup> from 0.845 to 0.866. Meanwhile, in the XGB models, including information "E" leads to a MAPE reduction from 0.333 to 0.197 while it reduces RMSE from 26.296 to 21.206, and increases the R2 from 0.833 to 0.857. Including information "TX" leads to a MAPE reduction from 0.197 to 0.139, and it

increases the R<sup>2</sup> from 0.857 to 0.865. Additionally, the performance of the model "BPNN + T + E + TX" is the best among the six models in Table 8.


**Table 8.** Prediction effects of BPNN and XGB.

To evaluate the prediction performance of BPNN and XGB in different hours, we report MAPE and RMSE of six models in different hours. Figure 7a shows that the model "BPNN + T + E + TX" obtains the lowest MAPE among three predictions except at 6 a.m., 8 a.m., and 9 p.m. Figure 7b shows that the performance of the model "XGB + T + E + TX" is the best except at 11 a.m. and 5 p.m. From Figure 8, we know that the model "BPNN + T + E + TX" obtains the lowest RMSE among three predictions except at 4, 7, and 9 p.m., and performances of the model "XGB + T + E + TX" are the best except at 11 a.m., 12 a.m., 4 p.m. and 5 p.m.

**Figure 7.** Average mean absolute percentage error (MAPE) of online taxi-hailing demand prediction based on BPNN and XGB in different hours. (**a**) MAPE of the prediction models based on BPNN; (**b**) MAPE of the prediction models based on XGB.

### *5.4. Real-Time Online Taxi-Hailing Demand*

While we are forecasting online taxi-hailing demand by different models in Table 6, we ignore that the future taxi demand is unavailable. To realize the real-time online taxi-hailing demand prediction, we should predict the taxi demand before forecasting the online taxi-hailing demand by model "BPNN + T + E" and "XGB + T + E". The results of taxi demand prediction are as in Figure 9 and Table 9.

**Figure 8.** Average RMSE of online taxi-hailing demand prediction based on BPNN and XGB in different hours. (**a**) RMSE of the prediction models based on BPNN; (**b**) RMSE of the prediction models based on XGB.

**Figure 9.** The taxi demand prediction based on "BPNN + T + E" and "XGB + T + E". (**a**) Result of taxi demand prediction based on "BPNN + T + E"; (**b**) Result of taxi demand prediction based on "XGB + T + E".

**Table 9.** Prediction effects of taxi demand prediction based on BPNN and XGB.


Table 9 shows that the model "BPNN + T + E" performs better than model "XGB + T + E" in forecasting taxi demand. Based on the information on taxi demand prediction ("PTX"), we forecast the online taxi-hailing demand by model "BPNN + T + E + PTX" as Figure 10. From Table 10, we find that including the information of "PTX" leads to an MAPE reduction from 0.190 to 0.183 and an RMSE reduction from 23.921 to 21.050, and it increases the R2 from 0.845 to 0.853. However, because "PTX" is the projected taxi demand, the performance of the model "BPNN + T + E + TX" is better than the model "BPNN + T + E + PTX". Furthermore, Figure 11 indicates that the performance of the model "BPNN + T + E + PTX" is better than the model "BPNN + T + E" for most hours.

**Figure 10.** The demand prediction of online taxi-hailing based on the model "BPNN + T + E + PTX". **Table 10.** Prediction effects of model "BPNN + T + E", "BPNN + T + E + PTX" and "BPNN + T + E + TX".


**Figure 11.** Prediction effects of online taxi-hailing demand predictions in different hours. (**a**) Average RMSE of online taxi-hailing demand predictions in different hours; (**b**) Average MAPE of online taxi-hailing demand predictions in different hours.

### **6. Discussion and Conclusions**

### *6.1. Discussion*

We proposed a real-time prediction method of online taxi-hailing demand and studied the impacts of forecasting taxi demand on the accuracy of online taxi-hailing demand. Then, we obtained the findings below:


However, the research still has some limitations. In the future, these limitations should be studied. For example, we did not use linear regression models to predict online taxi-hailing demand. Moreover, we should propose a method to forecast multiple trip demands simultaneously. Additionally, we will set projected environmental data as factors of real-time demand prediction for real-time forecasting.

### *6.2. Conclusions*

In this research, the data-driven forecasting method of online taxi-hailing demand is carried out. To improve the prediction effects of online taxi-hailing demand, we proposed two methods for predicting online taxi-hailing demand based on BPNN and XGB, respectively. Then, we tested the two methods considering the information of "T", "E" and "TX". The results indicate that considering more information could improve the prediction accuracy of the models. Next, we forecasted the taxi demand and introduced a real-time online taxi-hailing demand forecasting method based on the projected taxi demand. We found that including the information of "PTX" improved prediction performance of model "BPNN + T + E". Furthermore, MAPE, RMSE and R2 of the testing set of the model "BPNN + T + E + PTX" are, respectively, improved to 0.183, 21.050, and 0.853. Because the more precise traffic demand forecasting method can provide a more reasonable basis for public resources' dispatch, the proposed method is cost-effective in the intelligent transportation system.

Furthermore, more experiments about traffic demand prediction can be considered. For instance, the multi-mode traffic demand predictor could be proposed to improve the prediction accuracy by considering the interaction among multiple modes of transportation. Meanwhile, the multi-mode traffic demand predictor can also take the interaction among different regions into account.

**Author Contributions:** Conceptualization, Z.L. and H.C. (Hong Chen); methodology, Z.L.; software, Z.L.; validation, X.S., H.C. (Hengrui Chen) and Z.L.; formal analysis, Z.L.; investigation, Z.L.; resources, Z.L.; data curation, H.C. (Hong Chen); writing—original draft preparation, Z.L.; writing—review and editing, Z.L.; visualization, H.C. (Hong Chen); supervision, H.C. (Hong Chen); project administration, Z.L. All authors have read and agreed to the published version of the manuscript.

**Funding:** This research was funded by the Technology Project of the Shaanxi Transportation Department, grant number 15-39R.

**Acknowledgments:** This study was supported by the Technology Project of Shaanxi Transportation Department. **Conflicts of Interest:** The authors declare no conflict of interest.

### **References**


© 2020 by the authors. Licensee MDPI, Basel, Switzerland. This article is an open access article distributed under the terms and conditions of the Creative Commons Attribution (CC BY) license (http://creativecommons.org/licenses/by/4.0/).

## *Article* **Heatwave Damage Prediction Using Random Forest Model in Korea**

### **Minsoo Park 1, Daekyo Jung 2, Seungsoo Lee <sup>2</sup> and Seunghee Park 1,3,\***


Received: 31 October 2020; Accepted: 19 November 2020; Published: 20 November 2020

**Abstract:** Climate change increases the frequency and intensity of heatwaves, causing significant human and material losses every year. Big data, whose volumes are rapidly increasing, are expected to be used for preemptive responses. However, human cognitive abilities are limited, which can lead to ineffective decision making during disaster responses when artificial intelligence-based analysis models are not employed. Existing prediction models have limitations with regard to their validation, and most models focus only on heat-associated deaths. In this study, a random forest model was developed for the weekly prediction of heat-related damages on the basis of four years (2015–2018) of statistical, meteorological, and floating population data from South Korea. The model was evaluated through comparisons with other traditional regression models in terms of mean absolute error, root mean squared error, root mean squared logarithmic error, and coefficient of determination (*R*2). In a comparative analysis with observed values, the proposed model showed an *R*<sup>2</sup> value of 0.804. The results show that the proposed model outperforms existing models. They also show that the floating population variable collected from mobile global positioning systems contributes more to predictions than the aggregate population variable.

**Keywords:** heatwaves; big data; random forest regression model; machine learning; prediction

### **1. Introduction**

According to the National Center for Environmental Information of the National Oceanic and Atmospheric Administration, the average annual global temperature has reached an all-time high over the past five years (0.75–0.95 ◦C rise from the average annual temperature in the 20th century) and is continuing to gradually increase. Global warming has considerably changed the climate in recent decades, increasing the probability and intensity of meteorological and climatic disasters [1,2]. The duration and intensity of heatwaves are expected to increase with an increase in the average annual temperature, and deaths from heatwaves are expected to double [3]. The record heatwave in the United Kingdom in 2003, which killed 70,000 people, is expected to become normal summer weather by 2040 [4].

Because heatwaves cause human and physical disasters every year, it is important to minimize disaster damage by establishing timely and preemptive disaster responses. A disaster response is a continuous decision making process conducted on the basis of a variety of information and past experiences that are continuously gathered from a range of locations. Further, disaster response is conducted from the moment a disaster is perceived to have occurred until the time when it ends. In the past, data collection techniques were less effective and provided limited information for use in contextual judgment and decision making. Consequently, owing to the lack of information available for

contextual judgment and decision making, disaster responses were highly dependent on the subjective experiences of decision makers. Furthermore, although data collection technology has developed rapidly with an increase in the information available for decision making, the capability of humans to process and use this information in disaster response is limited, especially in cases that require swift decision making.

The importance of utilizing big data and artificial intelligence (AI)-based analysis for the rapid processing of various types of data has been recognized. Big data refers to large and diverse forms of data that cannot be processed by traditional database systems. Further, big datasets can include signals, images, and documents whose sizes increase exponentially; such data are abundant, owing to the development of sensing and social media-oriented communication technologies within the present Internet of Things environment [5,6]. Big data systems not only utilize a variety of data quickly but are also expected to play a crucial role in analyzing meaningful information. However, early systems only focused on data collection and storage [7]. To produce meaningful results from big data, AI technology as well as simple statistical and visualization functions must be employed for analysis and prediction.

Heatwave definitions vary among different countries [8]; however, heatwaves are generally defined on the basis of the normal weather and temperatures corresponding to the seasons of a region, and they are said to occur when there is a large deviation from the normal climate pattern in a given region. These extreme weather conditions occur locally and extensively, which limits rapid disaster response. In particular, because such extreme weather conditions occur extensively throughout a region, response procedures, such as preparing resources immediately in the event of a disaster, are limited. This indicates the need to develop early warning systems to guide disaster responses. Previous studies have focused on mortality as an endpoint for the analysis of damage caused by heatwaves [9–11], and only few studies have focused on morbidity as an indicator [12]. In addition, most studies have adopted only weather-related parameters as predictor variables of mortality. However, even under the same weather conditions, the damage pattern can vary, and it depends on other variables, such as the vulnerable population. This emphasizes the need to consider various variables as well as weather-related parameters to predict heatwave damage.

In this study, a heat-related health prediction model was developed on the basis of a machine learning algorithm for early warning systems. The purpose of this study is to help decision makers to preemptively respond, reducing human and economic losses. This paper is organized as follows: Section 2 describes the architecture of the random forest (RF) architecture and variables that can represent damage caused by heatwaves obtained from a big data collection site operated in South Korea. Experimental results, including variable evaluation, model optimization, and RF's accuracy evaluation in comparison with the tradition regression models is mention on Section 3. A trained model was applied to the site and visualized—this is also specified in Section 3. Section 4 presents the discussion and conclusion of this study.

### **2. Methodology**

### *2.1. Test Area*

South Korea was selected as the test bed, and its heatwave characteristics were investigated to establish the range and duration of the collected data for model training. The typical weather pattern that causes heatwaves in South Korea is a significant rise in temperature during the daytime, owing to stagnant high atmospheric pressure, which is a widespread occurrence across the country [13]. Although heatwave standards vary by country, a heatwave warning in South Korea is issued when the daily maximum temperature is expected to be above 33 ◦C for at least 2 consecutive days. Alerts are concentrated mainly from June to August. Heatwave occurrences in South Korea exhibit substantial interannual variability, but recently, they have become more frequent in late May and early September, and their frequency and intensity have increased [14,15]. In particular, record-breaking heatwaves occurred in 2016 and 2018, causing many casualties [16]. The Korea Disease Control and Prevention

Agency (KDCA; formerly the Korea Centers for Disease Control and Prevention) has been operating a nationwide thermal disease monitoring system since 2011 to determine the weekly health damage caused by heatwaves from late May to early September every year. South Korea has 17 administrative districts composed of 8 municipalities and 9 provinces. In accordance with the characteristics of these test beds, we set the range resolution to match the 17 administrative divisions, and the temporal resolution was set to a 1-week period to match the disease monitoring system data from the KDCA.

### *2.2. Variable Selection*

Relevant variables were selected to predict heat-related damage. Heat-related diseases mainly occur in the form of cardiovascular and respiratory diseases and heatstroke [17]; consequently, various epidemiological studies of their occurrence have been conducted worldwide [17–19]. Among the most important characteristics of the damage caused by heatwaves and the corresponding vulnerabilities are the damage patterns of disasters, which cannot be obtained from temperature variables alone [20,21]. Studies have shown that the damage caused by disasters is related to geographic features [22,23], surface relative humidity [24–26], wind speed [27,28], population density [29], economic status [30], and vulnerable occupational groups (laborers, construction, and agricultural workers) [31–34]. On the basis of important characteristics determined in previous studies, we selected the following variables: temperature, humidity, wind speed, number of vulnerable occupational groups, insurance premiums per person, personal income per person, floating population, and registered population of residents (the number of people counted by the administration). The vulnerable population can be inferred from data on insurance premiums, income, and vulnerable occupational groups; further, as the values of these indicators increased, the number of patients with thermal diseases increased. However, both the aggregate and floating populations were used as population variables, and it was expected that the floating population, which reflects real-time information, would be a more useful variable for predictions than the aggregate population.

### *2.3. Random Forest Regression*

RF is an ensemble machine learning method that combines several separately trained models to create a strong learner that can be applied for classification and regression [35]. Such a combination of individual models can reduce overfitting and improve generalization. Therefore, RF has the advantages of high prediction accuracy and algorithm robustness. When training ensemble classifiers, techniques involving the use of different datasets or properties are applied to create different training models. As shown in Figure 1, RF is based on the bootstrap method which is resampling technique that involves random sampling of a dataset with replacement. Then repeats the process k times to obtain several independent and identically distributed training subsets {*Strain*,1, *Strain*,2, ... , *Strain*,k}, which have n samples. Then, m features from the n samples are selected without accepting duplicate samples. Prediction results from different decision trees build each training subset. The most commonly obtained forecast results are selected and determined by the final forecast [35,36]. In conclusion, although some trees created in RF may be exposed to overfitting, overfitting of the RF can be prevented by generating a large number of trees. RF algorithms have been applied to various disaster fields to predict [37,38], forecast, and evaluate risks [39,40].

**Figure 1.** Architecture of a random forest.

A loss function measures the similarity between the values predicted by a model and the correct values. To increase the accuracy of a model, the loss should be reduced as the model is trained. Different loss functions are used depending on the characteristics of the model (classification or regression) and dataset. The representative loss functions for measuring errors in regression models are mean absolute error (MAE) and mean squared error (MSE):

$$MAE = \frac{\sum\_{i=1}^{N} |y\_i - \hat{y}|}{N} \tag{1}$$

$$MSE = \frac{\sum\_{i=1}^{N} (y\_i - \mathfrak{g})^2}{N} \tag{2}$$

where *N* is the total number of data points, *y* is the real (observed) output value, and *y*ˆ is the predicted output value. When determining the MAE, the difference between the observed and predicted values of each data point is summed, and when determining the MSE, the square of the difference between observed and predicted values is summed. Therefore, the MSE is more sensitive to outlier values than the MAE. When the temperature exceeds a certain range, the heatstroke patients with thermal damage is characterized by a rapid increase in the incidence of patients. Consequently, MAE was considered as a loss function in this study to apply the characteristic of the target data.

To evaluate regression models, the proximity of predicted values to the observed data is quantified on the basis of the MAE, root mean squared error (RMSE), root mean squared logarithmic error (RMSLE), and coefficient of determination (*R*2), which are mainly used to evaluate accuracy [41,42]. However, the mean deviations of MAE, RMSE, and RMSLE (the lower the value, the higher the accuracy) have different values depending on the scale; therefore, it is difficult to make inferences using the absolute values alone. In contrast, *R*<sup>2</sup> is a relative value because it is the variance ratio of dependent variables predicted from independent variables; thus, the performance can be intuitively determined. *R*<sup>2</sup> generally ranges from 0 to 1. Note that if the *R*<sup>2</sup> value of a model is 0.7 or more, the model is usually considered reasonable [43].

The RF model was established to predict the number of patients with heat-related diseases caused by heatwaves. Socioeconomic, demographic, meteorological, and demographic data were collected and used as input variables for the model. The Boruta algorithm was used to filter the variables in the RF model [38]; this algorithm uses a Z score calculated by dividing the average loss by its standard deviation. It was implemented using an R package [44] to confirm whether certain variables can be used as predictive model inputs. Typical parameters of the random forest algorithm are ntree and max depth. To select optimal hyperparameters, the minimum loss function value (MAE) was found by increasing the number of decision trees (n-tree) and their maximum depth (max depth). After separating the dataset comprising the selected variables into training and test datasets, we evaluated the model trained using the training dataset by comparing it with other traditional regression models. Finally, the mean decrease in impurity (i.e., Gini importance) was used to extract the variable importance values, i.e., to determine the predicted contribution of each variable's model.

### **3. Results of Predicting the Number of Heatwave-Related Patients**

### *3.1. Data Collection and Pre-Processing for Model Training*

The variables and target data are listed in Table 1 with their data sources and renewal cycles. The variables are categorized as static or dynamic. Further, the abbreviations of the variables are used hereafter in the main text, figures, and tables. The static variables were pre-collected from a government agency that manages big data. They are universally updated quarterly and yearly, making them less volatile when predicting the number of heatwave-related patients in summer. In contrast, the dynamic variables, such as floating population and weather information, change with time. In South Korea, big data regarding the floating population are estimated on the basis of mobile big data collected hourly and monthly by SK Telecom's nationwide mobile communication base stations, and the estimated data are obtained from the Statistical Data Center. They are also estimated using public big data and communication data provided by the Seoul Open Data Plaza. Weather data are collected hourly and were provided by the Korea Meteorological Administration (KMA).


**Table 1.** Descriptions of variables to predict the number of heatwave-related patients.

However, the weather data, particularly those collected from sensors, may have missing values due to sensor defects. To address the problem of missing values, we used datasets consisting of columns with no missing values in order to predict the missing values of other datasets. Target data were based on weekly data obtained from the thermal disease monitoring system managed by KDCA (patients with heat-related diseases and deaths caused by heatwaves in emergency rooms nationwide); data regarding heat-related diseases such as heat stroke, exhaustion, cramps, fainting, and edema were also provided as weekly data. The resolution of the entire dataset was unified through considering the data properties of the features and targets; the temporal resolution was set to 1 week, and the range resolution was set on the basis of the South Korean administrative divisions. The datasets were randomly used for classification—80% were used as learning data and the remaining 20% as test data. Finally, variables were normalized before being inputted into the RF model to avoid creating a model that depends on specific variable units owing to the different ranges of each variable.

All variables were confirmed using the Boruta algorithm. The contribution of each variable to the RF prediction model is shown in Figure 2. The edges of each box represent the quartiles, and the line through each box represents the median. Each bar represents the 1.5 interquartile range of the nearer quartile, and the open circles represent outliers. The blue boxes correspond to the minimal, average, and maximum Z scores of a shadow attribute in the Boruta algorithm. The green boxplots correspond to confirmed important attributes. It was confirmed that all collected data from the Boruta algorithm can be used as variables of the predictive model.

**Figure 2.** Contribution ranking importance of the 12 independent variables in the random forest (RF)-based variable reduction algorithm from the Boruta package [44] in R.

### *3.2. Hyper-Parameter Optimization*

The experiment was conducted using the Scikit-learn (v.0.22.2) Python package [45] to implement the RF; the hardware platform was an Intel (R) Core (TM) i9-9900k 3.60 GHz CPU with 32 GB of RAM. The out-of-bag (OOB) error is mainly used to measure errors in machine learning models, such as bootstrap aggregation (bagging), which can be substituted for test errors [46]. The lowest MAE was found for the training, OOB, and test errors as n-tree and max depth increased, and is shown in Figure 3. When the number of decision trees was more than 100, all graphs remained almost unchanged, and when the number of decision trees was 181, the lowest OOB error was found (4.59). The training and test errors at this time were 1.67 and 3.94, respectively.

**Figure 3.** Hyperparameter optimization in the RF regression model: (**a**) training curves with respect to number of trees and (**b**) training curves with respect to maximum depth of trees.

On the other hand, when the number of decision trees was 181, the maximum depth of the decision tree remained constant from over 40, and the lowest OOB error (4.58) was found when the depth was 46. The train error and test error at this time were 1.69 and 3.94, respectively. Therefore, the hyperparameters were determined with 181 decision trees and 46 tree depths.

### *3.3. Model Comparasion*

The RF model was trained on the basis of the determined hyperparameters, and test data were applied to the regression model. The linear regression relationship between the predicted data from the model and test data is shown in Figure 4. The black line in the graph represents the regression line. The *x*-axis represents the weekly predicted number of patients with heat-related diseases in a specific region, as predicted by the model, and the *y*-axis indicates the weekly number of real patients with heat-related diseases in the region. The translucent band around the regression line area indicates the size of the confidence interval, which was 95% in this case. The red dotted line indicates when the model accurately reflected reality (slope: 1). The linear fitting slope of this RF model was 1.11. In particular, when high values were predicted, they tended to be underestimated compared to the observed values; however, the models were confirmed to be relatively reasonable.

**Figure 4.** Linear fitting results of test data.

To compare the accuracy of the regression model more quantitatively, Table 2 compares its results with those of other regression models. In particular, the RF model is compared with linear regression, decision tree, and support vector machine (SVM) models. All models were trained using the same training set, and all the trained models were evaluated by the same test set. However, some of the values predicted by the SVM model were negative; because the values must be greater than or equal to zero, we treated all negative values as zeros. As shown in Table 2, the best values for all the considered metrics, including MAE (3.816), RMSE (8.655), RMSLE (0.645), and *R*<sup>2</sup> (0.803), were obtained for the RF model. This means that the RF is more accurate than other models for making predictions, and the *R*<sup>2</sup> value of 0.803 proves that this model is reasonable.


**Table 2.** Comparisons of performance evaluation.

The bold is the best result among other methods.

### *3.4. Feature Importance*

Figure 5 shows the estimated variable importance rankings corresponding to the model. The weekly mean temperature variable, which had a value of 0.440, contributed the most in this model, followed by the vulnerable occupational groups (0.129), weekly median temperature (0.102), floating population (0.098), and weekly max temperature (0.085) variables. These five variables can be considered the main variables for prediction, whereas the rest are less important. Interestingly, the variable importance rankings proved that the floating population variable, which changes with time, had a greater effect on prediction than the population of registered residents. However, regional economic indicators had less impact on diseases related to heatwaves, as observed from the low values for income (0.020) and insurance (0.013).

**Figure 5.** Variance importance in the RF model.

### *3.5. Model Application and Visualization*

To apply the validated model, we predetermined the dynamic variables from predictions. Because the time series of variables and result values were the same, the predicted variable values must be used for prediction. With regard to static variables, we employed the latest data as inputs among the information that is updated periodically, which is the same as in model learning. The values of the dynamic population were replaced with dynamic variables using weather forecast data provided by

KMA on a weekly basis and a time series forecasting library, called Prophet [47], which is provided by Facebook.

The performance results and visualization of the model are shown in Figure 6. From the end of May, which was when heatwave management began, the substituted variables were inputted into the model for 4 weeks, and then the predicted values were obtained and compared with the observed values obtained from KDCA on the weekend. The forecasted and observed values for Seoul were compared for 4 weeks, and in the second week of June, the predicted values for each administrative district of South Korea were numerically quantified to visualize the high-risk areas and provide information to heatwave disaster response decision makers. The high-risk areas are shown in dark colors, whereas the lower risk areas are shown in lighter colors. Considering objectivity by region, we used the number of patients and the predicted floating population ratios to calculate the risk. Four weeks of data were applied to real-world situations, resulting in an *R*<sup>2</sup> value of 0.70.

**Figure 6.** (**a**) Visualization of predicted number of heatwave-related patients in the second week of June 2020 in South Korea. (**b**) Predicted and observed data for number of heatwave-related patients in Seoul over a month.

### **4. Discussion and Conclusions**

Heatwave damage prediction has been investigated in the United States, Europe, and Asia [9,48–50]. However, existing predictions are limited to practical notification systems owing to unrepresentative data and insufficient data accuracy [51]. According to previous research, this problem is due to the use of heatwave mortality alone as the endpoint of damage. Because the mortality rate of heatwaves is exceedingly small compared to that of the general population, it is more effective to predict risk by morbidity, which is relatively higher in proportion than heatwave mortality.

"Temperatures exceeding 33 ◦C" is the only available criterion for identifying the danger of heatwaves in South Korea, which allows the government to raise risk awareness by alerting the public. However, the damage to the population (deaths and sickness) caused by heatwaves varies even at the same temperature. Therefore, using only temperature data cannot determine the level of damage to peoples' health. On the basis of epidemiological investigations performed in previous research, we selected relevant variables and evaluated them by the Boruta algorithm. Then, a random forest-based heatwave damage prediction method was proposed, and its performance was compared with other traditional models. Previous studies considering demographic information have mainly used data with static characteristics, such as monthly statistical information. More accurate predictions were achieved by matching the exposure to heatwaves in a specific area to the population in that area

using dynamic population data updated more frequently, allowing this variable to contribute more to the prediction.

In the evaluation of the importance of variables, the average temperature variable and the number of occupational groups that are considered to be vulnerable to heat waves were highly evaluated (the average temperature for a week is the sum of the week, that is, the accumulated temperature). This result supports the importance of predicting the cumulative temperature in advance and responding in advance in order to minimize heat damage. In addition, it provides grounds that preemptive responses from the government, such as operation of sprinkler trucks and installation of shade curtains, should be made in areas with many vulnerable occupations. The learning process performed to build the machine learning model used independent variables based on previously recorded data. However, it is difficult to apply big data to a real-time environment owing to limitations such as irregularity in the frequency of data. Furthermore, when the time series of the dependent and independent variables are configured identically in training, the conditions for predictions are not effectively established in practical systems. Therefore, a method that employs predicted variable values was proposed. Although the accuracy of predicting future patients by applying predicted data is lower than the test accuracy during validation, the *R*<sup>2</sup> value of 0.70 supports the fact that this model provides reasonable information. Governments can use the methods developed in this study to provide disaster response decision makers with a reasonable basis for prioritizing an administrative area to provide a preemptive response and disaster support.

Nonetheless, this study has several limitations. First, the temporal resolution of the predicted values is relatively coarse-grained; thus, it is impossible to provide daily predictions for heatwaves, which are expected to occur every day. Most studies on heatwaves in South Korea have provided weekly information [11,49]. As a result, it was inferred that the resolution of target values (heat patients) is tailored to the minimum information time unit of data source. Secondly, during the experiment, data obtained across the country were input into one learning algorithm, and regional differences between administrative districts were not considered. South Korea is a relatively small country; although the regional environmental difference is relatively smaller than that in larger countries, the differences in geography and weather between its eastern and western regions are substantial. Therefore, developing individual algorithms for each region can improve model performance. This problem can be solved because the model can be trained for each region if sufficient datasets are available. Since 2018, South Korea has been managing vulnerable populations by operating heatwave shelters. The prediction model established in this study will contribute to future studies to select regions at risk of heatwaves and provide decision makers with a basis for installing heat shelters using high regional resolutions and estimating the cumulative number of patients relative to the population.

**Author Contributions:** Idea development and original draft writing, M.P.; project administration, D.J.; draft review and editing, S.L.; supervision and funding acquisition, S.P. All authors have read and agreed to the published version of the manuscript.

**Funding:** This research was supported by a grant (2019-MOIS31-011) from the Fundamental Technology Development Program for Extreme Disaster Response funded by the Ministry of Interior and Safety, Korea, and supported by the Korea Ministry of Land, Infrastructure and Transport (MOLIT) as an Innovative Talent Education Program for Smart City.

**Conflicts of Interest:** The authors declare no conflict of interest.

### **References**


**Publisher's Note:** MDPI stays neutral with regard to jurisdictional claims in published maps and institutional affiliations.

© 2020 by the authors. Licensee MDPI, Basel, Switzerland. This article is an open access article distributed under the terms and conditions of the Creative Commons Attribution (CC BY) license (http://creativecommons.org/licenses/by/4.0/).

## *Article NowDeepN***: An Ensemble of Deep Learning Models for Weather Nowcasting Based on Radar Products' Values Prediction**

**Gabriela Czibula 1,\*,†, Andrei Mihai 1,† and Eugen Mihule¸t <sup>2</sup>**


**Abstract:** One of the hottest topics in today's meteorological research is weather nowcasting, which is the weather forecast for a short time period such as one to six hours. Radar is an important data source used by operational meteorologists for issuing nowcasting warnings. With the main goal of helping meteorologists in analysing radar data for issuing nowcasting warnings, we propose *NowDeepN*, a supervised learning based regression model which uses an ensemble of *deep artificial neural networks* for predicting the values for radar products at a certain time moment. The values predicted by *NowDeepN* may be used by meteorologists in estimating the future development of potential severe phenomena and would replace the time consuming process of extrapolating the radar echoes. *NowDeepN* is intended to be a proof of concept for the effectiveness of learning from radar data relevant patterns that would be useful for predicting future values for radar products based on their historical values. For assessing the performance of *NowDeepN*, a set of experiments on real radar data provided by the Romanian National Meteorological Administration is conducted. The impact of a *data cleaning* step introduced for correcting the erroneous radar products' values is investigated both from the computational and meteorological perspectives. The experimental results also indicate the relevance of the features considered in the supervised learning task, highlighting that the radar products' values at a certain geographical location at a time moment may be predicted from the products' values from a neighboring area of that location at previous time moments. An overall *Normalized Root Mean Squared Error* less than 4% was obtained for *NowDeepN* on the cleaned radar data. Compared to similar related work from the nowcasting literature, *NowDeepN* outperforms several approaches and this emphasizes the performance of our proposal.

**Keywords:** weather nowcasting; machine learning; deep neural networks; autoencoders; Principal Component Analysis

### **1. Introduction**

Weather nowcasting [1,2] refers to short-time weather prediction, namely weather analysis and forecast for the next 0 to 6 h. Nowadays, the role of nowcasting in crisis management and risk prevention is increasing, as more and more severe weather events are expected [3]. Large volumes of meteorological data, including radar, satellite and weather stations' observations, are held by meteorological institutes and available for analysis. Radars and weather stations are constantly collecting real-time data, while data about cloud patterns, winds, temperature are continuously gathered by weather-focused satellites. Thus, there is large amount of meteorological related data available to be analyzed using machine learning (ML) based algorithms for improving the accuracy of short-term weather-prediction techniques.

The World Meteorological Organization (WMO) [4] mentions that "nowcasting plays an increasing role in crisis management and risk prevention, but its realization is a highly

**Citation:** Czibula, G.; Mihai, A.; Mihule¸t, E. *NowDeepN*: An Ensemble of Deep Learning Models for Weather Nowcasting Based on Radar Products' Values Prediction. *Appl. Sci.* **2021**, *11*, 125. https://doi.org/10.3390/ app11010125

Received: 28 October 2020 Accepted: 18 December 2020 Published: 24 December 2020

**Publisher's Note:** MDPI stays neutral with regard to jurisdictional claims in published maps and institutional affiliations.

**Copyright:** © 2020 by the authors. Licensee MDPI, Basel, Switzerland. This article is an open access article distributed under the terms and conditions of the Creative Commons Attribution (CC BY) license (https:// creativecommons.org/licenses/by/ 4.0/).

complex task", the highest difficulties being related to the small-scale nature of convective weather phenomena. This implies that the Numerical Weather Prediction approach (NWP) [5] is not feasible and that the forecast has to rely mainly on the extrapolation of known weather parameters. As meteorological institutes worldwide expect climate changes including extreme rain phenomena [3], there is an increasing need for accurate and early warning of severe weather events. Considering the increased number and intensity of severe meteorological phenomena, predicting them in due time to avoid disasters becomes highly demanding for meteorologists.

Mainly due to the extremely large volume of data that has to be analyzed in a short period of time, issuing a nowcasting warning is a complex and difficult task. For operative meteorologists it is difficult to issue nowcasting warnings, as there is a large volume of meteorological data (radar, satellite, or other weather stations' observations) that has to be analyzed in order to take an appropriate decision. Besides, given the stochastic and chaotic character of the atmosphere, the evolution of certain weather phenomena are difficult to predict by human experts. Thus, machine learning (ML) and *deep learning* [1,2] techniques are useful for assisting meteorologists in the decision-making process, offering solutions for nowcasting by learning relevant patterns from large amount of weather data. Most of the existing operational and semi-operational methods for nowcasting are using the extrapolation of radar data and algorithms mainly based on cell tracking. Existing nowcasting techniques use various data sources which may be relevant for accurate nowcasting, such as: meteorological data (radar, satellite, meteorological observations) and geographical data (elevation, exposure, vegetation, hydrological features, anthropic features).

In the current study we used real radar data provided by one of the WSR-98D weather radars [6] of the Romanian National Weather Administration. Given its capability to determine the location, size, direction and speed of water droplets, the weather radar is an essential tool used by meteorologists for nowcasting. For this type of forecast, fast decisions are imposed, so with the aim of facilitating the analysis in an operative environment, the data retrieved by the radar is supplied to the meteorologist in the form of coloured maps, which are easy to assess on a brief overview. On such a map, each pixel corresponds to a geographical location and its colour represents a certain value of the displayed product. In a short period of time, commonly under one minute, the meteorologist can locate potential dangerous storm cells, analyse their vertical structure, relative speed and direction, the dimension of hail, top of the clouds, and other relevant parameters. At this point, the operator should appreciate the current phase of the storm and by extrapolating those values in time, usually up to one hour, predicts the future development (intensity and area affected) of the storm. Although extrapolating the radar echoes is one of the main techniques used in nowcasting, it presents weaknesses in terms of processing times and precision, all related to the skills and experience of the meteorologist.

Most of the currently existing operational and semi-operational methods for nowcasting use the extrapolation of radar data and algorithms mainly based on cell tracking [7–9]. However, an important limitation of existing centroid cell-tracking algorithms lies in the detection of storms segments having irregular shapes or with variable wind speeds, resulting in identification and tracking errors.

With the goal of helping meteorologists in analysing radar data for issuing nowcasting warnings, we are introducing in this paper a supervised learning model *NowDeepN* based on an ensemble of deep neural network (DNN) regressors for predicting the values for radar products which may be used for weather nowcasting. As an additional goal, we aim to empirically validate the hypothesis that similar values for the radar products at a given time moment for a certain geographical region are encoded in similar neighborhoods of that region at previous time moments. The values predicted by *NowDeepN* for the radar products at certain time moments may be used by the meteorologist in estimating the future development of potential severe phenomena and thus *NowDeepN* would replace the time consuming process for the operational meteorologists of extrapolating the radar echoes.

As a proof of concept, *NowDeepN* is proposed for learning to approximate a function between past values of the radar products extracted from radar observations and their future values. Experiments will be performed on real radar data provided by the Romanian National Meteorological Administration and collected on the Central Transylvania region. Our experimental goal is to obtain an empirical evidence that in both normal and severe weather conditions the values for a radar product at a given moment in a certain location are predictable from the values of the neighboring locations from previous time moments. If this stands, the study may be further improved and extended on a larger scale. In addition, we are proposing a *data cleaning* step useful for correcting the erroneous input radar data. To the best of our knowledge it is the first time an approach such as *NowDeepN* is introduced in the nowcasting literature.

To summarize the contribution of this paper, our goal is to answer the following research questions:


The remainder of the paper is structured as follows. The problem of weather nowcasting, its importance and main challenges are discussed in Section 2. Section 3 discusses about the fundamental concepts regarding the used machine learning models, whilst a literature review on supervised learning based weather nowcasting is presented in Section 4. Section 5 introduces our data model and the *NowDeepN* hybrid model for predicting the values for radar meteorological products using deep neural networks. Section 6 presents the experimental setting and results, while an analysis of the obtained results and their comparison to related work is provided in Section 7. The conclusions of our paper and future improvements and enhancements of *NowDeepN* are outlined in Section 8.

### **2. Weather Nowcasting**

The World Meteorological Organization (WMO) [4] mentions that "nowcasting plays an increasing role in crisis management and risk prevention, but its realization is a highly complex task", the highest difficulties being related to the small-scale nature of convective weather phenomena. This implies that the Numerical Weather Prediction approach is not feasible and that the forecast has to rely mainly on the extrapolation of known weather parameters. In this context, there is a high need for automated tools and systems supporting the nowcasting meteorologist, so significant research and development have been carried out on the topic of nowcasting. In spite of those efforts, both the operative personnel in nowcasting and the beneficiaries (population, relevant public institutions) are in demand of even more efficient products and services concerning the weather forecast itself and the alerting system.

Most of the currently existing operational and semi-operational methods for nowcasting are using the extrapolation of radar data and algorithms mainly based on cell tracking. Dixon and Wiener developed a real-time cell tracker named TITAN [10] which is useful for single cells. A centroid method is used in Reference [10] to identify storm entities in consecutive radar scans and estimate future movement of the storm centroid by minimizing a cost

function. SCIT was proposed by Johnson et al. [11] as a cell-tracking algorithm with a more more complex cell detection method. SCIT used reflectivity thresholds and a distance function between cells for tracking and estimating future positions. An operational nowcasting tool called TRT was developed by Hering et al. [7] and used by Meteo Swiss. TRT is a centroid cell-tracking method using radar data and, in addition to previously mentioned methods, a multiple sensor system. Jung and Lee [12] introduced a cell-tracking algorithm using fuzzy logic based on a large set of historic radar data, with the goal of minimizing errors induced by single features. Germany's National Meteorological Service DWD uses a nowcasting system called NowCastMIX developed by James et al. [8]. NowCastMIX employes remote and ground observations analyzed using fuzzy logic rules. AROME [13] is a small scale numerical prediction model, which is used by Meteo-France since 2008. Measurements used by AROME include measurements of the precipitation systems, of low level humidity, low-level wind, upper level wind and temperature and it provides a forecast for up to 30 h. AROME-NWC [9] was developed in 2015 on top of AROME, for nowcasting in the range of 0–6 h. The INCA system [14] was developed specifically for mountainous regions with the goal of forecasting for precipitation amounts, temperature, wind and convective parameters.

An important limitation of existing centroid cell-tracking algorithms lies in the detection of storms segments having irregular shapes or with variable wind speeds, resulting in identification and tracking errors. Errors also occur when storm cells are clustered together and the algorithm detects them as a single larger cell. In other cases, a single storm cell is detected as two or more cells at the same horizontal location but on different altitude levels. Given the high spatial and temporal resolution, NowCastMIX [8] estimates of severe storms change quickly with the assimilation of new observations, leading to difficulties in reasoning for meteorologists. It is also prone to an overestimation of the likelihood of severe convection. According to Reference [14], the INCA system provides high accuracy for temperature but comparatively low for wind and precipitation given the relatively low distribution of relevant stations in mountainous regions. The systems described in References [8,9,14], which are considered to be some of the most performant automated nowcasting systems at the present time are highly customized for their respective location and infrastructure, thus their performance and adaptability to other contexts is unclear.

Several industrial players have also shown an interest in the problem of weather forecasting and have developed various solutions in this direction: IBM Deep Thunder [15] aims to provide high-resolution weather prediction for a variety of applications, Panasonic Global 4D Weather [16] is a weather forecasting platform that uses Panasonic patented atmospheric TADMAR sensor collected data, ClimaCell company [17] offers weather forecasting solutions tailored for various weather-sensitive industries while TempoQuest [18] is a company putting on the market a proprietary forecasting software for commercial users and government agencies.

#### **3. Machine Learning Models Used**

This section reviews the fundamental concepts regarding the machine learning models used in this paper: deep neural networks, *autoencoders* and *t-Distributed Stochastic Neighbor Embedding*.

*Supervised learning* is a subfield of machine learning, dealing with the task of approximating a mapping from some input domain to some output domain based on input-output example pairs. The data set consisting of the pre-existing examples of input-output pairs is called the *training data set*. A *supervised learning algorithm* generalizes the training data, producing a function that, given an input outside of the training data set, can return a close enough approximation of the correct output. What can be considered a "good enough approximation" is dependent on the specific problem. The vast quantity of data available nowadays as well as the increasing power of computation make the supervised learning methods an extremely useful tool that can be used in a wide range of domains.

### *3.1. Deep Neural Networks*

*Neural network* learning methods provide a robust approach to approximating realvalued, discrete-valued or vector-valued target functions [19]. As a biological motivation, neural networks have been modeled to be similar to learning systems that we can find in humans and animals, namely complex networks of neurons. This morphology has been adopted in computer science, by building densely interconnected systems that have as building blocks basic units, that take as input a series of real-valued numbers and produce a single real-valued output [19]. These basic units are called artificial neurons. Neural networks are suited for problems that deal with noisy, complex data, such as camera, microphone or sensor data. Their success is due to their similarity to effective biological systems, that are able to generalize and associate data that has not been explicitly trained upon during the training phase, and correlate that data to a class where it belongs. Each neuron of the network has an array of parameters, based on which it processes the input data, called weights. The weights are adjusted during the training phase, based on the error of the network. The error represents the difference between the correct output and the network output. The learning algorithm used for adjusting the weights based on the error is the *backpropagation* algorithm.

Unlike classical neural networks, deep neural networks (DNNs) [20] contain multiple hidden layers and have a large number of parameters which makes them able to express complicated target functions, that is, complex mappings between their input and outputs [21]. Nowadays, DNNs are powerful models in the machine learning literature applied for complex classification and regression problems from various domains. Machine learning models, including deep ones, have been successfully applied for developing forecasting models such as for bank failures [22], prediction markets [23] or gambling [24]. Due to their complexity, large networks are slow to use and are prone to *overfitting*, which is a serious problem in DNNs. Overfitting is a major problem for supervised learning models, in which the model learns "by heart" the training data, but it does not have the capability to generalize well on the testing data. An overfit model is discovered through a very good performance on the training data, but a much lower performance on the testing set. A possible cause for overfitting in DNNs is the limited training data, as in such cases the relationships learned by the networks may be the result of sampling noise. Thus, these complex relationships will exist in the training data but not in real test data [21]. There are various methods for addressing overfitting and reducing it, such as—(1) stopping the training when the performance on a validation set starts decreasing; (2) introducing weight penalties through regularization techniques soft weight sharing [25]; (3) applying cross-validation; (4) extending the data set to include more training examples; and (5) *dropout* by randomly dropping neurons and their connections during training.

*Regularization* stands for an ensemble of techniques that have as purpose the simplification of the model, in order to avoid overfitting. Dropout is the regularization technique applied for neural networks. This process consists in deactivating some neurons during the training process, forcing the network to achieve the result by using a reduced (and simpler) neuron configuration. During prediction phase, the neurons will be reactivated. The selection of which neurons to keep active is done by a probability *p*, chosen arbitrarily, and dropped out by a probability of 1-*p* [21].

#### *3.2. Autoencoders and Principal Component Analysis*

*Autoencoders* (AEs) [20] are deep feed forward neural networks which aim to learn to reconstruct the input, being known in the machine learning literature as *self-supervised* learning systems. An AE has two main components: an *encoder* and a *decoder*. Assuming that the input space is <sup>R</sup>*n*, the encoder part learns a mapping *<sup>f</sup>* : <sup>R</sup>*<sup>n</sup>* <sup>→</sup> <sup>R</sup>*m*, while the decoder learns the function *<sup>g</sup>* : <sup>R</sup>*<sup>m</sup>* <sup>→</sup> <sup>R</sup>*n*. If *<sup>m</sup>* <sup>&</sup>lt; *<sup>n</sup>*, the network learns to compress the input data into a lower dimensional latent space and to reconstruct it based on the latent representation, by learning the essential characteristics of the data. For avoiding the overfitting symptom (i.e., simply copying the input to the output), L1 regularization is applied on the encoded state and the model is called a *sparse* one. Autoencoders have been successfully applied in various tasks ranging from image analysis [26] and speech processing [27] to protein analysis and classification [28,29].

Principal Component Analysis (PCA) is a dimensionality reduction statistical technique heavily used in the ML field for tasks such as descriptive data analysis, data visualization or data preprocessing. It can be seen as an unsupervised learning technique that learns to represent the data in a new, lower-dimensional, space. Given a data set with *n* samples (*x*1, ... , *xn*) and each sample having *p* attributes (*a*1, ... , *ap*), the PCA algorithm will search for linear combinations of the variables (i.e., ∑*<sup>p</sup> <sup>i</sup>*=<sup>1</sup> *ai* · *ci* with *c*1, ... , *cp* constants) such that they are linearly uncorrelated and that the first linear combination has the largest possible variance, the second has the largest possible variance while being orthogonal on the first linear combination and so on. These linear combinations are called *principal components*. The principal components can be found by computing the eigenvectors and eigenvalues of the covariance matrix of the data set, where the constants of each linear combination are given by the eigenvectors and the first linear combination is the one given by the eigenvector associated with the biggest eigenvalue, and so on [30].

#### **4. Literature Review on Supervised Learning Based Weather Nowcasting**

The literature contains various machine learning-based approaches for weather nowcasting. Relevant results obtained recently in predicting short-term weather are summarized in the following and the limitations of the existing solutions are emphasized.

Han et al. [31] use Support Vector Machines (SVM) are trained on box-based features in order to classify whether or not a radar echo >35 dBZ will appear on the radar within 30 min. The approach uses both temporal and spatial features, derived from vertical wind, and perturbation temperature. Greedy feature selection was used in order to finally select these features. The obtained results were around 0.61 Probability Of Detection (POD), 0.52 False Alarm Ratio (FAR), 0.36 Critical Success Index (CSI), which outperformed the Logistic Regression, J48, Adaboost and Maxent approaches compared against on the same data set.

Beusch et al. [32] present the COALITION-3 ("Context and Scale Oriented Thunderstorm Satellite Predictors Development") algorithm developed by MeteoSwiss. Its goal is to identify, track and nowcast the position and development of storms in a robust manner. In order to do this, it employs optical flow methods in combination with winds predicted by COSMO. In order to separate the movement of the storm from its temporal evolution, Lagrangian translation is used. Following this, the algorithm estimates future intensification and quantifies the probability and severity of the different risks associated with thunderstorms. This information is extracted by applying machine learning techniques to a data archive containing observations from the MeteoSwiss dual polarized Doppler radar system and information from other available systems.

Shi et al. [33] introduce an extension of the Long-Short Term Memory Network (LSTM), named ConvLSTM. Through experiments done on Moving MNIST and Radar Echo Dataset, the authors proved their method suitable for spatiotemporal data, having all the perks which come with a simple LSTM, preserving spatiotemporal features due to the inherited convolutional structure. Besides the above mentioned aspect other advantages of ConvLSTM over a basic LSTM are: transitions input-to-state and state-to-state are made in a convolutional manner, deeper models can produce better results with a smaller number of parameters. Their proposed architecture contains 2 networks, an encoder composed of 2 ConvLSTM layers and a forecasting network containing the same number of ConvLSTM layers and an additional 1 × 1 convolutional layer to generate final predictions.

Kim et al. [34] have also proposed a ConvLSTM-based model for precipitation prediction, but they used three-dimensional and four-channel data unlike Shi et al. [33], who used three-dimensional and only one-channel data. The four channels correspond to four altitudes. The proposed model, called DeepRain, predicts, based on radar reflectivity data, the amount of rainfall on a large scale. The experimental evaluation has proved a decrease of

root mean square error (RMSE) with 23% when compared to linear regression and also a superior performance as compared to a fully connected LSTM. For rainfall prediction a RMSE value of 11.31 has been obtained on the test set.

Heye et al. [35] present a practical solution for leveraging the precipitation nowcasting problem starting from how the data is stored and preprocessed, the deep learning model used and up to the necessary frameworks and hardware. They used the ConvLSTM with peepholes cells described by Shi et al. [33] in an architecture inspired from those used for machine translation. The model is composed of an encoder and a decoder each made out of four ConvLSTM layers. They experimented with both taking the last step of the encoder and using attention for transferring information to the decoder, with the former producing better results in terms of Probability of Detection and Critical Success Rate and worse for the False Alarm Rate. The decoder uses both the encoded actual reflectivity and the predicted reflectivity. For accelerating learning, during training the ground truths are fed back into the decoder, while for evaluation the prior predictions are used. Additionally they noticed that using as start symbol a tensor of ones which represents high reflectivity when data is scaled to [0, 1] improves Probability of Detection and Critical Success Rate.

Shi et al. [36] proposed a benchmark for the nowcasting prediction problem and a new model. The benchmark includes a new data set, 2 testing protocols and 2 measurements for training, Balanced Mean Squared Error and Balanced Mean Absolute Error. Those measurements were necessary due to an imbalancement between the small amount of times where potentially dangerous events occured and normal amount of rainfall. The proposed model, Trajectory Gated Recurrent Unit (TrajGRU), deals better than Convolutional GRU, with the representation of location variant relationships. TrajGRU allows aggregation of the state along a learned trajectories. The architecture consist of an encoder and a forecaster part, inserting between RNN layers downsampling or upsampling operations depending upon region of the architecture. Their proposal show to be flexible and superior than other methods. The experiments done on Moving MNIST and precipitation nowcasting HKO-7 dataset emphasize the ability of the model to capture the spatiotemporal correlation.

Narejo and Pasero [37] have proposed a hybrid model combining Deep Belief Networks (DBN) with Restricted Boltzmann Machine (RBM) to predict different parameters of weather—air temperature, pressure and relative humidity—at a local level (i.e., restricted to a particular geographical area). Initially, the RBMs are trained unsupervisedly. Subsequently, the already trained RBMs are stacked to create a DBN. The DBN is supervisedly trained so as to predict the weather parameters.

Sprenger et al. [38] approached foehn prediction using AdaBoost machine learning algorithm. The authors motivate the choice of the model through the reasonableness of the balance between the predictive power, the computational speed and the interpretability of the results. Being trained with three years of hourly simulations data and using a modified decision stumps as weaker learners, the model achieved, on a validation data set, 0.88 sensitivity (or probability of detection) and 0.29 probability of false detection.

Yan Ji [39] approached the problem of short-term precipitation prediction from radar observations, using *artificial neural networks*. The nowcasting of the rain intensity was carried out on radar raw data and rain gauge data collected from China from 2010 to 2012. The reflectivity values were extracted from the raw data and were interpolated a 3D rectangular lattice grid of 1 km × 1 km in horizontal direction at the height of 1.5 and 3 km [39]. The collected data set was afterwards used for training the predictive model. A *root mean squared error* less than 5 (a minimum of 0.97 and a maximum of 4.7) was obtained. The experiments have shown a correlation coefficient *R* in the radar-rainfall estimation more than 0.6 and indicate that the accuracy rate of 36mins forecast is higher than 50% [39].

In a recent approach, Tran and Song [40] proposed a new loss function for the convolutional neural network (CNN) based models used for weather nowcasting. Using a computer vision perspective, the authors used Image Quality Assessment Metrics as loss functions in training, finding that the Structural Similarity function performed better than both MSE and MAE, especially by improving the quality of the images (i.e., the output

images were much less blurry). However, the best performance was achieved by combining the Structural Similarity with the MAE and MSE loss functions.

Han et al. [41] proposed a CNN model for convective storms nowcasting based on radar data. The proposed model predicted whether radar echo values will be higher than 35 dBZ in 30 min, thus modelling the problem as a classification problem. The authors modeled the input radar data as multi-channel 3D images and proposed a CNN model that performs cross-channel 3-D convolution. In their model, the output is also an 3D-image, where each point of the image is a 0 if radar echo is predicted to be ≤35 dBZ in 30 min and 1 otherwise. A Critical Success Index (CSI) score of 0.44 was obtained.

Yan et al. [42] proposed a model for precipitation nowcasting employing a convolutional architecture using multihead attention and residual connections (MAR-CNN). The data fed into the network consists of radar reflectivity images on three elevation levels and other numerical features, such as cloud movement speed. In order to deal with the unbalanced classes of precipitations, extreme meteorological events have been oversampled. The proposed model outperformed several deep learning baselines obtained by using only several of the MAR-CNN components—a dual channel convolutional attention module, a dual channel convolutional model, a single channel convolutional model, as well as a GBDT and an SVM.

#### *Limitations of Existing Approaches*

Han et al. [31] highlight that support vector machines require feature reductions in order to be more accurate, so they might not be able to make use of all the information that a data set could provide. They also argue for the importance of collecting more data in order to train the machine learning models better. Another concern is that of feature selection: while other informative features do exist in the literature, it is not always straightforward to get them to a form usable for machine learning.

As most works [31,32] show, each system is mostly tested with some form of local real world data. This does not allow to accurately extrapolate a system's ability to adapt to other locations and their specific meteorological events and trends.

A general concern which later on can morph into a limitation is the possible data imbalance. The performance of the proposed methods can be diminished by the small number of risky weather conditions compared with normal rain [35,36]. The results from Reference [33] might not reflect the overall competence of ConvLSTM due to the fact that Shi et al. used a rather small data set for testing and and a low threshold for rain-rate. Another limitation mentioned [35] is the fact that predictions tend to lower values of precipitations in time due to previously less confident predictions being fed back into the recurrent network. The architecture proposed by Reference [35] is limited to the train data and does not learn from new data. A concern experienced by others when using the proposed TrajGRU cell [36] is the speed. The implementation of this operator turns out to be slower than a simple ConvGRU.

A common limitation of the solutions proposed by Reference [34] and Reference [37] is compromising the interpretability of the prediction results. The boosting-based solution [38] alleviate the issue of interpretability, but Dietterich has highlighted in one of his studies [43] the sensitivity to outliers as a particular disadvantage of boosting. Outliers can negatively affect the final predictions since they might be excessively weighted during the boosting steps. An additional limitation of most of the existing solutions is that they do not combine multiple data sources thus being deprived by an expected enhancement of the predictive capability [44].

### **5. Methodology**

We further introduce our *NowDeepN* approach for weather nowcasting using *deep neural networks*. We start by describing in Section 5.1 the raw radar data used in our experiments. Section 5.2 introduces a theoretical model on which *NowDeepN* is based on,

then Section 5.3 presents our proposal. The section ends with the evaluation measures we will use for assessing the predictive performance of *NowDeepN*.

#### *5.1. Radar Data*

The experiments in the current study were conducted on real radar data provided by one of the WSR-98D weather radars [6] of the Romanian National Weather Administration. The WSR-98D is a type of doppler radar used by meteorological administrations for weather surveillance, capable of remote detection of water droplets in the atmosphere, that is, clouds and precipitations, retrieving data on their location, size and motion. The WSR-98D scans the volume of air above an area over 70,000 square kilometers, and about every 6 min a complete set of about 30 base and derived products for 7 different elevations is being collected. The base products are *particle reflectivity* (*R*), providing information on particle location, size and type, and *particle velocity* (*V*), supplying information on particle motion, that is, direction and speed relative to the radar. Both products are available for several elevation angles of the radar antenna, and for each time step a set of seven data products, *R*01–*R*07 and *V*01–*V*07, is delivered, each of them corresponding to a certain tilt of the antenna. Among the derived products, of particular interest for this study is *VIL* (vertically integrated liquid), an estimation of the total mass of precipitation above a certain unit of area. The data in the NEXRAD Level III files is stored in a gridded format, each point of the grid corresponding to a geographical location and containing the value of a certain product at the respective time frame. In the data grid provided by the WSR-98D radar, the OX axis contains the longitude values, while the OY axis contains the latitude values.

### *5.2. Data Model*

As shown in Section 1, the raw data provided by the radar scans during one day (24 h) on a certain geographic region was exported in the form of a sequence of matrices of *m* × *n* dimensional matrices, one matrix corresponding to a certain time moment *t* and a certain meteorological product *p* (i.e., each element from the matrix represents the value for the product *p* at a certain location from the map). As a set *Prod* of multiple meteorological products are provided by the radar, the radar data collected at a time *t* may be visualized as a 3D data grid in which the OZ axis corresponds to the radar products.

For instance, Figure 1 depicts a sample 3D grid with *m* = 2 rows, *n* = 2 columns and three products (*Prod* = {*R*01, *R*02, *R*03}) recorded at a certain time stamp *t*. In the figure, the values for *R*01 are in the front matrix, *R*02 values are in the middle one and *R*03 in the matrix behind.


**Figure 1.** A sample 3D data grid.

During one day, a sequence of 3D data grids (as shown in Figure 1) corresponding to various time stamps is provided by the radar. Assuming that the radar records data every 6 min, 240 3D data grids are provided. For a certain location (*i*,*j*) on the map, a time moment *t* and a set *Prod* of radar products, we are denoting by *Vt*(*i*, *j*, *l*, *Prod*) the vector representing the linearized 3D data subgrid containing the radar products' values for a neighboring area of a certain length *l* surrounding the point (*i*,*j*), at time moment *t*.

As an example, let us consider a 5×5 dimensional data grid, the set *Prod* = {*R*01, *R*02} of radar products, a time stamp *t*, the location (3, 3) (*i* = 3, *j* = 3) and a length *l* of 3 for the neighboring subgrid. Figure 2 depicts the 3D data grid containing values for *R*01 (front) and *R*02 (behind) for each cell from the data grid surrounding the point (3, 3), at time stamp

*t*. The point of interest as well as the 3D data subgrid of length 3 surrounding the point (3, 3) are highlighted.

**Figure 2.** The 3D data grid at time stamp.

The linearized 3D data subgrid (highlighted in Figure 2) is the 18-dimensional vector *Vt*(3, 3, 3, *Prod*)=(15, 0, 0, 10, 10, 15, 20, 10, 40, 5, 25, 30, 25, 25, 20, 40, 15, 20).

### *5.3. NowDeepN Approach*

The regression problem we are focusing on is the following—to predict a sequence of values for a set *Prod* of radar products at a given time moment *t* on a certain location (*i*,*j*) on the map, considering the values for the neighboring locations of (*i*,*j*) at time moment *t*−1. *NowDeepN* approach consists of three main stages which will be further detailed: *data collection and cleaning*, *training* (building the learning model) and *testing*. *NowDeepN* uses an ensemble of DNNs for learning to predict the values of the radar products from the set *Prod* based on their historical values. The ensemble consists of *np* DNNs (*np* = |*Prod*|), one DNN for each radar product. We started from the intuition that using one network for each product would be more effective than using only one network for predicting all *np* values, as the mapping learned by the model should be specific to each radar product. Thus, we consider that the effectiveness of the learning process will be increased by using a DNN for each radar product and this will be empirically sustained by the experimental results (Section 6).

#### 5.3.1. Data Collection and Cleaning

We mention that in the data set used in our case study, values for *R* and *V* products are available for only six elevations (i.e., *R*01–*R*04 and *R*06–*R*07, *V*01–*V*04 and *V*06–*V*07). The other three elevations delivered by the radar are missing since they are not regularly used in operational services, thus they are not stored in the same format as the rest of the elevations. The data gathered by the radar and exported as shown in Section 5.1 contains a special value that represents "No Data". This value is usually represented by −999 but we decided to replace it with 0 as in most cases this value refers to air particles with 0 reflectivity (i.e., no significant water droplets). "No data" may also represent air volumes which have returned no signal, for example if a sector with high reflectivity is between the radar and the respective location. In this case, replacing it with 0 is also correct, since the entire region is obturated and the data is not relevant for the learning process [45]. The radar data is also prone to different type of errors, meteorological and technical, which implicitly are to be found in the output data matrix. Meteorological errors (e.g., the underestimation of a particle's reflectivity) are difficult to identify and eliminate, but some errors occurring during the data conversion have been identified for *V*. For instance, the product *V* should only contain values from −33 to 33 but we found *invalid* values which are outside the range [−33, 33], such as −100. From a meteorological point of view, those erroneous values correspond to radar uncertainties in evaluating the direction and/or the speed of the particle, and are not taken into account in

operative service since they are punctiform values and are irrelevant to the characteristics of a region.

We note that the cleaning process which is further described will be applied for the *V* values at each degree of elevation (i.e., *V*01–*V*04 and *V*06–*V*07). Thus, when we are using *V* in the following, we refer to the *V* value at a certain degree of elevation.

For reducing the noise that the invalid values of *V* represent, a *data cleaning* step is proposed. The underlying idea behind the cleaning step is to replace the invalid values of *V* on a certain point (*i*, *j*) with the weighted average of the valid *V* values from a neighborhood of length 13 surrounding the point. The weight associated to a certain neighbor of the point is inverse proportional to the Euclidian distance between the neighbor and the point, such that the closest neighbors' values have more importance in estimating the value of point. The reason for this cleaning step is that, from a meteorological viewpoint, *V* determines the direction and speed of air volumes, thus indicating neighbours that are more relevant for future value of points. The length 13 surrounding the point represents about 5 km in the physical world, a distance which commonly determines small gradients of the meteorological parameters.

Let us consider that (*i*, *j*) is the point having an erroneous value for *V* (e.g., −100) and this value has to be replaced with an approximation of its real value. We are denoting by *V*(*x*, *y*) the value of the product *V* for the point (*x*,*y*) and by N*l*(*i*, *j*) the 2D data subgrid representing the neighborhood of length *l* surrounding (*i*, *j*). For instance, the neighborhood N3(3, 3) of length 3 surrounding the point (3, 3) from the data matrix from Figure 3 is depicted in Figure 4.


**Figure 3.** The sample 2D data grid. The values for the product *V* are displayed.


**Figure 4.** The 2D data subgrid representing N3(3, 3).

For a certain point (*x*,*y*) ∈ N*l*(*i*, *j*), (*x*, *y*) = (*i*, *j*), we denote by *simij*(*x*, *y*) = √ <sup>1</sup> (*x*−*i*)2+(*y*−*j*)<sup>2</sup> the "similarity" between (*i*,*j*) and its neighbor (*x*,*y*). Certainly, the data points closer to (*i*,*j*) have a higher similarity degree and their value is more relevant in the cleaning process. Thus, the value *V*(*i*, *j*) will be approximated with the weighted average of its valid neighbors, as shown in Formula (1)

$$V(i,j) = \sum\_{\substack{(x,y) \in \mathcal{N}\_l(i,j) \\ (x,y) \text{ valid}}} (w\_{ij}(x,y) \cdot V(x,y)),\tag{1}$$

where *wij*(*x*, *y*) represent the weight of point (*x*, *y*) and is computed as shown in Formula (2) by normalizing the similarity values *simij*(*x*, *y*) such that ∑ (*x*,*y*)∈N*l*(*i*,*j*) (*x*,*y*)*valid wij*(*x*, *y*) = 1. We note

that this normalization assures that the approximated *V* values represent valid ones, that is, ranging in the interval [−33, 33].

$$w\_{ij}(\mathbf{x}, y) = \frac{\text{sim}\_{ij}(\mathbf{x}, y)}{\sum\_{\substack{(\mathbf{x}', y') \in \mathcal{N}\_l(i, j) \\ (x', y') \text{valid}}} \text{.} \tag{2}$$

As previously mentioned, in the cleaning process we considered a length *l* = 13 for the neighborhood. But, if all data points from the neighborhood of length *l* (N*l*(*i*, *j*)) are invalid, we are incrementally increasing *l* until the neighboring area will contain at least one valid point. In this case, Formula (1) is applied again using the new length *l* for estimating the value *V*(*i*, *j*).

After the data was cleaned as previously shown, the data set is prepared for further training the *NowDeepN* regressor, using the representation described in Section 5.2. We denote, in the following, by *Prod* = {*p*1, *p*2, ... , *pnp*} the set of radar products we are using in our approach. The radar data set cleaned as previously described is split into *np* subsets *Dk*, 1 ≤ *k* ≤ *np*, a data set corresponding to each radar product. A DNN will be afterwards trained on each *Dk*, for learning to predict the value of the radar product *pk* at a time moment *t* on a certain geographical location *l*, based on the radar products' values from the neighborhood of *l* at time *t* − 1.

A training example from a data set *Dk*, 1 ≤ *k* ≤ *np* is in the form < *xk*, *yk* >, where:


Each data set *Dk* will contain, for each data point from the analyzed map, examples in the form < *xk*, *yk* > (as previously described). As a preprocessing step before training, the data sets *Dk* are normalized to [0,1], using the *min*-*max* normalization.

### 5.3.2. Building the *NowDeepN* Model

Using the data modelling proposed in Section 5.2 and previously described, we aim to build a supervised learning model *NowdDeepN* consisting of an ensemble of DNNs for expressing *np* functions (hypotheses) *hk*, 1 ≤ *k* ≤ *np* such that *h*(*xk*) ≈ *yk*.

One of the difficulties regarding the regression problem previously formulated is that the training data sets *Dk* built as shown in Section 5.3.1 are highly *imbalanced*. More specifically, there are a lot of training instances labeled with zero (i.e., *yk* = 0) corresponding to points on the map without specific weather events and a much smaller number of instances with a non-zero label (i.e., corresponding to a severe meteorological phenomenon). The imbalanced nature of the data may lead to a regressor which is biased to predict zero values, as the majority of the training examples used for building the regressor were zero-labeled. A number of *np* DNN regressors *Nk*, 1 ≤ *k* ≤ *np* will be trained on the data sets *Dk*, such that the model *Mk* will learn to provide estimates for the radar product *pk*.

Each of these DNN regressors outputs a single value which represents the prediction for the next time step for that regressor's product. The output value is given by an linear activation function, while the hidden neurons use the ReLU activation function [46]. The loss we used was the mean squared error and the optimizer used was Adam [47]. For regularization we used one drop-out layer with the default Keras parameters.

### 5.3.3. Testing

For assessing the performance of *NowDeepN*, a *cross-validation* testing methodology is applied on each of the data sets *Dk*. The data sets *Dk* are randomly splitted in 5 folds. Subsequently, 4 folds will be used for training and the remaining fold for testing and this is repeated for each fold (5 times).

For each training-testing split, two evaluation measures are used and computed for each training-testing split: *Root mean squared error* (RMSE) and *Normalized root mean squared error* (NRMSE) [48]. The RMSE computes the square root of the average of squared errors obtained for the testing instances. The NRMSE represents the normalized RMSE, obtained by dividing the RMSE value to the range of the output and is usually expressed as a percentage. The regression related literature indicates NRMSE as a good measure for estimating the predictive performance of a regressor. Lower values for RMSE and NRMSE (closer to zero) indicate better regressors. For a more precise evaluation of the results, the values for the evaluation measures (RMSE and NRMSE) are also computed for the non zero-labeled instances (RMSE*non*−*zero*, NRMSE*non*−*zero*).

The RMSE and NRMSE values are computed for each data point from the grid (geographical area) and then are averaged over all grid points. As multiple experiments (training-testing data splits) are performed, the values for the evaluation measures were averaged over the 5 runs and a 95% *confidence interval* (CI) [49] of the mean value is computed.

#### **6. Experimental Results**

Experiments were performed by applying *NowDeepN* model on real data sets provided by the Romanian National Meteorological Administration and collected on the Central Transylvania region, using the methodology introduced in Section 5.

#### *6.1. Data Set*

This study uses data provided by the WSR-98D weather radar [6] located in Bobohalma, Romania and stored in the NEXRAD Level III format, as described in Section 5.1. The day used as case study is the 5th of June 2017, a day with moderate atmospheric instability manifested through thunderstorms accompanied by heavy rain and medium-size hail. In our study we selected an area from the central Transylvania region (parts of Mure¸s, Cluj, Alba and Sibiu counties) representing a grid having the geographical coordinates (46.076 N, 46.725 N, 23.540 E and 25.064 E). In the chosen geographical area, there were two distinct episodes with intense meteorological events in 5 June 2017: the first one between approximately 09:00 and 11:00 UTC, and the second one between approximately 12:00 and 17:00 UTC, with the most severe events taking place between 14:00 and 15:00 UTC. Concerning these phenomena, the National Meteorological Administration issued five severe weather warnings, code yellow.

The data grid provided by the radar for the selected geographical area at a given time moment is fit to a matrix. The radar provides one data matrix for each radar product. As stated in Section 5.1, the radar data is split into multiple time stamps, each time stamp representing data gathered by the radar every 6 min (the radar takes 6 min to gather the data for the area). The radar data used in our case study has been recorded between 00:04:04 UTC and 23:54:02 UTC.

In the current study, we are using only 13 products (i.e., *np* = 13): base *reflectivity* (*R*) of particles on six elevations (*R*01–*R*04, *R*06–*R*07) *velocity* (*V*) on six elevation (*V*01–*V*04, *V*06–*V*07) and the estimated quantity of water (*VIL*) contained by a one square meter column of air. Thus, the set of considered radar products is *Prod* = {*R*01, *R*02, *R*03, *R*04, *R*06, *R*07, *V*01, *V*02, *V*03, *V*04, *V*06, *V*07, *VIL*}. Accordingly, *NowDeepN* ensemble of regressors will predict 13 values, corresponding to the products previously enumerated. Our study uses only *R*, *V* and *VIL* products, as they are mostly used by meteorologists for weather nowcasting.

As mentioned in Section 5.2, we consider each point in the grid an instance. For each point we consider its neighbours in a certain radius. We decided to select a value of 13

for the length *l* of the neighborhood surrounding each point (see Section 5.3.1). More exactly, the 2D data subgrid representing the neighborhood for a point is a 13 by 13 matrix. The reason for choosing 13 as the dimensionality of the neighborhood is that it represents about 5 km in the physical world, which, from a meteorological view, is a common distance to determine small gradients. Thus for each instance we are considering a matrix of 13 by 13 points, each of which have 13 products. Therefore, for each instance we have 13 × 13 × 13 = 2197 attributes. For each timestamp we have a grid of the size 400 × 312. We only used the instances for which we could get the entire neighbourhood (i.e., where the neighbourhood matrix would not exceed the limit of the grid), thus obtaining 116.400 instances per timestamp. The day used as a case study contains 231 timestamps, thus of our data set consists of 26.888.400 instances. The data used in the experiments are publicly available at http://www.cs.ubbcluj.ro/~mihai.andrei/datasets/nowdeepn/.

### *6.2. Data Analysis*

In order to estimate the impact of the data cleaning step, we analyzed the data set before and after cleaning. For each of the data products *V*01, *V*02, *V*03, *V*04, *V*06 and *V*07 (which were possibly cleaned) and each time stamp *t*, 1 ≤ *t* ≤ 231, we computed the average values of the radar products for all the cells from the analyzed grid. Additionally, the average of the mean values for all *V* products and all the cells from the grid were calculated for each time stamp.

Figure 5 comparatively depicts the variation of the mean *V* value with respect to each time stamp (ranging from 1 to 231) for three cases: (1) *before the cleaning* step; (2) before cleaning but *ignoring the invalid values*; and (3) *after the cleaning*. A step of 10 was selected on OX axis (i.e., an hour). Graphical representations similar to the ones from Figure 5 were created for *V*01–*V*07, as well. The time series plots for *V*01 and *V*06 are illustrated in Figures 6 and 7, respectively.

Analyzing the plots from Figures 5–7 and comparing the evolution of values *before cleaning* (red coloured), before cleaning but ignoring *invalid values* (blue coloured) and *after cleaning* (green coloured) we observe the following. At lower degrees of elevations (see the plot for *V*01 from Figure 6) there are much more invalid values than at higher degrees of elevation (see the graph for V06 from Figure 7). Besides, from a meteorological viewpoint, higher degrees of elevation are related to higher altitudes, implying a less chaotic air circulation, leading to more precise radar soundings. The impact of the data cleaning step is noticeable from the time series plots, as for *V*01 the graph before data cleaning significantly differs from the graph after cleaning. The two graphs have very different shapes and this suggests that the noise introduced in data after the cleaning step is considerably smaller than the noise existing in data before replacing the invalid *V* values. We note that a similar situation has been observed for *V*02–*V*04 as well.

Much more, the time series plots before data cleaning but ignoring invalid values (blue coloured) resembles to the graph after the invalid values were cleaned (green coloured). Figure 8 depicts a zoomed-in version of the graphs from the upper side of Figure 6 (the time series before cleaning but ignoring the invalid values and time series after cleaning). At higher degrees of elevation (V06), there is no significant difference between the evolution of *V* values before and after cleaning (the red coloured and green coloured plots from Figure 7). The shapes of the two plots are very similar for *V*06 and this was also observed for *V*07.

**Figure 5.** Time series plot for mean *V* values: ignoring invalid values, before and after cleaning.

**Figure 6.** Time series plot for average *V*01 values: ignoring invalid values, before and after cleaning.

**Figure 7.** Time series plots for average *V*06 values: ignoring invalid values, before and after cleaning.

**Figure 8.** Time series plot for average *V*01 values: ignoring invalid values and after cleaning.

All the previous observations lead us to the hypothesis that the cleaning step would impact the overall performance of *NowdDeepN*, and this should be visible at least at lower degrees of elevations for *V*.

#### *6.3. Results*

This section presents the experimental results obtained by applying *NowDeepN* approach on the data set described in Section 6.1. For the DNNs used in our experiments, the implementation from the Keras deep learning API [50] using the Tensorflow neural networks framework was employed. The code is publicly available at Reference [51]. Given the fact that our data was quite high-dimensional, as mentioned in Section 6.1, we needed a relatively complex neural network. Each of The DNN regressors in the ensemble contains 12 hidden layer, with the following number of neurons: one layer with 200 neurons, one layer with 2000 neurons, 5 layers with 500 neurons and 5 layers with 100 neurons. These networks were trained for 30 epochs using 1024 instances in a training batch.

As stated at the beginning of the paper, we aim to answer research question RQ1 by assessing the ability of *NowDeepN* to predict the values for the radar products at a given moment in a certain geographical location from the values of its neighboring locations from previous time moments. Besides, we intend to analyze how correlated are our computational findings with the meteorological evidence. Table 1 depicts the obtained results together with their 95% CI. The columns of the table illustrate the evaluation measures computed for all 13 products (second column), the average values computed for all six *R* products (third column), the average values computed for all six *V* products (fourth column) as well as for *VIL* (fifth column). In order to allow an easier interpretation of the results from a meteorological perspective, we also illustrate in Table 1 the *Mean of Absolute Errors* (the average of the absolute errors obtained for the testing instances) for all instances (MAE), as well as only for the non-zero labeled instances (MAE*non*−*zero*).


**Table 1.** Experimental results obtained using *NowDeepN*. 95% CI are depicted.

From a meteorological point of view, the MAE for both all and non-zero instances is a satisfactory one, meaning that the predicted value is on the same level or on a neighbouring level on the product value scale.

The following two figures depict a comparison between the real data gathered by the radar (Figure 9) and the prediction made by *NowDeepN* (Figure 10). The timestamp of the comparison was chosen so that there was significant meteorological activity (14:37 UTC is in the middle of the meteorological event, as described in the data set used in Section 6.1). The figures depict product R01, chosen because it is one of the most relevant radar products for nowcasting, as it provides information on the precipitation and usually contains more non-zero data than upper levels. Figure 9 represents the real data collected by the radar, while Figure 10 represents the predicted data, given the real data at 14:31:15 UTC. The figures represent the (real or predicted) values of R01 *over a geographical area.* Each pixel in the images represents a geographical location of roughly 1 km2. The values on axes OX and OY represent the coordinates *of a pixel* inside the image, relative to the (0,0) origin in the upper-left corner. While the values along the axes do not represent actual

longitude and latitude values, the OY axis runs along the latitudes and OX axis runs along the longitudes (the (0,0) origin of the images being the most north-western point of the geographical area the image represents).

**Figure 9.** Real data for product *R*01 (reflectivity at the lowest elevation angle, measured in dBZ with values ranging form 0 to 70) at 14:37:22 UTC.

**Figure 10.** Predicted data for product *R*01 (reflectivity at the lowest elevation angle, measured in dBZ with values ranging form 0 to 70) at 14:37:22 UTC.

In Table 1 an average NRMSE of less than 4% is reported for the *R* products, which would entail a close resemblance between the predicted data and the real data, resemblance which can be observed in the comparison between the Figures 9 and 10. It can be observed that the predicted pattern closely follows the real pattern, closely matching the shape and intensity of R01 in the studied area. There are, however, visible limitations of the prediction. There is a clear smoothing effect present, very clearly seen around areas where there are scattered points with non-zero values, the prediction tends to smooth out the area between them, resulting in a much larger areas of close to 0 non-zero values (colored in dark green) than in the real data. This effect can be also seen on areas with points with higher values, for example, around the point at roughly (170, 230), where in the real data present some ragged shape of higher intensity points, but the predicted data would present the area as a smooth shape, much less ragged, the space between the higher intensity points being predicted with a similar intensity. Another example can be seen at (110, 120), where, in the real data, there is a small shape with extremely high values; but in the predicted data, while the real shape is largely retained, the intensity is greatly diminished. Still, these kinds of effects are on a small scale relative to the entire area represented by the figure. Smoothing may be responsible for the higher NRMSE obtained for non-zero values, as areas with higher values are more affected by smoothing than areas with zero values.

### **7. Discussion**

In this section we are analysing the performance of *NowDeepN* approach in order to answer research questions RQ2 and RQ3. First, we are going to assess the impact of the data cleaning step on the performance of *NowDeepN* (Section 7.1) and to estimate the relevance of the manually engineered features used in the training process (Section 7.2). Then, we continue in Section 7.3 with comparing our results with similar results obtained in the literature.

#### *7.1. Impact of the Data Cleaning Step*

As previously shown in Section 6.1, the training data set contains instances with errors, particularly at lower degrees of elevation. Obviously, the noisy training data may affect the performance of the learning task. In this regard, a data cleaning step motivated by the meteorological perspective has been introduced in Section 5.3.1 for replacing the invalid values in the data set which were identified mostly for the product *V* at lower degrees of elevation (*V*01–*V*04). The analysis we performed on data after it was cleaned (Section 6.2) led us to the hypothesis that the cleaning step would impact the overall performance of *NowdDeepN*, and this should be visible at least at lower degrees of elevations for *V*. Intuitively, as the data cleaning is correlated with the meteorological evidence, we expect a better performance of *NowDeepN*, particularly at lower degrees of elevation.

In order to empirically validate the hypothesis that the cleaning step improves the predictive performance of *NowDeepN*, we have evaluated the model trained on the uncleaned data set, using the same methodology introduced in Section 5.

Table 2 depicts the results obtained applying *NowDeepN* on the uncleaned data. Comparing the results with those obtained on the cleaned data (Table 1) we observe an improvement in the predictive performance of *NowDeepN* achieved on the cleaned data. The last column from the Table 2 illustrates the improvement obtained using the cleaning step, computed for each evaluation measure *E* (i.e., RMSE, NRMSE, RMSE*non*−*zero*, NRMSE*non*−*zero*) for all 13 radar products. The improvement is computed as *Euncleaned*−*Ecleaned Euncleaned* .

Figure 11 depicts a very similar situation as Figure 10, the only difference being that the image represents values for R01 predicted by *NowDeepN* trained on *uncleaned* data. Similar to Figures 9 and 10, the image in Figure 11 represents a geographical area with each pixel representing a geographical location of roughly 1 km2. Again, the values on axes OX and OY represent the coordinates of pixels inside the image relative to the (0, 0) origin in the upper-left corner, while axis OX runs along the longitudes and axis OY runs along the latitudes, with the (0, 0) origin being the most north-western point in the geographical area represented by the image. The most erroneous values are on the *V* products, yet the errors can still greatly affect the other products such as R01, presented in the figure. While in the predicted data the shapes and intensities are largely retained as in the real data, some significant anomalies can be observed. For example, around the point at (140, 250) in the predicted data can be seen an area of higher valued points that is completely absent int the real data, with no means to attribute it to the smoothing effect. Also, in the upper right corner at around (10, 380), there can be seen in the real data a significant shape with quite high values at the interior that is simply removed in the predicted data, although the surrounding shapes are much better represented. There are other such anomalies that cannot be explained by the smoothing effect that appear in the data predicted by *NowDeepN* trained on uncleaned data. None of these anomalies appear on the data predicted *NowDeepN* trained on cleaned data. This suggests that there were some erroneous values of *V* in those areas that also affected how *NowDeepN* predicted the other products, and thus, cleaning the data yields much better results.

**Figure 11.** Predicted data, using the model trained on uncleaned data, for product *R*01 (reflectivity at the lowest elevation angle, measured in dBZ with values ranging form 0 to 70) at 14:37:22 UTC.

### *7.2. Relevance of the Used Features*

We aim to further analyze the *NowDeepN* approach by determining the relevance of the features used in the training process. More precisely, our goal is to examine if the radar products' values from a neighborhood of a certain location at time *t* are suitable for predicting the products' values at time *t* + 1, for the same location. The analysis from this section is conducted for answering question RQ3.

For determining the significance of the features, we are comparing the results of *NowDeepN* using the original set of features with those obtained by applying *NowDeepN* after the prior application of a feature extraction step. Two feature extractors will be applied on the original set of features, for reducing the dimensionality of the input data.


**Table 2.** Experimental results obtained using *NowDeepN* on the uncleaned data. 95% CI are used for the results. The last column illustrates the improvement achieved applying *NowDeepN* on the cleaned data, considering all 13 radar products.

	- 2. The PCA algorithm is applied for a linear dimensionality reduction of the input data into a 250 dimensional space. For PCA we have used the existing scikit-learn Python implementation of the algorithm using 250 principal components and the default values for the other parameters [52].

Table 3 depicts the results obtained applying *NowDeepN* with a previous feature extraction step (AE/PCA). Comparing the results with those obtained without applying a feature extraction step (Table 1) we observe an improvement in the predictive performance of *NowDeepN* achieved without a prior feature extraction step. The last column from the table illustrate the improvement obtained by *NowDeepN* on the original data (computed as the difference between the measure on the original and on the reduced data divided to the value obtained on the reduced data).

The results depicted in the last column from Table 3 empirically demonstrate that the features (i.e., the radar products' values from a neighborhood of a certain location at time *t*) used in training *NowDeepN* are relevant for predicting the radar products' values at time *t* + 1, for the same location. The relevance of the features is validated by the fact that a dimensionality reduction technique (AE/PCA) applied prior to the classification using *NowDeepN* does not improve the learning performance. The last column from Table 3 also reveals that AEs preserve better than PCA the characteristics of the data when reducing its dimensionality, which is expectable as AEs perform a non-linear mapping whilst PCA a linear one.

### *7.3. Comparison to Related Work*

The literature review from Section 4 revealed various approaches developed in the nowcasting literature using machine learning methods.


**Table 3.** Experimental results obtained using *NowDeepN* with a previous feature extraction step. 95% CI are used for the results. The last column illustrates the improvement achieved applying *NowDeepN* on the original data, considering all 13 radar products.

> We start the comparison between *NowDeepN* and related work by comparing our model to a simple baseline model, the *linear regression* (LR). For an exact comparison, the data model used for *NowDeepN* (Section 5.2) was used for the LR model as well. By applying the LR on the dataset described in Section 6.1, an overall RMSE for the non-zero values (RMSE*non*−*zero*) of 6.094 was obtained. Surprisingly, there is only 3% improvement achieved by *NowDeeP* on our data, with a higher improvement on *V* (about 12%). From a meteorological point of view, the minor improvement in reflectivity nowcasting (compared to the LR model) is probably determined by the fact that the model is predicting the values of the radar products for one time step, and on the particular day used in this study the convective structures detected by the radar display a relatively slow evolution due to the light to moderate wind, thus no rapid modifications between two radar scans can be observed in terms of *R* value or location. Bigger improvements can be observed in *V* product values predictions, since the evolution of this product has a more stochastic character. Future work, dealing with prediction over more time steps, should display greater improvements compared to the benchmark LR model.

> Even if there are numerous machine learning-based methods developed for nowcasting purposes, there are few methods focused on radar base products' values nowcasting, such as reflectivity nowcasting. Most of the related work focus on the precipitation nowcasting problem. We found four approaches having similar goal to our paper, that of predicting the future values of the radar products' values based on their historical values. The approaches from the literature which are the most similar to ours are those proposed by Yan Ji [39], Han et al. [31,41] and Yan et al. [42]. Even if the data sets used in the previously mentioned papers and the evaluation methodology differs from ours, we computed the evaluation measures reported in literature, trying to reproduce as accurately as possible the experiments from the related work.

> The work of Ji [39] is focused on predicting only the reflectivity values which are further used for precipitation prediction. Experiments are performed only on radar data collected only for time stamps when storms occurred, disregarding the periods with normal weather (i.e., for which the *R* value is 0). Besides the minimum and the maximum values for the RMSE, the *Hit rate* (HR) is reported in Reference [39] as the percentage of instances for which the absolute error (between the predicted and the real value) is less than or equal to 5. For an accurate comparison with the work of Yan Ji [39], we trained our *NowDeepN* model only on the instances labeled with non-zero values for *R* and was tested only on non-zero instances. We also note that the evaluation from Reference [39] is performed only once, without using a cross-validation. Han et al. [31,41] focused on their works on predicting the *R* values (using SVM and CNN classifiers), more specifically if they exceed 35 dBZ. Considering that the positive class is the one for which the *R* values are larger than 35, the authors used three evaluation measures: (1) *critical success index* (CSI), computed as

*CSI* <sup>=</sup> *TP TP*+*FN*+*FP* ; (2) *probability of detection* (POD), *POD* <sup>=</sup> *TP TP*+*FN* ; and (3) *false alarm rate* (FAR), *FAR* <sup>=</sup> *FP FP*+*TP* . The MAR-CNN model proposed for precipitation nowcasting by Yan et al. [42] used radar reflectivity images on three elevation levels and other numerical features and provided a RMSE of 7.90 for the predicted reflectivity values.

Table 4 summarizes the comparison between *NowDeepN* and the related work. The best values for the evaluation measures are highlighted.


**Table 4.** Comparison to the work of Yan Ji [39], Han et al. [31,41] and Yan et al. [42].

Table 4 reveals that *NowDeepN* outperforms the approaches proposed by Han et al. [31,41] in terms of CSI, POD and FAR evaluation measures. Overall, in 71% of the cases (5 out of 7 comparisons), the comparison is favorable to *NowDeepN*. Our proposal is outperformed only by the work of Yan Ji [39] which reported a better HR and a maximum RMSE slightly better than ours. This difference may occur due to the following: (1) the data sets used (both as training and and testing) are different and have particularities due to the geographic area (country) on which were collected (i.e., China [39] and Romania); (2) the testing data set from Reference [39] contains data collected on a relatively small area which may lead to a biased evaluation and an overestimated performance.

As previously mentioned, a lot of work has been carried out in the literature for precipitation nowcasting. Tran and Song [40] tackled the precipitation nowcasting problem from a computer vision perspective, by applying certain thresholds on the reflectivity values (5/20/40 dBZ). In order to measure the performance of *NowDeepN* (in terms of CSI, POD and FAR) for the aforementioned thresholds we transformed the predicted values to predicted classes, by denoting each predicted value lower or equal to a threshold as being in the negative class and each predicted value higher than the threshold as being in the positive class. We then applied the same transformation on the ground truth and computed the measures; following this process for each of the three thresholds (5/20/40 dBZ). Table 5 illustrates the comparison between *NowDeepN* and the model proposed by Tran and Song [40]. We mention that Tran and Song [40] provide for each evaluation measure ranges of values. Since an exact comparison cannot be provided (the datasets used for evaluation and the input data models are different) our comparison relies only on the magnitude of CSI, POD and FAR evaluation metrics. The best values obtained for each reflectivity threshold and evaluation metric are highlighted.

The comparative results from Table 5 highlight that *NowDeepN* obtained better results than the model proposed by Tran and Song [40] in 77.7% of the cases (7 out of 9 comparisons). We note the good performance of *NowDeepN* at higher values for the reflectivity threshold, which indicate the ability of our model to detect moderate and heavy precipitation and medium and large hail.


**Table 5.** Comparison to the work of Tran and Song [40].

### **8. Conclusions and Future Work**

We introduced in this paper a supervised learning based regression model *NowDeepN*, which used an ensemble of *deep artificial neural network* for predicting the values for meteorological products at a certain time moment based on their historical values. *NowDeepN* was intended to be a proof of concept for the feasibility of learning to approximate a function between past values of the radar products extracted from radar observations and their future values.

The *NowDeepN* model consisted of an ensemble of DNNs for radar products' values nowcasting. In this ensemble, a DNN model was used for learning to approximate the value of each radar product at a given time moment and a certain geographical location from the radar products' values from the neighborhood of that location at previous time moments.

Experiments were conducted on real radar data provided by the Romanian National Meteorological Administration and collected on the Central Transylvania region. The obtained results provided an empirical evidence that in both normal and severe weather conditions the values for a radar product at a given moment in a certain location are predictable from the values of the neighboring locations from previous time moments. This evidence is essential for further using convolutional neural network models for automatically extracting from radar data features which would be relevant for predicting the radar products' values at a certain time moment based on their historical values. A data cleaning step was introduced for correcting the erroneous input radar and its impact on increasing the predictive performance of the *NowDeepN* model was highlighted. In addition, the relevance of the features considered in the supervised learning task was empirically proven. More specifically, the experiments shown that the radar products' values from the neighboring area of a certain geographical location *l* at time *t*−1 are useful for predicting the radar products' values on location *l* at time *t*.

The experimental results highlighted that our *NowDeepN* model has a good performance particularly for high values of the reflectivity threshold, which indicate its ability to detect moderate and heavy precipitation and medium and large hail. While from a meteorological point of view, the performance of *NowDeepN* for predicting radar reflectivity values one time step ahead is satisfactory, in order to assess its performance compared to techniques currently employed in weather nowcasting, further development of the model for multiple time steps (e.g., 5 or 10 time steps, covering 30 or 60 min) is needed.

The experimental evaluation of *NowDeepN* will be further extended by enlarging the data set used for training the model. As future plans we aim to investigate convolutional neural network models [53] as well as supervised classifiers based on *relational association rule mining* [54,55] for detecting relationships between the meteorological products' values which may distinguish between normal and severe meteorological phenomena. In addition, we will analyze the possibility to extend the features used in the learning process, by combining radar data with other features (e.g., geographic and antropic features).

**Author Contributions:** Conceptualization, G.C. and A.M.; methodology, G.C. and A.M.; software, G.C. and A.M.; validation, G.C., A.M. and E.M.; formal analysis, G.C. and A.M.; investigation, G.C., A.M. and E.M.; resources, G.C., A.M. and E.M.; data curation, G.C., A.M. and E.M.; writing–original

draft preparation, G.C.; writing–review and editing, G.C., A.M. and E.M.; visualization, G.C., A.M. and E.M.; funding acquisition, G.C., A.M. and E.M. All authors have read and agreed to the published version of the manuscript.

**Funding:** The research leading to these results has received funding from the NO Grants 2014–2021, under Project contract no. 26/2020.

**Data Availability Statement:** Restrictions apply to the availability of these data. Data was obtained from Romanian National Meteorological Administration and are available http://www. meteoromania.ro/ with the permission of Romanian National Meteorological Administration.

**Acknowledgments:** The research leading to these results has received funding from the NO Grants 2014–2021, under Project contract no. 26/2020. The authors acknowledge the assistance received from the National Meteorological Administration from Romania, for providing the meteorological data sets used in the experiments. The authors also thank the editor and the reviewers for their useful suggestions and comments that helped to improve the paper and the presentation.

**Conflicts of Interest:** The authors declare no conflict of interest. The funders had no role in the design of the study; in the collection, analyses, or interpretation of data; in the writing of the manuscript, or in the decision to publish the results.

#### **References**


## *Article* **Comparison of Instance Selection and Construction Methods with Various Classifiers**

**Marcin Blachnik 1,\* and Mirosław Kordos <sup>2</sup>**


Received: 7 May 2020; Accepted: 1 June 2020; Published: 5 June 2020

**Abstract:** Instance selection and construction methods were originally designed to improve the performance of the k-nearest neighbors classifier by increasing its speed and improving the classification accuracy. These goals were achieved by eliminating redundant and noisy samples, thus reducing the size of the training set. In this paper, the performance of instance selection methods is investigated in terms of classification accuracy and reduction of training set size. The classification accuracy of the following classifiers is evaluated: decision trees, random forest, Naive Bayes, linear model, support vector machine and k-nearest neighbors. The obtained results indicate that for the most of the classifiers compressing the training set affects prediction performance and only a small group of instance selection methods can be recommended as a general purpose preprocessing step. These are learning vector quantization based algorithms, along with the *Drop2* and *Drop3*. Other methods are less efficient or provide low compression ratio.

**Keywords:** machine learning; classification; preprocessing; instance selection

### **1. Introduction**

Classification is one of the basic machine learning problems, with many practical applications in industry and other fields. The typical process of constructing a classifier consists of data collection, data preprocessing, training and optimizing the prediction models and finally applying the best of the evaluated models. The described scheme is obvious, however we face two types of problems. The first one is that recently we more often start to construct classifiers with limited resources and the second one is that we want to interpret and understand the data and the constructed model easily.

The first group of restrictions are mostly related to time and memory constraints, where machine learning algorithms are often trained on mobile devices or micro computers like Rasberry Pi and other similar devices. There are basically three approaches to overcome these restrictions:


In the paper we focus on the second approach where instead of redesigning the classification algorithm or sending the data to the cloud we analyze how the data filters or in other words how the training set reduction methods influence classification performance with the data processing pipeline depicted in Figure 1.

**Figure 1.** The pipeline of prediction model construction with data filtering.

The data filter has two goals: first it can improve the classifier performance by eliminating noisy samples from the training data thus allowing to achieve higher classification accuracy, and the second goal is training set compression. Training set size reduction allows to speed up classifier construction process but also it speeds up decision making when the classifier is already trained [1]. The speed up in classifier training is rather obvious when the size of the input data is smaller but the speed up in the prediction phase results from smaller number of support vectors of the support vector machine, shallow trees (earlier stopping of tree construction) and lower number of reference vectors in *k*NN. Moreover, it speeds up model selection and optimization. Here, the gain should be multiplied by the number of evaluated classifiers and their hyper-parameters, because the filtering stage is applied once and then the classifier selection and optimization is carried out.

Training set compression can be also used for solving model interpretability where the so called prototype-based rules [2,3] can be applied. These rules are also based on limiting the size of the data. And in this case, even if the original dataset is small, further limiting dataset size is still beneficial.

The main purpose of the paper and the research objective is the analysis of both aspects of data filtering that is the influence on prediction accuracy of various classifiers and the influence on training set size reduction.

In the study we evaluate a set of 20 most popular instance selection and construction methods used as data filters and 7 popular classifiers on 40 datasets in terms of classification performance and training set compression. The evaluated data filters are: condensed nearest neighbor (*CNN*), edited nearest neighbor (*ENN*), repeated-edited nearest neighbor (*RENN*), *All-kNN*, instance based learning version 2 (*IB2*), relative neighbor graph editing (*RNGE*), *Drop* family, iterative case filtering (*ICF*), modified selective subset selection (*MSS*), hit-miss network editing (*HMN-E*), hit-miss network interactive editing (*HMN-EI*), class conditional instance selection (*CCIS*) as well as learning vector quantization version 1,2.1 (respectively *LVQ1*, *LVQ2.1*), generalized learning vector quantization (*GLVQ*), *k-Means*. These data filters are selected for the evaluation, because they are the most frequently used instance selection methods. The datasets after filtering are used to train classifiers such as k-nearest neighbor (*k*NN), support vector machine (SVM), decision tree based methods, linear model and simple Naive Bayes. The hyperparameters of each of these classifiers are optimized with the grid search approach in order to achieve the highest possible prediction accuracy on the compressed data. Moreover, the obtained results are compared to the results obtained with simple stratified random sampling, which defines an acceptance threshold below which particular methods are not beneficial for given classifiers.

The article is structured as follow: in the next section an overview of instance selection methods is provided, with a literature overview, and also the research gap is presented, then in Section 3 we describe the experimental setup, and in Section 4 the results are presented. Finally, Section 5 summarizes the paper with general conclusions.

### **2. Research Background**

### *2.1. Problem Statement*

From the statistical point of view, reduction of the training set size will not affect prediction accuracy of the final classifier when the conditional probability *P*(*c*|**x**) of predicting class *c* for given vector **x** remains unchanged when estimated from the original training set **T** and from the set of prototypes **P** obtained from the data filtering process.

In the literature one of popular multidimensional probability estimation methods is based on the nearest neighbor approach [4]. Similarly, for *k*NN many data filtering methods were developed in order to select a suitable subset of the training set. These methods are called instance selection methods and instance construction or prototype generation methods, and mostly they were designed to overcome weaknesses of the *k*NN classifier. In instance selection, usually the performance of *k*NN or even 1-NN classifier is used to identify those training samples which are important for the classification. These are mostly border samples close to the decision boundary. Instances which represent larger groups of instances from the same class and noise samples are usually filtered out because they reduce the performance of *k*NN. On the other hand, instance construction methods tend to find optimal position of the stored samples by 1-NN, so they do not need to represent samples from the original training set, these are usually new samples. An example effect of applying instance selection methods to the training set is presented on Figure 2.

**Figure 2.** Example effect of artificial 2D training set compression using *Drop3* instance selection methods. (**a**) Original training set with additional noise (**b**) Training set after compression.

The idea of applying instance selection and prototype generation methods as data filters is not new and it is often considered a standard preprocessing step. In particular in Reference [5] some of the instance selection methods evaluated in our study were considered as the most effective preprocessing methods.

For example we have applied instance selection in the optimization of metallurgical processes for data size limitation and rule extraction [6,7], but instance selection methods were also applied in haptic modeling [8] as well as for automatic machine learning [9]. However, we cannot find in the literature any comprehensive study analyzing the influence of the instance selection methods on various classifiers. Most of the authors when presenting new algorithms indicate only the performance of 1-NN, *k*NN with fixed *k* parameter, and sometimes other classifiers but usually also with fixed hyperparameters. Such a comparison can be considered unfair, because the training set is being changed during the data filtering so different parameters are required by the classifiers. Unfortunately, such a comparison requires much larger computational time especially when using grid search with internal cross-validation procedure, so the process is usually simplified and only classifiers with fixed parameters are used. Only in References [10,11] some broader comparison is available but the experiments were conducted on only few (6 and 8) datasets. To fill that gap we provide a detailed analysis of the influence of the data based on instance selection and construction methods applied to 40 datasets.

### *2.2. Instance Selection and Construction Methods Overview*

One of the most important properties of data filtering methods is the relation between instances in the original training set **T** and the in the selected prototype set **P**. If **P** ⊂ **T** then the methods are called instance selection algorithms, because the prototypes **P** are selected directly from the training set **T**. This property does not hold for prototype construction also called prototype generation methods. In this case the elements of **P** are new vectors which can constitute completely new instances, which have never appeared in **T**.

This property is important considering the comprehensibility of the selected samples. In the case of instance selection methods the instances can be mapped into real objects, while in the case of instance construction methods the direct mapping is not possible. This is especially important when working with prototype-based rules [2,12], or other interpretable models.

In the literature perhaps the best overview of instance selection methods can be found in Reference [13] where the authors provide a taxonomy of over 70 instance selection methods, and empirically compare half of them in terms of compression and prediction performance of 1-NN classifier. The same group of authors of Reference [14] perform a similar analysis for prototype generation methods where 32 methods are discussed and 18 of them are empirically compared in application only for 1-NN classifier.

The taxonomy of data filtering methods can be presented in the following aspects:

	- **incremental**, when given method starts from an empty set **P** and iteratively adds new samples, such as in *CNN* [15] or *IB2* [16]
	- **decremental**, when a given method starts from **P** = **T** and then samples from **P** are iteratively removed such as in *HMN-E* and *HMN-EI* [17], *MSS* [18], *Drop1*, *Drop2*, *Drop3* and *Drop5* [19], *RENN* [20].
	- **batch**, when the instances are removed at once after analysis of the input training set. An examples of such methods are *ENN* [21], *All-kNN* [20] *RNGE* [22], *ICF* [23], *CCIS* [24].
	- **fixed**, when a fixed number of prototypes is given as the hyperprarametr of the method. This group includes *LVQ* (*LVQ1*, *LVQ2.1*, *GLVQ*) family [25,26], *k-Means* and random sampling.
	- **condensation**, when the algorithm tries to remove samples, which are located far from the decision boundary, as in the case of *IB2*, *CNN*, *Drop1* and *Drop2*, *MSS*.
	- **edition**, when the algorithm is designed for noise removal, such as *ENN*, *RENN*, *All-kNN*, *HMN-E*.
	- **hybrid**, when the algorithm performs both steps—condensation and edition. Usually these methods starts from noise removal, and then perform condensation. This group includes: *Drop3*, *Drop5*, *ICF*, *HMN-EI*.
	- **filters**, where the method uses internal heuristic independent to the final classifier.
	- **wrappers**, when external dedicated classifier is used to identify important samples.

The decision of assigning an algorithm to the right evaluation method depends on the final prediction model applied after data filtering. If the instance selection or construction method is followed by 1-NN or *k*NN classifier they can be considered as wrappers, because internally all of them use a kind of nearest neighbor based approach to decide whether an instance should be selected or rejected. On the other hand they can be also considered as filters, when the data filter takes as input entire training set and returns selected subset which is then used to train any classifier, not only the *k*NN. There are implementations which works as wrappers, so they allow to use all kind of classifiers such as in Reference [27], where instead of *k*NN any other classifier can be used, in particular the MLP neural network was used. The drawback of the wrappers is the increase of computational complexity. Here in this article we only consider the standard instance selection methods without any generalization.

### **3. Experimental Design**

There are several factors which determine the applicability of given algorithms as a general purpose training set filter. Among the most important are compression level and prediction accuracy of the final classifier. The compression is defined as:

$$comp = 1 - \frac{||\mathbf{P}||}{||\mathbf{T}||},\tag{1}$$

so that higher value of compression indicates that more samples were rejected and the resultant set **P** is smaller and lower values (close to 0) indicates that the output training set is larger. The second property is the prediction accuracy of the classifier trained on **P**. This value is subjective to the applied classifier, so that for one classifier given set **P** may result in high accuracy, while for the other the accuracy can be worse. Here a simple accuracy measure was evaluated:

$$acc = \frac{\text{\#correctly classified samples}}{\text{\#all evaluated samples}}. \tag{2}$$

In order to determine the applicability of instance selection and construction methods as universal training set filters we designed experiments which mimic typical use cases of training set filtering. The scheme of the data processing pipeline is presented in Figure 3

**Figure 3.** The pipeline of data processing used in the experiments.

It starts with data loading and attribute normalization, then the 10 fold cross-validation procedure is executed which wraps the data filtering stage (our instance selection or construction method) followed by classifier training and hyperparameter optimization procedure. Finally, the trained classifier is applied to the test set. During the process execution prediction accuracy and compression were recorded.

In the experiments the most commonly used classifiers were evaluated. These are: the basic classifiers for which the evaluated data filters were designed such as 1-NN and *k*NN; simple classifiers like Naive Bayes or linear model (GLM); followed by kernel methods such as SVM with Gaussian kernel and finally the decision tree based methods including C4.5 and Random Forest. Many of these methods require careful parameter selection such as the value of *k* in *k*NN or *C* and *γ* in SVM or the number of trees in Random Forest. All of the evaluated parameters are presented in Table 1. It is important to note that each of the applied data filters was independently evaluated for each classifier, because a particular filter may be beneficial for one classifier and unfavorable for another.


**Table 1.** Parameter settings of the evaluated classifiers.

The entire group of instance selection and construction methods is very broad. As indicated in Section 2.2 some authors distinguish over 70 instance selection methods and over 32 prototype construction methods [13,14]. From these groups we selected the most popular ones which can be found in many research papers as the reference methods [10,11,28–30]. These are *CNN*, *ENN*, *RENN*, *All-kNN*, *IB2*, *RNGE*, *Drop1*, *Drop2*, *Drop3*, *Drop5*, *ICF*, *MSS*, *HMN-E*, *HMN-EI*, *CCIS*, from the group of instance selection methods, and from the group of prototype generation methods we selected 3 algorithm from the family of *LVQ* algorithms, these are *LVQ1*, *LVQ2.1*, as well as the *GLVQ* algorithm. In the experiments we also evaluated *k-Means* clustering algorithm which is most often used to reduce the size of the training set [31]. The *k-Means* algorithm was independently applied to each class label, and then the obtained cluster centers were used as prototypes with appropriate class labels [32]. All of the prototype generation methods belong to the group of fixed methods, so they require to determine the compression manually. For that purpose the experiments were carried out for two different initial sets of prototypes: randomly selected 10% of the training samples used for initialization which corresponds to 90% compression and also 30% of the training samples which corresponds to 70% compression. The 90% compression is the lower bound of the compression obtained by most of instance selection methods.

All evaluated methods were also compared with the random stratified sampling (*Rnd*), which is the simplest solution that can be used as a data filter. Similarly as with prototype construction methods, the experiments with *Rnd* were conducted for compression 70% and 90% that corresponds to *Rnd(0.3)* and *Rnd(0.1)* (the numbers represent percentage of the samples that remain). The accuracy obtained for *Rnd* constitute the lower bound which allows to distinguish beneficial data filters from the weak ones that are worse than simple random sampling.

The experiments were carried out on 40 datasets obtained from the Keel repository [33]. A list of the datasets is presented in Table 2. All the calculations were conducted using RapidMiner software with Information Selection extension developed by the authors [34]. The extension is available at the RapidMiner Marketplace and the most recent version is also available at GitHub (https://github.com/ mblachnik/infoSel). Some of the evaluated algorithms like *HMN-EI* and *CCIS* were taken from the Keel framework [35] and integrated with the Information Selection extension.


**Table 2.** Datasets used in the experiments. The s/a/c denotes the number of samples, attributes and classes.

### **4. Results and Analysis**

Since simple averaging has limited interpretability, we used both average performance and average rank to asses the quality of the evaluated methods. In order to make ranking for each dataset and each classifier the results obtained for particular data filters were ranked from the best to the worst in terms of classification accuracy and compression. The highest rank (equal to 26, which is the number of evaluated algorithms) was given to the best filter method for particular dataset and the lowest rank (1) was assigned to the worst method (rank with ties). Then the ranks over datasets were averaged to indicate the final performance. Such a comparison does not reflect how much one method differs from the other in terms of given quality measure, but ranking unlike averaging performances is insensitive to the data distribution where measures like accuracy can range from 40% on one dataset up to 99% on another. On the other hand ranking do not provide information on how much the methods differ so these both quality measures complement each other and should be considered together, where the ranking gives an answer which method was more often better, and then, the mean performance indicates how much given method was better from the competitor. Moreover, the threshold obtained by the random sampling should be applied simultaneously to both values and when any of them is below the threshold given method should be rejected as useless.

The obtained results including both average ranks as well as average performances are presented in Table 3. Moreover, the Wilcoxon signed-rank statistical test [36] was used to check whether the results obtained by the classifier without any data filter significantly differ from the results obtained when given data filter was used to cleanup the dataset. The calculations were conducted with significance level *α* = 0.1. The data filters which did not lead to a significantly decrease of the prediction accuracy were marked with =. In few cases the data filter allowed to increase the accuracy of a classifier, and if the increase was statistically significant we marked the results with + sign. In this case the significance was measured using Wilcoxon tailed sign-rank test.



To increase readability, the results which represent ranks are also presented graphically independently for each classifier. In the figures the doted lines represent the performances obtained by random sampling, so that if any filter method appears within the space defined by the doted lines it is dominated by simple random sampling (*Rnd*).

Below in the following subsection the term "reference method" is used to describe the algorithm without data filter, this is the classifier which was directly applied to the training data.

### *4.1. 1-NN*

The results obtained for 1-NN are visualized in Figure 4. They indicate that the *GLVQ* and *LVQ2.1* significantly outperform other methods and especially the reference solution without any data filter. From the group of instance selection methods the best ones are noise filters *ENN*, *HMN-E* and from the group of condensation methods—*Drop2* and *Drop3* are dominating. It is also noticeable that all of the data filters appear above the base rates defined by the random sampling.

**Figure 4.** Results obtained for 1-NN classifier. (**a**) Average performance ranks. (**b**) Average performance.

### *4.2. kNN*

In the case of *k*NN similarly the best results are obtained for *GLVQ* and *LVQ2.1* (see Figure 5), but they do not differ significantly from the results obtained by *k*NN with optimally tuned *k* parameter. All other filters appears to decrease classification accuracy. Also noticeable is the fact that now

*IB2* appears to be dominated by the random sampling, as well as *All-kNN* and *Drop1* in terms of average accuracy.

**Figure 5.** Results obtained for *k*NN classifier. (**a**) Average performance ranks (**b**) Average performance.

### *4.3. Naive Bayes*

The results obtained for Naive Bayes are shown in Figure 6. There is no significant difference between the accuracy obtained for the two random sampling methods (one with compression 90% and the second with compression 70%). The difference between these two in terms of ranks is less than 1. For Naive Bayes the highest accuracy is obtained by the *ENN* and *RENN*, for which the comparison to the reference method is statistically significantly different. Also *All-kNN* is very high, but the Wilcoxon test does not indicate significant statistical difference. That is reasonable because noise filters remove the noise samples which affects the probability distributions estimated by the Naive Bayes classifier. Here the *LVQ* family, especially the *GLVQ*(0.3) algorithm, displays similar performance to noise filters, but with the compression reaching 70%, unfortunately the difference to the reference method is not significant. Also other prototype construction methods like other *LVQ* algorithms as well as *k-Means* clustering method do not show significant differences. From the group of evaluated methods almost

all instance selection algorithms are dominated by the random sampling so all these methods can be considered as unhelpful.

**Figure 6.** Results obtained for Naive Bayes classifier. (**a**) Average performance ranks. (**b**) Average performance.

#### *4.4. GLM*

The linear model without instance selection provided the highest accuracy, as shown in Figure 7 and by applying any data filter we may expect a drop in accuracy. The highest accuracy using filters is obtained for *ENN* and *HMN-E* and for larger compression methods with *LVQ1*, *GLVQ* and *Drop3*, but all these results are statistically significantly different. It is important to mention that the linear model can be efficiently implemented, so the data filters are not necessarily required, because they extend the computation time. For GLM nine models are dominated by random sampling, these are *Drop1*, *Drop2*, *Drop5*, *ICF*, *CCIS*, *CNN*, *MSS*, *HMN-EI* and *All-kNN* and many are close to the border like *GLVQ*, *LVQ2.1* or *k-Means*, so in general it is not recommended to perform any data filtering for the GLM model.

**Figure 7.** Results obtained for GLM classifier. (**a**) Average performance ranks. (**b**) Average performance.

### *4.5. C4.5 Decision Tree*

In the case of C4.5 decision tree (see Figure 8) it could be expected that any dataset reduction may result in the drop of accuracy. This is due to the quality of estimated statistics which are determined when selecting the split nodes. As a result the majority of data filters are dominated by random sampling. Only noise filters allows to achieve the accuracy comparable to the one obtained by the reference method, moreover the results for *ENN* and *All-kNN* are not statistically significantly different. This is due to the fact that noise filters have very low compression, but also regularizing the decision boundary by eliminating the noise samples can have positive influence on the estimated measures of decision tree nodes quality. For the condensation methods only *Drop3*, *Drop2* and *ICF* achieve results not dominated by random sampling.

**Figure 8.** Results obtained for C4.5 classifier. (**a**) Average performance ranks. (**b**) Average performance.

### *4.6. Random Forest*

Random Forest is a classifier which is also based on the decision tree but thanks to the properties of collective decision making it can overcome some of its weaknesses. As shown in Figure 9 only *ENN* achieves comparable accuracy to the one obtained with entire training set. However, almost all data filtering methods especially instance selection methods are better than random sampling (except *HMN-EI* and *All-kNN* which lie on the border), and almost all prototype construction methods (except *LVQ1*) lie on the bounds defined by random sampling. Note that here all methods are statistically significantly different from the reference method, that is worse than the reference solution.

**Figure 9.** Results obtained for Random Forest classifier. (**a**) Average performance ranks. (**b**) Average performance.

### *4.7. SVM*

The final of the evaluated classifiers is SVM which is one of the most robust classifiers (similarly to Random Forest). The results presented in Figure 10 indicate that all the examined instance selection methods lead to decrease in prediction accuracy. Moreover, for the compression level of up to 70% the top data filters are *ENN*, *HMN-E*, *RNGE*, *CNN*, *k-Means* and *LVQ1*, which on average share similar accuracy rank. Further increase in compression leads to significant drop in accuracy rank, so the best methods with compression equal 90% like *k-Means* and *LVQ1* have accuracy rank almost 8 points lower. For SVM only *RENN*, *All-kNN* and *HMN-EI* (which all belong to the noise filters) are outperformed by random sampling. The reason fo that are the figh tolerance on noise in the data that can be controlled by *C* parameter in SVM.

**Figure 10.** Results obtained for SVM classifier. (**a**) Average performance ranks. (**b**) Average performance.

### **5. Conclusions**

In the article we investigated the performance of the popular classical instance selection and prototype generation methods in terms of the obtained compression of the data set and their influence on the performance of various classifiers. To summarize the obtained results we averaged them for each data filtering method over all classifiers. This allowed us to compare all the evaluated data filters. The results are presented in Figure 11. The red line in the plots indicate the methods which belong to the Pareto front, for example, these ones which are not dominated by the other methods. The following methods belong to the front: *ENN*, *HMN-E*, *GLVQ*(0.3), *LVQ2.1*(0.3), *HMN-EI*, *Drop2*, *Drop3* and *LVQ2.1*(0.1). Some other methods like *k-Means*(0.3) and *LVQ1* can be considered as the top ones because they lie very close to the Pareto front. From the top methods two algorithms provide compression less than 50%, these are *ENN* and *HMN-E*. These methods should be considered only when the compression is not the primary need.

**Figure 11.** Average results over all evaluated classifiers. Red line represents Pareto front. (**a**) Average performance ranks. (**b**) Average performance.

In theory, as indicated in the beginning of this article, the goal of data filtering methods is to keep the estimated conditional probabilities *P*(*c*|**x**) unchanged before and after data filtering so that the *P*(*c*|**x**)**<sup>T</sup>** = *P*(*c*|**x**)**P**, but in reality each of these classifiers has its own probability estimation technique. So the one used by the decision trees which is based on the instance frequency calculation within the bin, do not match with the one of the nearest neighbor classifier. Moreover, SVM and Random Forest are more robust than *k*NN so they can better deal with the noisy samples than the data filters which internally use *k*NN to assess training instances.

The obtained results indicate that the size of the dataset matters. In general applying any of the examined data filters result in the decrease in accuracy, and a huge drop in prediction performance can be observed between compression 70% and 90%, so that the compression 70% can be considered as a kind of threshold below which we should not compress the dataset. Although, we also observed, that for bigger datasets instance selection methods proved more efficient allowing for higher compression. Interestingly, on average the prediction performance slowly drops even for the noise filters. The exception are 1-NN and Naive Bayes classifiers where some of the tested instance filtering methods (in particular *ENN* and the *LVQ* family) allowed to increase the accuracy. For the *k*NN with tuned *k* the accuracy may remain unchanged, so the benefit is the execution time of the prediction phase, which requires fewer distance calculations to make the decision.

The observed phenomenon can be interpreted taking into account that all of the tested instance selection methods were developed to work with *k*NN. As it was indicated in Section 2 instance selection methods can be considered wrappers for the kNN classifier, while for the remaining classifiers they work as filters. Some authors design specific algorithms for particular classifiers. The examples are the works of Kawulok and Nalepa who developed memetic [37] and genetic [38] algorithms for SVM, also de Mello and others developed an algorithm dedicated for the SVM [39]. In [7] we developed generalized *CNN* and *ENN* algorithms which work as wrappers in particular with MLP network. But these methods are strictly designed for given classifiers and can not be generalized so they were not considered in this research.

In the literature some authors use instance selection methods for balancing the data distribution of unbalanced classification problems [40]. In this scenario instance selection methods are applied to down-sample the majority class, and the minority classes remain unchanged, but this aspect was not considered in our study. Also the problem of applying instance selection methods to other tasks such as regression [41], multi-label learning [42] or stream mining [43] was not covered and requires further studies. Another open question which remains is deeper analysis of why particular of evaluated methods are better than the competitors. This requires independent analysis on lower number of methods and remains for future investigation.

In summery, when considering the use of initial data filtering for training set reduction one should consider *GLVQ*, *LVQ2.1* and in the case where it is needed to use the original training samples and not newly constructed prototypes one should consider *Drop2 Drop3* from the set of evaluated methods. These methods provide significant dataset size reduction and in general allow to obtain the higher prediction accuracy in comparison to the other methods with similar compression, but by applying them we should expect a drop in prediction accuracy for classifiers other than *k*NN.

**Author Contributions:** Conceptualization, M.B.; methodology, M.B.; software, M.B.; validation, M.B., M.Kordos.; formal analysis, M.B.; investigation, M.B. and M.K.; resources, M.B.; data curation, M.B.; writing—original draft preparation, M.B.; writing—review and editing, M.B., M.K.; visualization, M.B. and M.K.; supervision, M.B.; project administration, M.B.; funding acquisition, M.B. All authors have read and agreed to the published version of the manuscript.

**Funding:** The APC was funded by Silesian University of Technology BK-204/2020/RM4

**Conflicts of Interest:** The authors declare no conflict of interest.

### **Abbreviations**

The following abbreviations are used in this manuscript:



### **References**


© 2020 by the authors. Licensee MDPI, Basel, Switzerland. This article is an open access article distributed under the terms and conditions of the Creative Commons Attribution (CC BY) license (http://creativecommons.org/licenses/by/4.0/).

## *Article* **Selection of Support Vector Candidates Using Relative Support Distance for Sustainability in Large-Scale Support Vector Machines**

### **Minho Ryu 1,2 and Kichun Lee 2,\***


**\*** Correspondence: skylee@hanyang.ac.kr; Tel.: +82-02-2220-0478

Received: 8 September 2020; Accepted: 29 September 2020; Published: 6 October 2020

**Abstract:** Support vector machines (SVMs) are a well-known classifier due to their superior classification performance. They are defined by a hyperplane, which separates two classes with the largest margin. In the computation of the hyperplane, however, it is necessary to solve a quadratic programming problem. The storage cost of a quadratic programming problem grows with the square of the number of training sample points, and the time complexity is proportional to the cube of the number in general. Thus, it is worth studying how to reduce the training time of SVMs without compromising the performance to prepare for sustainability in large-scale SVM problems. In this paper, we proposed a novel data reduction method for reducing the training time by combining decision trees and relative support distance. We applied a new concept, relative support distance, to select good support vector candidates in each partition generated by the decision trees. The selected support vector candidates improved the training speed for large-scale SVM problems. In experiments, we demonstrated that our approach significantly reduced the training time while maintaining good classification performance in comparison with existing approaches.

**Keywords:** support vector machine; decision tree; large-scale dataset; relative support distance; support vector candidates

### **1. Introduction**

Support vector machines (SVMs) [1] have been a very powerful machine learning algorithm developed for classification problems, which works by recognizing patterns via kernel tricks [2]. Because of its high performance and great generalization ability compared with other classification methods, the SVM method is widely used in bioinformatics, text and image recognition, and finances, to name a few. Basically, the method finds a linear boundary (hyperplane) that represents the largest margin between two classes (labels) in the input space [3–6]. It can be applied to not only linear separation but also nonlinear separation using kernel functions. Its nonlinear separation can be achieved via kernel functions, which map the input space to a high-dimensional space, called feature space where optimal separating hyperplane is determined in the feature space. In addition, the hyperplane in the feature space, which achieves a better separation of training data, is translated to a nonlinear boundary in the original space [7,8]. The kernel trick is used to associate the kernel function with the mapping function, bringing forth a nonlinear separation in the input space.

Due to the growing speed of data acquisition on various domains and the continual popularity of SVMs, large-scale SVM problems frequently arise: human detection using histogram of oriented gradients by SVMs, large-scale image classification by SVMs, disease classification using mass spectrum by SVMs, and so forth. Even though SVMs show superior classification performance, their computing time and storage requirements increase dramatically with the number of instances, which is a major

obstacle [9,10]. As the goal of SVMs is to find the optimal separating hyperplane that maximizes the margin between two classes, they should solve a quadratic programming problem. In practice, the time complexity in the training phase of the SVM method is at least *O n*2 , where *n* is the number of data samples, depending on the kernel function [11]. Indeed, several approaches have been applied to improve the training speed of SVMs. Sequential minimal optimization (SMO) [12], SVM-light [13], simple support vector machine (SSVM) [14] and library of support vector machine (LibSVM) [15] are among others. Basically, they break the problem into a series of small problems that can be easily solved, reducing the required memory size.

Additionally, data reduction or selection methods have been introduced for large-scale SVM problems. Reduced support vector machines (RSVMs) are a random sampling method that, being quite simple, uses a small portion of the large dataset [16]. However, it needs to be applied several times and unimportant observations are equally sampled. The method presented by Collobert et al. efficiently parallelizes sub-problems, fitting to very large-size SVM problems [17]. It used cascades of SVMs in which data are split into subsets to be optimized separately with multiple SVMs instead of analyzing the whole dataset. A method based on the selection of candidate vectors (CVS) was presented using relative pair-wise Euclidean distances in the input space to find the candidate vectors in advance [18]. Because the only selected samples are used in the training phase, it shows fast training speed. However, its classification performance is relatively worse than that of the conventional SVM, and the need for selecting good candidate vectors arise.

Besides, for large-scale SVM problems, a joint approach that combines SVM with other machine learning methods has emerged. Many evolutionary algorithms have been proposed to select training data for SVMs [19–22]. Although they have shown promising results, these methods need to be executed multiple times to decide proper parameters and training data, which is computationally expensive. Decision tree methods also have been commonly proposed to reduce training data because the training time is proportional to *O np*<sup>2</sup> where *p* represents discrete input variables [23] so is faster than traditional SVMs. The decision tree method recursively decomposes the input data set into binary subsets through independent variables when the splitting condition is met. In supervised learning, decision trees, bringing forth random forests, are one of the most popular models because they are easy to interpret and computationally inexpensive. Indeed, taking advantage of decision trees, several researches combining SVMs with decision trees have been proposed for large-size SVM problems. Fu Chang et al. [24] presented a method that uses a binary tree to decompose an input data space into several regions and trains an SVM classifier on each of the decomposed regions. Another method using decision trees and Fisher's linear discriminant was also proposed for large-size SVM problems in which they applied Fisher's linear discriminant to detect 'good' data samples near the support vectors [25]. Cervantes et al. [26] also utilized a decision tree to select candidate support vectors using the support vectors annotated by SVM trained by a small portion of training data. Their approaches, however, are limited in that it cannot properly handle the regions that have nonlinear relationships.

The ultimate aim in dealing with large-scale SVM problems is to reduce the training time and memory consumption of SVMs without compromising the performance. For this goal, it would be worth finding good support vector candidates as a data-reduction method. Thus, in this paper we present a method that finds support vector candidates based on decision trees that works better than previous methods. We determine the decision hyperplane using support vector candidates chosen among the training dataset. In this proposed approach, we introduce a new concept, relative support distance, to effectively find candidates using decision trees in consideration of nonlinear relationships between local observations and labels. Decision tree learning decomposes the input space and helps find subspaces of the data where the majority class labels are opposite to each other. Relative support distance measures a degree that an observation is likely to be a support vector, using a virtual hyperplane that bisects the two centroids of two classes and the nonlinear relationship between the hyperplane and each of the two centroids.

This paper is organized as follows. Section 2 provides the overview of SVMs and decision trees that are exploited in our algorithm. In Section 3, we introduce the proposed method of selecting support vector candidates using relative support distance measures. Then, in Section 4, we provide the results of experiments to compare the performance of the proposed method with that of some existing methods. Lastly, in Section 5, we conclude this paper with future research directions.

### **2. Preliminaries**

In this section, we briefly summarize the concepts of support vector machines and decision trees. Relating to the concepts, we then introduce the concept of relative support distance to measure the possibility of being a support vector in training data.

### *2.1. Support Vector Machines*

Support vector machines (SVMs) [1] are generally used for binary classification. Given *n* pairs of instances with input vectors {*x*1, *x*2, ... , *xn*} and response variables *y*1, *y*2, ... , *yn* , where *<sup>x</sup><sup>i</sup>* <sup>∈</sup> <sup>R</sup>*<sup>p</sup>* and *yi* ∈ {−1, 1}, SVMs present a decision function in a hyperplane that optimally separates two classes:

$$y = \operatorname{sign}(\mathfrak{w}^{\mathfrak{t}}\mathfrak{x} + b) \tag{1}$$

where *w* is a weight vector and *b* is a bias term. The margin is the distance between the hyperplane and the training data nearest the hyperplane. The distance from an observation to the hyperplane is given by *<sup>d</sup>*(*x*) /*w*. To find the hyperplane that maximizes the margin, we solve the problem by transforming it to its dual problem, introducing the Lagrange multipliers. Namely, in soft-margin SVMs with penalty parameter *C*, we find *w* by the following optimization problem:

$$\begin{aligned} \max & \sum\_{i}^{n} \alpha\_{i} - \frac{1}{2} \sum\_{i}^{n} \sum\_{j}^{n} \alpha\_{i} \alpha\_{j} y\_{i} y\_{j} \mathbb{K} \{\mathbf{x}\_{i}, \mathbf{x}\_{j}\}, \\ & \text{subject to} \quad \sum\_{i}^{n} \alpha\_{i} y\_{i} = 0, \\ & 0 \le \alpha\_{i} \le \mathbb{C}/n, i = 1, \dots, n. \end{aligned} \tag{2}$$

where *C* > 0, α*i*, *i* = 1, ... , *n*, are the dual variables corresponding *xi*, and all the *x<sup>i</sup>* corresponding to nonzero α*<sup>i</sup>* are called support vectors. By numerically solving the problem (2) for α*i*, we obtain α<sup>∗</sup> *i* and compute *w*∗ = - *<sup>i</sup>* α<sup>∗</sup> *i yix<sup>i</sup>* and *b*<sup>∗</sup> = *yi* − *w*<sup>∗</sup> *x<sup>i</sup>* for 0 < α<sup>∗</sup> *<sup>i</sup>* < *C*/*n*. The kernel function *K xi*, *x<sup>j</sup>* is the inner product of the mapping function: *K xi*, *x<sup>j</sup>* = φ(*xi*) *T*φ *xj* . The mapping function φ(*x*) maps the input vectors to high-dimensional feature spaces. Well-known kernel functions are polynomial kernels, tangent kernels, and radial basis kernels. In this research, we chose the radial basis kernel function (RBF) with a free parameter γ denoted as

$$\mathcal{K}(\mathbf{x}\_i, \mathbf{x}\_j) = \exp\left(-\gamma \|\mathbf{x}\_i - \mathbf{x}\_j\|^2\right) \tag{3}$$

Notice that the radial basis kernel, possessing the mapping function φ(*x*) with an infinite number of dimensions [27], is flexible and the most widely chosen.

### *2.2. Decision Tree*

A decision tree is a general tool in data mining and machine learning used as a classification or regression model in which a tree-like graph of decisions is formed. Among the well-known algorithms such as CHAID (chi-squared automatic interaction detection) [28], CART (classification and regression tree) [29], C4.5 [30], and QUEST (quick, unbiased, efficient, statistical tree) [31], we use CART which is very similar to C4.5 since it uses a binary splitting criterion applied recursively and leaving no empty leaf. Decision tree learning builds its model based on recursive partitioning of training data into pure or homogeneous sub-regions. Prediction process of classification or regression can be expressed by inference rules based on the tree structure of the built model, so it can be interpreted and understood easier than other methods. The tree building procedure begins at the root node, which includes all instances in the training data. To find the best possible variable to split the node into two child nodes, we check all possible splitting variables (called splitters), as well as all possible values of the variable used to split the node. It involves an *O*(*pn* log *n*) time complexity where *p* is the number of input variables and *n* is the size of the training data set [32]. In choosing the best splitter, we can use some impurity metrics such as entropy or Gini impurity. For example, the Gini impurity function: *im*(*T*) = 1 − - *<sup>y</sup> p*(*T* = *y*) 2 , where *p*(*T* = *y*) is the proportion of observations where class type *T* is *y*. Next, we define the difference between the weighted impurity measure of the parent node and the two child nodes. Let us denote the impurity measure of the parent node by *im*(*T*); the impurity measures of the two child nodes by *im Tle f t* and *im Tright* ; the number of parent node instances by *XT*; and the number of the child node instances by *XT*,*le f t* and *XT*,*right*. We choose the best splitter by the query that decreases the impurity as much as possible:

$$
\Delta im(T) = im(T) - \frac{X\_{T,left}}{X\_T} im(T\_{left}) - \frac{X\_{T,right}}{X\_T} im(T\_{right}).\tag{4}
$$

There are two methods called pre-pruning and post-pruning to avoid over-fitting in decision tree. The pre-pruning method uses stopping conditions before over-fitting occurs. It attempts to stop separating each node if specified conditions are met. The latter method makes a tree over-fitting and determines an appropriate tree size by backward pruning of the over-fitted tree [33]. Generally, the post-pruning is known as more effective than the pre-pruning. Therefore, we use the post-pruning algorithm.

### **3. Tree-Based Relative Support Distance**

In order to cope with large-scale SVM problems, we propose a novel selection method for support vector candidates using a combination of tree decomposition and relative support distance. We aim to reduce the training time of SVMs for the numerical computation of α*<sup>i</sup>* in (2) which produces *w*<sup>∗</sup> and *b*∗ in (1) by selecting good support vectors in advance that are a small subset of the training data. To illustrate our concept, we start with a simple example in Figure 1, where the distribution of the iris data is shown: for the details of the data, refer to Fisher [34]. In short, the iris dataset describes iris plants using four continuous features. The data set contains 3 classes of 50 instances as Iris Setosa, Iris Versicolor, or Iris Virginica. We decompose the input space into several regions by decision tree learning. After training an SVM model for the whole dataset, we mark support vectors by filled shapes. Each region has its own majority class label, and the boundaries are between the two majority classes. The support vectors are close to the boundaries. In addition, we notice that they are located relatively far away from the center of the data points with the majority class label in a region.

**Figure 1.** The construction of a decision tree and SVMs for the iris data shows the boundaries and support vectors (the filled shapes). The support vectors in the regions are located far away from the majority-class centroid.

In light of this, we describe our algorithm to find a subset of support vectors that determine the separating hyperplane. We divide a training dataset into several decomposed regions in the input space by decision tree learning. This process brings each decomposed region to have most of the data points with the majority class label by the tree learning algorithm. Next, we detect adjacent regions in which the majority class is opposite to that of each region. We define this kind of region as distinct adjacent region. Then we calculate a new distance measure, relative support distance, with the data points in the selected region pairs. The procedure of the algorithm is as follows:


### *3.1. Distinct Adjacent Regions*

After applying decision tree learning to the training data, we detect adjacent regions. The decision tree partitions the input space into several leaves (also denoted by terminal nodes) by reducing some impurity measures such as entropy. Following the approach of detecting adjacent regions introduced by Chau [25], we put in mathematical conditions for being adjacent regions and relate it to the relative support distance. Firstly, we represent each terminal node of a learned decision tree as follows:

$$L\_{\!\!\!\!=1} = \bigcap\_{j=1}^{p} b\_{qj\prime} \quad l\_{q\!\!\!=1} \le b\_{q\!\!\!/ } \le h\_{q\!\!\!/ '} \tag{5}$$

where *Lq* is the *q*th leaf in the tree structure and *bqj* is the boundary range for the *j*th variable of the *q*th leaf with its lower bound *lqj* and upper bound *hqj*. Recall that *p* is the number of input variables. We should check whether each pair of leaves, *Lo* and *Lq*, meet the following criteria:

$$h\_{\alpha\circ} = l\_{\text{qs}} \text{ or } l\_{\alpha\circ} = h\_{\text{qs}}.\tag{6}$$

$$\|l\_{qk} \le l\_{ok} \le h\_{qk} \text{ or } \ l\_{qk} \le h\_{ok} \le h\_{qk}. \tag{7}$$

where *s* and *k* are one of the input variables, 1 ≤ *s* ≤ *p*, 1 ≤ *k* ≤ *p*, and *s k*. That is to say, if two leaves *Lo* and *Lq* are adjacent regions, they have to share one variable, represented by the variable *s* in Equation (6), and one boundary, induced by the variable *k* in (7). Among all adjacent regions, we only consider distinct adjacent regions. For example, in Figure 2, the neighbors of *L*<sup>1</sup> are *L*2, *L*4, and *L*5: {*L*1, *L*5}, however, does not form an adjacent region pair. {*L*3, *L*5} is an adjacent region pair but not distinct since those regions have the same majority class. Therefore, the distinct adjacent regions in the example are only {*L*1, *L*2}, {*L*1, *L*4}, {*L*2, *L*3}, {*L*2, *L*5}, and {*L*4, *L*5}. Distinct adjacent regions are summarized in Table 1. Now, we apply the measure of relative support distance to select support vector candidates in the found distinct adjacent regions for each region.

**Table 1.** Partition of input regions and distinct adjacent regions.


**Figure 2.** Distinct adjacent regions are {*L*1, *L*2}, {*L*1, *L*4}, {*L*2, *L*3}, {*L*2, *L*5}, and {*L*4, *L*5}.

### *3.2. Relative Support Distance*

Support vectors (SVs) play a substantial role in determining the decision hyperplane in contrast to non-SV data points. We extract data points in the training data that are most likely to be the support vectors, constructing a set of support vector candidates. Given two distinct adjacent regions *L*<sup>1</sup> and *L*<sup>2</sup> from the previous step, let us assume the majority class label of *L*<sup>1</sup> is *y* = 1 and that of *L*2, *y* = 2 without loss of generality. First, we calculate the centroid (*mc*) for each majority class label as follows: for an index set *Sc* = {*i*|*x<sup>i</sup>* ∈ *Lc* and the label of *x<sup>i</sup>* = *c*},

$$m\_{\ll} = \frac{1}{n\_{\ll}} \sum\_{i \in S\_{\ll}} x\_{i\prime} \tag{8}$$

where *c* ∈ {1, 2} and *nc* is the cardinality of index set *Sc*.

In other words, *m<sup>c</sup>* is the majority-class centroid of data points in *Lc*, of which the labels are *y* = *c*. Next, we create a virtual hyperplane that bisects the line from *m*<sup>1</sup> to *m*2:

$$\begin{aligned} \mathcal{M} &= \frac{1}{2} (m\_1 + m\_2), \\ \mathcal{W} &= m\_1 - m\_2. \end{aligned} \tag{9}$$

where *M* is the middle point of the two majority-class centroids. The virtual hyperplane is given by *H*(*x*) = 0, where

$$H(\mathbf{x}) = \mathbf{W}^\ell(\mathbf{x} - \mathbf{M}).\tag{10}$$

Lastly, we calculate the distance *rx* between each data point *x* in *Sc* and *m<sup>c</sup>* and the distance *h* between each data point in *Sc* and the virtual hyperplane *H*(*x*) = 0:

$$\begin{aligned} r\_{\mathbf{x}\_{\varepsilon,l}} &= \left\| \mathbf{x}\_{\varepsilon,l} - m\_{\mathbf{c}} \right\|\_{l'} \\ \mathbf{d}\_{\mathbf{x}\_{\varepsilon,l}} &= \frac{\left| H\left(\mathbf{x}\_{\varepsilon,l}\right) \right|}{\left\| \mathbf{W} \right\|} = \frac{\mathbf{W}^{\ell}\left(\mathbf{x}\_{\varepsilon,l} - \mathbf{M}\right)}{\left\| m\_1 - m\_2 \right\|} \, \end{aligned} \tag{11}$$

where *xc*,*<sup>l</sup>* is the *l*th data point belonging to *Sc*. Figure 3 shows a conceptual description of *r* and *h* using the virtual hyperplane in a leaf. After calculating *rx* and *hx*, we apply feature scaling to bring all values into the range between 0 and 1. Our observation is that data points lying close to the virtual hyperplane are likely to be support vectors. In addition, data points lying close to the centroid are less likely to be support vectors. In light of these observations, we select data points lying near the hyperplane and far away from the centroid. For this purpose, we define the relative support distance *T*(*rx*, *hx*) as follows:

$$T(r\_\mathbf{x}, h\_\mathbf{x}) = \frac{1}{(1 + e^{-r\_\mathbf{x}})h\_\mathbf{x}}.\tag{12}$$

**Figure 3.** Distances *r* from the center and *h* from the virtual hyperplane are shown. The two-star shapes are the centroid of each class.

The larger *T*(*rx*, *hx*) becomes, the more likely that the associated *x* is a support vector.

The relationship between support vectors and distances *r* and *h* is illustrated in Figure 4. We use leaves *L*<sup>4</sup> with distinct adjacent regions *L*<sup>1</sup> and *L*<sup>2</sup> in Figure 1. In Figure 4, the observations marked by circles are non-support vectors while those by triangles (in red) are support vectors selected after training all data by SVMs. The distances *r* and *h* of *L*<sup>4</sup> relative to *L*<sup>1</sup> are in Figure 4a, and the relative support distance measures in Figure 4c. Likewise, those of *L*<sup>4</sup> relative to *L*<sup>2</sup> are in Figure 4b,d. We observe that the observations, marked by triangles and surrounded by a red ellipsoid in Figure 1, correspond to the support vectors surrounded by a red ellipsoid in region *L*<sup>4</sup> in Figure 4a, and they have larger values of relative support distance as shown in Figure 4c. Similarly, we notice that the observations, marked by triangles and surrounded by a green ellipsoid in Figure 4b, correspond to the support vectors surrounded by a green ellipsoid in region *L*<sup>4</sup> in Figure 1, and they also have larger values of relative support distance as shown in Figure 4d. The support vectors in *L*<sup>4</sup> are obtained by collecting observations with large values of relative support distance, for example by the rule *T*(*rx*, *hx*) > 0.9, from both the pair of *L*<sup>4</sup> and *L*<sup>1</sup> and the pair of *L*<sup>4</sup> and *L*2. The results reveal that the observations that have a mostly larger distance *r* and shorter distance *h* are likely to be support vectors.

**Figure 4.** *Cont*.

**Figure 4.** Illustration of the distances *r* and *h* according to distinct adjacent regions for the iris data in Figure 1. (**a**) Distances for leaf *L*<sup>4</sup> relative to *L*<sup>1</sup> are shown. Notice the support vectors captured by leaf *L*<sup>4</sup> to leaf *L*<sup>1</sup> are in the (dotted) red ellipsoid. (**b**) Distances for leaf *L*<sup>4</sup> relative to *L*<sup>2</sup> are shown. Notice the support vectors captured by leaf *L*<sup>4</sup> relative to leaf *L*<sup>2</sup> are in the (dashed) green ellipsoid. (**c**) Relative support distances of the observations *xi* in leaf *L*<sup>4</sup> relative to *L*<sup>1</sup> are shown. Notice the support vectors captured by leaf *L*<sup>4</sup> to leaf *L*<sup>1</sup> are in the (dotted) red ellipsoid. (**d**) Relative support distances of the observations *xi* in leaf *L*<sup>4</sup> relative to *L*<sup>2</sup> are shown. Notice the support vectors captured by leaf *L*<sup>4</sup> relative to leaf *L*<sup>2</sup> are in the (dashed) green ellipsoid.

For each region, we calculate pairwise relative support distance with distinct adjacent regions and select a fraction of the observations, denoted by parameter β, in the decreasing order by *T*(*rx*, *hx*) as a candidate set of support vectors. That is to say, for each region, we select the top β fraction of training data based on *T*(*rx*, *hx*). Parameter β represents the proportion of the selected data points, between 0 and 1. For example, when β is set to 1, all data points are included in the training of SVMs. When β = 0.1, we exclude 90% of the data points and reduce the training data set to 10%. Finally, we combine relative support distance with random sampling, which means that a half of the training candidates are selected based on the proposed distance and the others are selected by random sampling. Though being quite informative for selecting possible support vectors, the proposed distance is calculated locally with distinct adjacent regions. Therefore, random sampling can compensate this property by providing whole data distribution information.

### **4. Experimental Results**

In the experiments, we compare the proposed method, tree-based relative support distance (denoted by TRSD), with some previously suggested methods, specifically SVMs with candidate vectors selection, denoted by CVS [18], and SVM with Fisher linear discriminant analysis, denoted by FLD [25], as well as standard SVMs, denoted by SVM. For all comparing methods, we use LibSVM [15] since it is one of the fastest methods for training SVMs. The experiments are run on a computer with the following features: Core i5 3.4 GHz processor, 16.0 GB RAM, Windows 10 enterprise operating system. The algorithms are implemented in the R programming language. We use 18 datasets which are from UCI Machine Learning Repository [35] and LibSVM Data Repository [36] except the checkerboard dataset [37]: a9a, banana, breast cancer, four-class, German credit, IJCNN-1 [38], iris, mushroom, phishing, Cod-RNA, skin segmentation, waveform, and w8a. Iris and Waveform datasets are modified for binary classification problems by assigning one class to positive and the others to negative. Table 2 shows a summary of the datasets used in the experiments where Size is the number of instances in dataset and Dim is the number of features.


**Table 2.** Datasets for experiments.

For testing, we apply three-fold cross validation, repeated three-times, by shuffling each dataset and dividing it into three parts, and use two parts as the training dataset, the other part as a testing dataset with different seeds. We use the RBF kernel for training SVMs in all tested methods. For each experiment, cross validation and grid search are used for tuning two hyper-parameters: the penalty factor C and the RBF kernel parameter γ in Equation (3). Hyper-parameters are searched by a two-dimensional grid with C ∈ {0.1, 1, 10, 100} and γ ∈ 0.001, 0.01, 0.1, 1, 1/*p* where *p* is the number of features. Table 3 shows the values used for each dataset in the experiments. Moreover, we vary the fraction of data points in each region β from 0.1 to 0.3 with the interval of 0.1.


**Table 3.** Hyper-parameters setting for the experiments.


**Table 3.** *Cont.*

We compare the performance of SVM, CVS, FLD, and TRSD in terms of classification accuracy and the training time (in seconds), summarized in Table 4. We also depicted the performance comparison of the proposed TRSD with CVS and FLD on the five largest datasets when β = 0.1 in Figure 5. We used log-2 scale for *y*-axis in Figure 5b. In Table 4, Acc is the accuracy on test data; σ is the standard deviation; and Time is the training time in seconds. Even though the accuracy of the proposed algorithm is slightly degraded in a few cases, it is higher than that of CVS and FLD in most cases. In addition, as β is greater, the accuracy of the proposed algorithm enhanced substantially. For small datasets, there is no significant improvement on computation time compared to the standard SVM since those datasets are already small enough. However, we notice that the training time of TRSD improved quite much when using the large-scale datasets.

**Figure 5.** Comparison of the proposed tree-based relative support distance (TRSD) with candidate vectors (CVS) and Fisher linear discriminant analysis (FLD) on the five largest datasets.

For statistical analysis, we also performed Friedman test to see that there exists significant difference between the multiple comparing methods in terms of accuracy. If the null hypothesis of the Friedman test is rejected, we performed Dunn's test. Table 5 shows the summary of the Dunn's test results at the significant level α = 0.05. In Table 5, the entries (1) TRSD > CVS (FLD); (2) TRSD ≈ CVS (FLD); (3) TRSD < CVS (FLD), respectively, denote that: (1) the performance of TRSD is significantly better than CVS (FLD); (2) there in no significant difference between the performances of TRSD and CVS (FLD); and (3) the performance of TRSD is significantly worse than CVS (FLD). Each number in Table 5 means the number of datasets. At β = 0.1, our proposed method is significantly better than CVS and FLD in 10 and 9 cases among 18 datasets. These numbers increase to 11 and 10 at β = 0.3. On the other hand, our proposed method is significantly worse than CVS only in 2, 3 and 3 cases; than FLD in 0, 1 and 0 cases at β = 0.1, 0.2, 0.3 respectively. Based on the observations in the experiments, we can conclude that our proposed method generates an effective reduction of the training datasets while producing better performance than the existing data reduction approaches.


**Table 4.** Comparisons in terms of accuracy and time on datasets (the results of support vector machines (SVM) are not bolded, because it had access to all examples).


**Table 5.** Dunn's test results in different β for the significant level α = 0.05.

Finally, to compare the training time of SVM, CVS, FLD, and TRSD in detail, we divide it into two parts: selecting candidate vectors (SC) and training a final SVM model (TS) when β = 0.3, summarized in Table 6. From Table 6, we can notice that it especially takes longer time for TRSD and FLD than CVS to select candidate vectors with w8a dataset. This is because the time complexity of building a decision tree is *O*(*pn* log *n*) where *p* is the number of features and *n* is the size of training dataset. However, our proposed method takes shorter than FLD since calculating TRSD is more computationally efficient than fisher linear discriminant and is the fastest overall. The results in Table 6 show that the proposed method efficiently selects support vector candidates while maintaining good classification performance.

**Table 6.** Time comparison in detail. SC and TS mean computing time for selecting candidate vectors and training a final SVM model respectively.


### **5. Discussion and Conclusions**

In this study, we have proposed a tree-based data reduction approach for solving large-scale SVM problems. In order to reduce time consumption in training SVM models, we apply a novel support vector selection method combining tree decomposition and the proposed relative support distance. We introduce the relative distance measure along with a virtual hyperplane between two distinct adjacent regions to effectively exclude non-SV data points. The virtual hyperplane, easily obtainable, takes advantage of the decomposed tree structures and is shown to be effective in selecting support vector candidates. In computing the relative support distance, we also use the distance between each data point to the centroid in each region and combine the two in consideration of the nonlinear characteristics of support vectors. In experiments, we have demonstrated that the proposed method outperforms some existing methods for selecting support vector candidates in terms of computation time and classification performance. In the future, we would like to investigate other large-scale

SVM problems such as multi-class classification and support vector regression. We also envision an extension of the proposed method to under-sampling techniques.

**Author Contributions:** Investigation, M.R.; Methodology, M.R.; Software, M.R.; Writing—original draft, M.R.; Writing—review & editing, K.L. All authors have read and agreed to the published version of the manuscript.

**Funding:** This research was supported by the Ministry of Education of the Republic of Korea and the National Research Foundation of Korea (NRF-2020R1F1A1076278).

**Conflicts of Interest:** The authors declare no conflict of interest.

### **References**


© 2020 by the authors. Licensee MDPI, Basel, Switzerland. This article is an open access article distributed under the terms and conditions of the Creative Commons Attribution (CC BY) license (http://creativecommons.org/licenses/by/4.0/).

## *Article* **Solving Partial Differential Equations Using Deep Learning and Physical Constraints**

**Yanan Guo 1,2, Xiaoqun Cao 1,2,\*, Bainian Liu 1,2 and Mei Gao 1,2**


Received: 31 July 2020; Accepted: 22 August 2020; Published: 26 August 2020

**Abstract:** The various studies of partial differential equations (PDEs) are hot topics of mathematical research. Among them, solving PDEs is a very important and difficult task. Since many partial differential equations do not have analytical solutions, numerical methods are widely used to solve PDEs. Although numerical methods have been widely used with good performance, researchers are still searching for new methods for solving partial differential equations. In recent years, deep learning has achieved great success in many fields, such as image classification and natural language processing. Studies have shown that deep neural networks have powerful function-fitting capabilities and have great potential in the study of partial differential equations. In this paper, we introduce an improved Physics Informed Neural Network (PINN) for solving partial differential equations. PINN takes the physical information that is contained in partial differential equations as a regularization term, which improves the performance of neural networks. In this study, we use the method to study the wave equation, the KdV–Burgers equation, and the KdV equation. The experimental results show that PINN is effective in solving partial differential equations and deserves further research.

**Keywords:** partial differential equations; deep learning; physics-informed neural network; wave equation; KdV-Burgers equation; KdV equation

### **1. Introduction**

Partial differential equations (PDEs) are important tools for the study of all kinds of natural phenomena and they are widely used to explain various physical laws [1–3]. In addition, many engineering and technical problems can be modeled and analyzed using partial differential equations, such as wake turbulence, optical fiber communications, atmospheric pollutant dispersion, and so on [4–6]. Therefore, advances in partial differential equations are often of great importance to many fields, such as aerospace, numerical weather prediction, etc. [7,8]. Currently, partial differential equations and many other disciplines are increasingly connected and mutually reinforcing each other. Therefore, the study of partial differential equations is of great significance. However, a major difficulty in the study of partial differential equations is that it is often impossible to obtain analytical solutions. Therefore, various numerical methods for solving partial differential equations have been proposed by related researchers, such as the finite difference method, finite element method, finite volume method, etc. [9,10]. Numerical methods have greatly facilitated the study of partial differential equations. Nowadays, these methods have been widely used and they are being continuously improved. At the same time, researchers are also trying to develop new methods and tools to solve partial differential equations.

With the advent of big data and the enhancement of computing resources, data-driven methods have been increasingly applied [11,12]. In recent years, as a representative of data-driven methods, deep learning methods that are based on deep neural networks have made breakthrough progress [13–15]. Deep neural networks are excellent at mining various kinds of implicit information and they have achieved great success in handling various tasks in science and engineering, such as in image classification [16], natural language processing [17], and fault detection [18]. According to the universal approximation theorem, a multilayer feedforward network containing a sufficient number of hidden layer neurons can approximate any continuous function with arbitrary accuracy [19,20]. Therefore, neural networks have also tremendous advantages in function fitting. In recent years, neural network-based approaches have appeared in the study of partial differential equations. For example, Lagaris et al. [21] use artificial neural networks to solve initial and boundary value problems. They first construct a trial solution consisting of two parts, the first part satisfying the initial/boundary condition and the second part being a feedforward neural network, and then train the network to satisfy the differential equation. The experimental results show that the method has good performance. However, for high dimensional problems, the training time increases due to the larger training set, which needs to be solved by methods, such as parallel implementations. Based on the work of Lagaris et al. [21], Göküzüm et al. [22] further propose a method that is based on an artificial neural network (ANN) discretization for solving periodic boundary value problems in homogenization. In contrast to Lagaris et al. et al. [21], Göküzüm et al. [22] use a global energy potential to construct the objective to be optimized. Numerical experiments show that the method can achieve reasonable physical results using a smaller number of neurons, thus reducing the memory requirements. However, this method still faces problems, such as slow training speed and overfitting, which may be solved by dropout or regularization. Nguyen-Thanh et al. [23] propose a method to study finite deformation hyperelasticity using energy functional and deep neural networks, which is named Deep Energy Method (DEM). The method uses potential energy as a loss function and trains the deep neural network by minimizing the energy function. DEM has promising prospects for high-dimensional problems, ill-posed problems, etc. However, it also faces problems of how to add boundary conditions, integration techniques, and so on. In addition, how to better build deep neural networks is also a problem for DEM to study in depth. Although there are still many problems, deep learning methods have gradually become a new way of solving partial differential equations.

Nowadays, there has been a growing number of researchers using deep learning methods to study partial differential equations [24–27]. For example, Huang [28] combines deep neural networks and the Wiener-Hopf method to study some wave problems. This combinational research strategy has achieved excellent experimental results in solving two particular problems. It cannot be overlooked that preparing the training dataset is an important and expensive task in their study. Among the many studies, an important work that cannot be ignored is the physics-informed neural networks proposed by Raissi et al. [29–31]. This neural network model takes into account the physical laws contained in PDEs and encodes them into the neural network as regularization terms, which improves performance of the neural network model. Nowadays, physics-informed neural networks are gaining more and more attention from researchers and they are gradually being applied to various fields of research [32–36]. Jagtap et al. [37] introduce adaptive activation functions into deep and physics-informed neural networks (PINNs) to better approximate complex functions and the solutions of partial differential equations. When compared with the traditional activation functions, the adaptive activation functions have better learning ability, which improves the performance of deep and physics-informed neural networks. Based on previous studies on PINN, this study further optimizes the method and constructs physics-informed neural networks for the wave equation, KdV–Burgers equation, and the KdV equation, respectively.

The paper is structured, as follows: Section 2 introduces the proposed algorithm for solving partial differential equations based on neural networks and physical knowledge constraints. Subsequently, Section 3 provides experimental validation of the proposed method, gives the partial

differential equations used and the experimental scheme, and analyzes the experimental results. In Section 4 we discuss the experimental results. Finally, Section 5 summarizes the work in this paper and lists the future work to be done.

### **2. Methodology**

In this section, we begin with a brief introduction to neural networks. Subsequently, we present an overview of physics-informed neural networks that incorporate physical laws. The relevant algorithms and implementation framework of the PINNs are introduced.

#### *2.1. Artificial Neural Networks*

Artificial Neural Network (ANN) is a research hotspot in the field of artificial intelligence since the 1980s [38,39]. It abstracts the human brain neurons from the perspective of information processing and models various networks according to different connections. Specifically, the artificial neural network is used to simulate the process of transmitting information from neuron cells in the brain. It consists of multiple connected artificial neurons and can be used to mine and fit complex relationships hidden within the data. Besides, the connections between different neurons are given different weights, each representing the amount of influence of one neuron on another neuron. Figure 1 illustrates the structure of a feedforward neural network (FNN). Feedforward neural network [40] is a simple artificial neural network in the field of artificial intelligence. As can be seen in Figure 1, a feedforward neural network consists of an input layer, one or more hidden layers, and an output layer. Within it, parameters are propagated from the input layer through the hidden layer to the output layer. When designing a neural network, the number of hidden layers, the number of neurons per layer, and the selection of activation functions are all important factors to consider.

As the number of hidden layers increases, an artificial neuron network can be viewed as an adaptive nonlinear dynamic system that consists of a large number of neurons through various connections, which can be used to approximate a variety of complex functions. Although the structure of artificial neural networks is relatively simple, it is not easy to make artificial neural networks capable of learning. It was not until around 1980 that the backpropagation algorithm effectively solved the learning problem of multilayer neural networks, and became the most popular neural network learning algorithm [39,41]. Because an artificial neural network can be used as a function approximator, it can be considered as a learnable function and applied to solve partial differential equations. Theoretically, with enough training data and neurons, artificial neural networks can learn solutions to partial differential equations.

**Figure 1.** A structural diagram of a feedforward neural network (FNN), which consists of an input layer, one or more hidden layers, and an output layer, each containing one or more artificial neurons.

### *2.2. Physics-Informed Neural Networks*

In this section, we introduce the physics-informed neural networks (PINNs) and related settings in this study. Traditional neural networks are based entirely on a data-driven approach that does not take into account the physical laws that are contained in the data. Therefore, a large amount of data is often required to train the neural networks to obtain a reasonable model. In contrast, physics-informed neural networks introduce physical information into the network by forcing the network output to satisfy the corresponding partial differential equations. Specifically, by adding regularization about partial differential equations to the loss function, the model is made to consider physical laws during the training process. This processing makes the training process require less data and speeds up the training process. Physics-informed neural networks can be used to solve not only the forward problem, i.e., obtaining approximate solutions to partial differential equations, but also the inverse problem, i.e., obtaining the parameters of partial differential equations from training data [29,36,42,43]. In the following, the physical-informed neural network modified and used in this study is introduced for the forward problem of partial differential equations.

In this study, consider the partial differential equation defined on the domain Ω with the boundary *∂*Ω.

$$\mathcal{D}(\mathfrak{u}(\mathfrak{x})) = \mathbf{0} \quad \mathfrak{x} \in \Omega \tag{1}$$

$$\mathcal{B}(\mathfrak{u}(\mathfrak{x})) = \mathbf{0} \quad \mathfrak{x} \in \mathfrak{Y}\Omega \tag{2}$$

where *u* is the unknown solution and D denotes a linear or nonlinear differential operator (*e*.*g*., *<sup>∂</sup>*/*∂x*, *<sup>u</sup>* ◦ *<sup>∂</sup>*/*∂x*, *<sup>u</sup>* ◦ *<sup>∂</sup>*2/*∂x***2**,*etc*.), and the operator <sup>B</sup> denotes the boundary condition of a partial differential Equation (e.g., Dirichlet boundary condition, Neumann boundary condition, Robin boundary condition, etc.). A point to note is that, for partial differential equations that contain temporal variables, we treat *t* as a special component of *x*, i.e., the temporal domain is included in Ω. At this point, the initial condition can be treated as a special type of Dirichlet boundary condition on the spatio-temporal domain.

First, we construct a neural network for approximating the solution *u*(*x*) of a partial differential equation. This neural network is denoted by *u*ˆ(*x*; *θ*), which takes the *x* as input and outputs a vector of the same dimension as *u*(*x*). Suppose this neural network contains an input layer, *L* − 1 hidden layers and an output layer. Specifically, each hidden layer in the neural network receives the output from the previous layer and, in the *kth* hidden layer, there are *Nk* number of neurons. *θ* is used to represent the neural network parameters, containing the collection of weight matrix *<sup>W</sup>*[*k*] <sup>∈</sup> **<sup>R</sup>***nk*×*nk*−<sup>1</sup> and bias vector *<sup>b</sup>*[*k*] <sup>∈</sup> **<sup>R</sup>***nk* for each layer *<sup>k</sup>* with *nk* neurons. These parameters will be continuously optimized during the training phase. The neural network *u*ˆ should satisfy two requirements: on the one hand, given a dataset of *u*(*x*) observations, the network should be able to reproduce the observations when *x* is used as input and, on the other hand, *u*ˆ should conform to the physics underlying the partial differential equation. Thus, we next fulfill the requirements of the second part by defining a residual network.

$$f(\mathbf{x}; \boldsymbol{\theta}) := \mathcal{N}[\boldsymbol{\mathbb{1}}(\mathbf{x}; \boldsymbol{\theta})] \tag{3}$$

To build this neural network, we need to use automatic differentiation (AD). Currently, automatic differentiation techniques have been widely integrated into many deep learning frameworks, such as Tensorflow [44] and PyTorch [45]. Therefore, many researchers have commonly used automatic differentiation in their studies of PINNs [29]. In this study, for the surrogate network *u*ˆ, we derive the neural network by the automatic differentiation according to the chain rule. Moreover, since the network *f* has the same parameters as the network *u*ˆ, both networks are trained by minimizing a loss function. Specifically, Figure 2 shows a schematic diagram of a physics-informed neural network. The next main task is to find the best neural network parameters that minimize the defined loss function. In a physics-informed neural network, the loss function is defined, as follows

$$J(\theta) = MSE\_u + MSE\_f \tag{4}$$

The calculation of the mean square error (MSE) is given by the following formula:

$$MSE\_{\boldsymbol{u}} = \frac{1}{N\_{\boldsymbol{u}}} \sum\_{i=1}^{N\_{\boldsymbol{u}}} \left| \boldsymbol{\hat{u}}^{i} - \boldsymbol{u} \left( \mathbf{x}\_{\boldsymbol{u}\prime}^{i} \boldsymbol{t}\_{\boldsymbol{u}}^{i} \right) \right|^{2} \tag{5}$$

$$MSE\_f = \frac{1}{N\_f} \sum\_{i=1}^{N\_f} \left| f \left( \mathbf{x}\_{f'}^i \mathbf{t}\_f^i \right) \right|^2 \tag{6}$$

Here, *u xi <sup>u</sup>*, *t i u* denotes training data from initial and boundary conditions and *u xi f* , *t i f* denotes the training data in the space-time domain. Equation (5) requires the neural network to satisfy the initial and boundary conditions, while Equation (6) requires the neural network to satisfy the constraints of the partial differential equation, which corresponds to the physical information part of the neural network. Next, the optimization problem for Equation (4) is addressed by optimizing the parameters in order to find the minimum value of the loss function, i.e., we seek the following parameters.

$$w^\* = \underset{w \in \theta}{\text{arg min}} (f(w))\tag{7}$$

$$b^\* = \underset{b \in \theta}{\text{arg min}} (f(b))\tag{8}$$

In the last step, we use gradient-based optimizers to minimize the loss function, such as SGD, RMSprop, Adam, and L-BFGS [46–48]. It is found that, for smooth PDE solutions, L-BFGS can find a good solution faster than Adam, using fewer iterations. This is because Adam optimizer relies only on the first order derivative, whereas L-BFGS uses the second order derivative of the loss function [49]. However, one problem with L-BFGS is that it is more likely to get stuck on a bad local minimum. Considering their respective advantages, in this study we end up using a combination of L-BFGS and Adam optimizer to minimize the loss function. Besides, we also use a residual-based adaptive refinement (RAR) [50] method to improve the training effect by increasing the number of residual points in regions with large residuals of partial differential equations until the residuals are less than the threshold. By the above method, we will obtain trained neural networks that can be used to approximate the solutions of partial differential equations. In the next part, we will use the above method to study three important partial differential equations: the one-dimensional wave equation, the KdV–Burgers equation and the KdV equation.

**Figure 2.** The schematic of physics-informed neural network (PINN) for solving partial differential equations.

### **3. Experiments and Results**

In this section, we study the one-dimensional wave equation, the KdV-Burgers equation and the KdV equation using physics-informed neural networks. The neural network models are constructed for these three equations, respectively, based on the given initial and boundary conditions. The approximation results of the neural networks are compared with the true solutions to test the physics-informed neural networks in this paper. All of the experiments were done on Ubuntu 16.04 and we used the open-source TensorFlow to build and train the neural network models. Besides, we used PyCharm 2019.3 which is developed by JetBrains as the development environment for the experiments and NVIDIA GeForce GTX 1080 Ti. In the following, we will present the experimental design and results of these three equations, respectively.

#### *3.1. Wave Equation*

This section presents an experimental study of the wave equation using the physics-informed neural network. The wave equation is a typical hyperbolic partial differential equation and it contains second-order partial derivatives about the independent variable. In physics, the wave equation describes the path of a wave propagating through a medium and is used to study the various types of wave propagation phenomena. It appears in many fields of science, such as acoustic wave propagations, radio communications, and seismic wave propagation [51–53]. The study of wave equations is of great importance, as they are widely used in many fields. In this study, we choose a one-dimensional wave equation [54] for our experiments. In mathematical form, this wave equation is defined, as follows:

$$u\_{tt} - \varepsilon u\_{xx} = 0, \quad \text{x} \in [0, 1], \quad \text{t} \in [0, 1] \tag{9}$$

where *u* is a function of the spatial variables *x* and time *t*. In the equation, the value of *c* represents the wave propagation velocity, which is given as 1 in this study. Besides, for this wave equation, its initial conditions and the homogeneous Dirichlet boundary conditions are given, as follows:

$$\begin{cases} u(0, \mathbf{x}) = \frac{1}{2} \sin(\pi \mathbf{x}) \\ u\_t(0, \mathbf{x}) = \pi \sin(3\pi \mathbf{x}) \\ u(t, 0) = u(t, 1) = 0 \end{cases} \tag{10}$$

The true solution of the above equation is *u*(*t*, *x*) = <sup>1</sup> <sup>2</sup> sin(*πx*) cos(*πt*) + <sup>1</sup> <sup>3</sup> sin(3*πx*) sin(3*πt*), which we use to generate the data. The initial conditions, boundary conditions, and some random data

in the space-time domain are used as training data to train the neural network model. In order to test the performance of the training model, we use the neural network model to make multiple predictions and compare it with the true solution of the partial differential equation. The specific experimental setup and procedure are as follows.

First, a neural network is designed for approximating the solutions of partial differential equations, denoted as *u*ˆ(*t*, *x*). For the architecture of the neural network, it contains six hidden layers, each with 100 neurons, and a hyperbolic tangent tanh is chosen as the activation function. Besides, a physics-informed neural network *f*(*t*, *x*) is constructed for introducing control information of the equation. In our experiments, we use TensorFlow to construct the neural network. As a widely used deep learning framework, it has sophisticated automatic differentiation, so it is easy to introduce information about the equations. Specifically, the definition of a physical information neural network *f*(*t*, *x*) is given by

$$f(t, \mathbf{x}) := \mathbf{u}\_{tt} - \mathbf{u}\_{xx} \tag{11}$$

The next main task is to train the parameters of the neural network *u*ˆ(*t*, *x*) and *f*(*t*, *x*). We continuously optimize the parameters by minimizing the mean square error loss to obtain the optimal parameters.

$$J(\theta) = MSE\_{\text{ul}} + MSE\_f \tag{12}$$

where

$$MSE\_u = \frac{1}{N\_u} \sum\_{i=1}^{N\_u} \left| \mu \left( t\_{u'}^i \mathbf{x}\_u^i \right) - \mu^i \right|^2 \tag{13}$$

$$MSE\_f = \frac{1}{N\_f} \sum\_{i=1}^{N\_f} \left| f \left( t\_f^i, x\_f^i \right) \right|^2 \tag{14}$$

where *MSEu* is a loss function constructed using observations of initial and boundary conditions. *MSEf* is a loss function that is based on partial differential equations for introducing physical information. Specifically, *t i <sup>u</sup>*, *x<sup>i</sup> <sup>u</sup>*, *u<sup>i</sup> Nu <sup>i</sup>*=<sup>1</sup> corresponds to the initial and boundary training data of *<sup>u</sup>*(*t*, *<sup>x</sup>*), and *Nu* is the number of data provided. In addition, *<sup>u</sup>*(*tf* , *xf*) and " *t i <sup>f</sup>* , *<sup>x</sup><sup>i</sup> f* #*Nf i*=1 corresponds to the training data of the spatio-temporal domain, and *Nf* is the corresponding number of training data. In this work, to fully consider the physical information embedded in the equations, we select the data in the spatio-temporal domain to train the neural network. The training data of the spatio-temporal domain is selected randomly, and the amount of training data *Nf* is 40,000. Besides, the total number of training data of the initial and boundary conditions is relatively small, and the expected effect can be achieved when *Nu* is 300. Similarly, the selection of training data for the initial and boundary conditions is also random. During the optimization procedure, we set the learning rate to 0.001, and in order to balance convergence speed and global convergence, we ran L-BFGS 30,000 epochs and then continued the optimization using Adam until convergence. In addition, we used the Glorot normal initializer [55] for initialization. In this experiment, the time to train the model was approximately fifteen minutes. We tested the effect of the model after completing the training of the neural network model. Figure 3 is the prediction of the neural network model obtained from the training, and it can be seen that the prediction obtained is quite complex. We choose different moments to compare the prediction with the exact solution to test the accuracy of this prediction. Figure 4 shows the comparison between the exact solution and the prediction at different times *t* = 0.2, 0.5, 0.8. From Figure 4, it can be seen that the predictions of the neural network model and exact solutions are very consistent, indicating that the constructed neural network model has a good ability to solve partial differential equations. In addition, the relative L2 error of this example was calculated to be 5.16 · <sup>10</sup>−4, which further validates the effectiveness of this method. Although the solution of the selected partial differential equations is complex, the neural network model can still approximate a result very close to the true solution from

the training data, indicating that the neural network with physical information has great potential and value, and is worthy of further research.

**Figure 3.** Solution of the wave equation given by physics-informed neural networks.

**Figure 4.** Comparison of the prediction given by physics-informed neural networks with the exact solution.

### *3.2. KdV-Burgers Equation*

We have studied the KdV–Burgers equation to further analyze the ability of physics-informed neural networks to solve complex partial differential equations. The KdV–Burgers equation is a nonlinear partial differential equation containing higher-order derivatives that has been of interest to many researchers [56,57]. Today, the KdV–Burgers equation is widely studied and applied in many fields, such as the study of the flow of liquids containing bubbles, the flow of liquids in elastic tubes, and other problems [58,59]. In mathematical form, the KdV–Burgers equation is defined, as follows:

$$
\alpha u\_t + \alpha u u\_x + \beta u\_{xx} + \gamma u\_{xxx} = 0 \tag{15}
$$

where *α*, *β* and *γ* are all non-zero real constants, i.e., *αβγ* = 0. Equation (15) can be reduced to Burgers equation [60] or Korteweg-de Vries (KdV) equation [61] in special cases. Specifically, when *γ* = 0, the Equation (15) is simplified to Burgers equation.

$$
\mu\_t + \mu u u\_x + \beta u\_{xx} = 0 \tag{16}
$$

The Burgers equation is a second-order nonlinear partial differential equation, which is used to simulate the propagation and reflection of shock waves. This equation is used in various fields of research, such as fluid dynamics and nonlinear acoustics (NLA) [60,62]. Besides, Equation (15) becomes the Korteweg-de Vries (KdV) equation when *β* is zero.

$$
\mu\_t + \kappa u u\_x + \gamma u\_{xxx} = 0 \tag{17}
$$

The Korteweg-de Vries (KdV) equation was first introduced in 1985 by Korteweg and de Vries. It is a very important equation, both mathematically and practically, for the description of small amplitude shallow-water waves, ion-phonon waves, and fluctuation phenomena in biological and physical systems [63,64]. This equation, which differs from the Burgers equation in that it does not introduce dissipation and it can explain the existence of solitary waves, is of great interest to physicists and mathematicians. Therefore, the study of this equation is of great scientific significance and research value.

The KdV–Burgers equation can be viewed as a combination of the Korteweg-de Vries equation and the Burgers equation, containing the nonlinearity *uux*, the dispersion *uxxx*, and the dissipation *uxx*, with high complexity. The equation has been applied in many fields and it has received a great deal of attention from many researchers. In this section, we use physics-informed neural networks to develop new methods for solving the KdV–Burgers equation.

In this experiment, an important task is to construct a high-quality training data set based on partial differential equations. For Equation (15), the values of *α*, *β*, *γ* are given as 1,−0.075,*π*/1000, respectively, and deterministic initial conditions are given. In mathematical form, the nonlinear KdV–Burgers equation with periodic boundary conditions studied in this section is defined, as follows

$$\begin{aligned} &u\_t + uu\_x - 0.075u\_{xx} + \pi/1000u\_{xxx} = 0, \quad \mathbf{x} \in [-1, 1], \quad t \in [0, 1] \\ &u(0, \mathbf{x}) = e^{(0.005 \ast \cos(\pi \ast \mathbf{x}))} \sin(\pi \ast \mathbf{x}) \\ &u(t, -1) = u(t, 1) \\ &u\_x(t, -1) = u\_x(t, 1) \end{aligned} \tag{18}$$

For Equation (18), we simulate it using conventional spectral methods and use the Chebfun package [65] in the programming implementation. Specifically, we integrate Equation (18) from the initial moment *t* = 0 to the final time *t* = 1.0 using a time step *t* = 10−6, depending on the initial and periodic boundary conditions. Besides, we use a fourth-order explicit Runge–Kutta temporal integrator and a spectral Fourier discretization with 512 modes to ensure the accuracy of the integration.

After obtaining the high-resolution dataset, we next constructed a neural network to approximate the solution of Equation (18), which is denoted as *u*ˆ(*t*, *x*). The neural network contains seven hidden layers, each containing 120 neurons, and the hyperbolic tangent tanh is chosen as the activation function. Besides, the physical information neural network *f*(*t*, *x*) is constructed using the automatic differentiation of the TensorFlow to introduce control information of the equations. Specifically, the physics-informed neural network *f*(*t*, *x*) is defined, as follows

$$f(t, \mathbf{x}) := u\_t + uu\_\mathbf{x} - 0.075u\_\mathbf{x} + \pi/1000u\_\mathbf{xxx} \tag{19}$$

The next main task is to train the parameters of the neural network *u*ˆ(*t*, *x*) and *f*(*t*, *x*). We continuously optimize the parameters by minimizing the mean square error loss to obtain the optimal parameters.

$$J(\theta) = MSE\_{\nu} + MSE\_{f} \tag{20}$$

where

$$MSE\_{u} = \frac{1}{N\_{u}} \sum\_{i=1}^{N\_{u}} \left| u\left(t\_{u'}^{i}, \mathbf{x}\_{u}^{i}\right) - u^{i}\right|^{2} \tag{21}$$

$$MSE\_f = \frac{1}{N\_f} \sum\_{i=1}^{N\_f} \left| f \left( t\_{f'}^i, x\_f^i \right) \right|^2 \tag{22}$$

Similar to the previous experiment, *MSEu* is a loss function constructed while using observations of the initial and boundary conditions. *MSEf* is a loss function that introduces physical information of partial differential equations. Specifically, *t i <sup>u</sup>*, *x<sup>i</sup> <sup>u</sup>*, *u<sup>i</sup> Nu <sup>i</sup>*=<sup>1</sup> corresponds to the initial and boundary condition data, *<sup>u</sup>*(*tf* , *xf*) and " *t i <sup>f</sup>* , *<sup>x</sup><sup>i</sup> f* #*Nf <sup>i</sup>*=<sup>1</sup> corresponds to the data in the space-time domain. In this experiment, the training data are also randomly selected from the generated dataset, with the number of training data *Nf* in the space-time domain being 30,000 and the number of training data *Nu* in the initial and boundary conditions being 400. We use the Glorot normal initializer for initialization. During optimization, we set the learning rate to 0.001 and, to ensure global convergence and speed up the convergence process, we ran L-BFGS for 30,000 epochs and then used Adam to continue optimizing until convergence. In this experiment, the time to train the model was approximately twelve minutes. After completing the training of the neural network model, we also tested the effects of the model. Figure 5 shows the predictions of the neural network model, and we can see that the resulting predictions are quite complex. We choose different moments of the prediction results to compare with the exact solution in order to test the accuracy of this prediction. Figure 6 shows the comparison between the exact solution and prediction at different moments *t* = 0.25, 0.5, 0.75. From Figure 6, it can be seen that the predictions of the neural network model and exact solutions are highly consistent, indicating that the constructed neural network model can solve the KdV–Burgers equation well. In addition, the relative L2 error of this example was calculated to be 4.79 · <sup>10</sup>−4, which further validates the effectiveness of this method. Despite the high complexity of the KdV–Burgers equation, the neural network model can still obtain results very close to the true solution from the training data, which again shows that the method has great potential and value and it is worthy of further research.

**Figure 5.** Solution of the Korteweg-de Vries (KdV)–Burgers equation given by physics-informed neural networks.

**Figure 6.** Comparison of the prediction given by physics-informed neural networks with the exact solution.

### *3.3. Two-Soliton Solution of the Korteweg-De Vries Equation*

KdV equations are an important class of equations that have soliton solutions, as introduced in Section 3.2. The study of the KdV equation is important for understanding the nature of solitons and the interaction of two or more solitons. Many scholars have studied the multi-soliton solutions of KdV equations [66] and, in this section, we employ the proposed physics-informed neural networks to study the following KdV equations:

$$
\mu\_t + \theta \mu u\_x + u\_{xxx} = 0, \quad -\infty < x < \infty \tag{23}
$$

when given the initial condition *u*(0, *x*) = 6 sec *h*2*x*, Equation (23) has the following two-soliton solution *u*(*x*, *t*)

$$u(\mathbf{x},t) = 12\frac{3+4\cosh(2\mathbf{x}-8t)+\cosh(4\mathbf{x}-64t)}{(3\cosh(\mathbf{x}-28t)+\cosh(3\mathbf{x}-36t))^2} \tag{24}$$

First, because this solution is valid for either positive or negative *t*, we obtained data based on the true solution for *x* ∈ [−20, 20] and *t* ∈ [−1, 1]. Some initial data and some random data in the space-time domain are selected as training data. Next, we constructed a neural network to approximate the solution of Equation (23), which is denoted as *u*ˆ(*t*, *x*). The neural network contains seven hidden layers, each containing 100 neurons, and the hyperbolic tangent tanh is chosen as the activation function. Besides, the physical information neural network *f*(*t*, *x*) is constructed using the automatic differentiation of the TensorFlow in order to introduce control information of the equation. Specifically, the physics-informed neural network *f*(*t*, *x*) is defined, as follows:

$$f(t, \mathbf{x}) := u\_t + \mathfrak{c}u u\_x + u\_{xxx} \tag{25}$$

Next, we obtain the optimal parameters of the neural network *u*ˆ(*t*, *x*) and *f*(*t*, *x*) by minimizing the following mean square error loss.

$$J(\theta) = MSE\_{\text{ul}} + MSE\_f \tag{26}$$

where

$$MSE\_{\mathcal{U}} = \frac{1}{N\_{\mathcal{U}}} \sum\_{i=1}^{N\_{\mathcal{U}}} \left| \mu \left( t\_{\boldsymbol{u}\boldsymbol{\prime}}^{i} \mathbf{x}\_{\boldsymbol{u}}^{i} \right) - \boldsymbol{u}^{i} \right|^{2} \tag{27}$$

$$MSE\_f = \frac{1}{N\_f} \sum\_{i=1}^{N\_f} \left| f \left( t\_{f'}^i, x\_f^i \right) \right|^2 \tag{28}$$

Similar to the previous experiment, on the one hand, the loss function *MSEu* is constructed based on the observation data of the initial condition and boundary condition; on the other hand, the loss function *MSEf* is constructed based on the physical information of the partial differential equation. Specifically, *t i <sup>u</sup>*, *x<sup>i</sup> <sup>u</sup>*, *u<sup>i</sup> Nu <sup>i</sup>*=<sup>1</sup> corresponds to the initial and boundary condition data, *u*(*tf* , *xf*) and " *t i <sup>f</sup>* , *<sup>x</sup><sup>i</sup> f* #*Nf i*=1 corresponds to the data in the space-time domain. In this experiment, we randomly select training data in the data set. The number of training data in the spatio-temporal domain is 60,000, and the number of training data that meets the initial and boundary conditions is 400. We used the Glorot normal initializer for initialization and set the learning rate for the optimization process to 0.001. We ran L-BFGS for 30,000 epochs and then used Adam to continue the optimization until convergence to ensure global convergence and speed up the convergence process. In this experiment, the time to train the model was about eighteen minutes. After training the neural network model, we used the model to make predictions and compared the results with the true solution. Figure 7 shows the predictions of the KdV equation given by physics-informed neural networks. Figure 8 gives the comparison between the prediction and the exact solution at different moments *t* = −0.5, 0, 0.5 to test the performance of the model. It can be seen from Figure 8 that the prediction is very close to the exact solution. Also, it can be seen that tall wave propagates faster than short waves, which is consistent with the nature of the solitary wave solutions of the KdV equation. Thus, it can be concluded that the neural network model can simulate the KdV equation well. Besides, the relative L2 error was calculated to be 4.62 · <sup>10</sup>−<sup>3</sup> in this case, further validating the effectiveness of the method. Because physics-informed neural

networks can simulate the solitary wave solution of the KdV equation well, we will apply this method to the study of multi-soliton solutions of other equations, such as the Boussinesq equation [66,67], in future studies.

**Figure 7.** Two-soliton solution of the KdV equation given by physics-informed neural networks.

**Figure 8.** Comparison of the prediction given by physics-informed neural networks with the exact solution.

### **4. Discussions**

In this study, three partial differential equations are investigated using physics-informed neural networks. Based on the characteristics of the wave equation, the KdV-Burgers equation, and the KdV equation, the corresponding neural network models were constructed and each model was experimentally analyzed. We compare the predictions of the physics-informed neural network with the true solutions of the equations and derive the following discussion.


to consider the discretization of partial differential equations and can learn the solutions of partial differential equations from a small amount of data. It is well known that popular computational fluid dynamics methods require consideration of the discretization of equations, such as finite difference methods. In practice, many engineering applications also need to consider the discretization of partial differential equations, for example, various numerical weather prediction models have made discretization schemes as an important part of their research. Physics-informed neural networks are well suited to solve this problem and, thus, this approach may have important implications for the development of computational fluid dynamics and even scientific computing.


### **5. Conclusions**

This paper introduces a method for solving partial differential equations by neural networks that fuse physical information. The method is an improvement on previous physics-informed neural networks. In this study, the physical laws that are contained in the partial differential equations are introduced into the neural networks as a regularization. This improvement motivates the neural network to better learn the solutions of the partial differential equations from a limited number of observations. Using the method presented in this paper, we have performed the experimental analysis of three important partial differential equations. The method proposed in this paper achieves good experimental results due to the powerful function approximation ability of neural networks and the physical information contained in the partial differential equations. It is believed that, in the future, physics-informed neural networks will profoundly influence the study of solving partial differential equations and even promote the development of the whole field of scientific computing. However, there are still many problems with physics-informed neural networks, such as how to better introduce physical information into neural networks and the possible non-convergence problem in loss function optimization, which is also the next focus of our research. Besides, in the future, we will focus on

analyzing the performance differences between PINN-based methods and FEM methods, comparing their accuracy, consumption time, and so on.

**Author Contributions:** Conceptualization, X.C. and Y.G.; methodology, Y.G. and X.C.; validation, M.G.; investigation, Y.G. and M.G.; writing—original draft preparation, Y.G.; writing—review and editing, B.L., M.G. and X.C.; visualization, Y.G. and M.G.; supervision, X.C. and B.L.; project administration, X.C. and B.L. All authors have read and agreed to the published version of the manuscript.

**Funding:** This research was funded by National Natural Science Foundation of China (Grant No.41475094) and the National Key R&D Program of China (Grant No.2018YFC1506704).

**Conflicts of Interest:** The authors declare no conflict of interest.

### **Abbreviations**

The following abbreviations are used in this manuscript:


### **References**


© 2020 by the authors. Licensee MDPI, Basel, Switzerland. This article is an open access article distributed under the terms and conditions of the Creative Commons Attribution (CC BY) license (http://creativecommons.org/licenses/by/4.0/).

## *Article* **Answer Set Programming for Regular Inference**

**Wojciech Wieczorek 1,\*, Tomasz Jastrzab <sup>2</sup> and Olgierd Unold <sup>3</sup>**


Received: 6 October 2020; Accepted: 25 October 2020; Published: 30 October 2020

**Abstract:** We propose an approach to non-deterministic finite automaton (NFA) inductive synthesis that is based on answer set programming (ASP) solvers. To that end, we explain how an NFA and its response to input samples can be encoded as rules in a logic program. We then ask an ASP solver to find an answer set for the program, which we use to extract the automaton of the required size. We conduct a series of experiments on some benchmark sets, using the implementation of our approach. The results show that our method outperforms, in terms of CPU time, a SAT approach and other exact algorithms on all benchmarks.

**Keywords:** answer set programming; non-deterministic automata induction; grammatical inference

### **1. Introduction**

The main problem investigated in this paper is as follows. Given a finite alphabet Σ, two finite subsets *<sup>S</sup>*+, *<sup>S</sup>*<sup>−</sup> ⊆ <sup>Σ</sup>∗, and an integer *<sup>k</sup>* > 0, find a *<sup>k</sup>*-state NFA *<sup>A</sup>* that recognizes a language *<sup>L</sup>* ⊆ <sup>Σ</sup><sup>∗</sup> such that *S*<sup>+</sup> ⊆ *L* and *S*<sup>−</sup> ⊆ Σ<sup>∗</sup> − *L*. In other words, we are dealing with the process of learning a finite state machine based on a set of labeled strings, thus building a model reflecting the characteristics of the observations. Machine learning of automata and grammars has a wide range of applications in such fields as syntactic pattern recognition, computational biology, systems modeling, natural language acquisition, and knowledge discovery (see [1–5]).

It is well known that NFA or regular expression minimization is computationally hard: it is PSPACE-complete [6]. Moreover, even if we specify the regular language by a deterministic finite automaton (DFA), the problem remains PSPACE-complete [7]. Angluin [8] showed that there is no polynomial-time algorithm for finding the shortest compatible regular expression for arbitrary given data (if P = NP). Thus we conjecture that the complexity of inferring a minimal-size NFA that matches a labeled set of input strings is probably exponential.

For the deterministic case, the problem is NP-complete [9]. Besides, in contrast to the NFAs, for a given regular language there is always exactly one minimum-size DFA (i.e., there is no other non-isomorphic DFA with the same minimal number of states). Therefore, is NFA induction harder than DFA induction? To answer this, let us compare the problem search space sizes expressed by the number of automata with a fixed number of states. Let *c* be the size of the alphabet and *k* the number of automaton states. The number of pairwise non-isomorphic minimal *k*-state DFAs over a *c*-letter alphabet is of order *k*2*k*−1*k*(*c*−1)*k*. The number of NFAs such that every state is reachable from the start state is of order 2*ck*<sup>2</sup> [10]. Thus, switching from determinism to non-determinism increases the search space enormously. However, on the other hand, it is well known that NFAs are more compact. A DFA could even be exponentially larger than a corresponding NFA for a given language.

The purpose of the present proposal is twofold. The first objective is to devise an algorithm for the smallest non-deterministic automaton problem. It entails preparing logical rules (this set of rules will be called an AnsProlog program) before starting the searching process. The second objective is to show how the ASP solvers help to tackle the regular inference problem for large-size instances and to compare our approach with the existing ones. Particularly, we will refer to the following exact NFA identification methods [11]:


We will also refer to a SAT encoding given in [5]. All four above-mentioned methods and a SAT encoding are thoroughly described in Section 4.2. To enable comparisons with other methods in the future, the Python implementation of our approach is made available via GitHub. The Python scripting language is used only for generating the appropriate AnsProlog facts and running Clingo, an ASP solver.

Another line of research concerns the induction of DFAs. The original idea of SAT encoding in this context comes from the work made by Heule and Verwer [12]. Their work, in turn, was based on the idea of transformation from DFA identification into graph coloring, which was proposed by Coste and Nicolas [13]. Zakirzyanov et al. [14] proposed BFS-based symmetry breaking predicates, instead of the original max-clique predicates, which improved the translation-to-SAT technique. The improvement was demonstrated with the experiments on randomly generated input data. The core idea is as follows. Consider a graph *G*, the vertices of which are the states of an initial automaton and there are edges between vertices that cannot be merged. Finding minimum-size DFA is equivalent to a graph coloring with a minimum number of colors. The graph coloring constraints, on the other hand, can be efficiently encoded into SAT according to Walsh [15].

In a more recent approach, satisfiability modulo theories (SMT) are explored. Suppose that *A* = (Σ, *Q* = {0, 1, ... , *K* − 1},*s* = 0, *F*, *δ*) is a target automaton and *P* is the set of all prefixes of *S*<sup>+</sup> ∪ *S*−. An SMT encoding proposed by Smetsers et al. [16] uses four functions: *δ* : *Q* × Σ → *Q*, *<sup>m</sup>*: *<sup>P</sup>* <sup>→</sup> *<sup>Q</sup>*, *<sup>λ</sup><sup>A</sup>* : *<sup>Q</sup>* → {⊥, }, *<sup>λ</sup><sup>T</sup>* : *<sup>S</sup>*<sup>+</sup> <sup>∪</sup> *<sup>S</sup>*<sup>−</sup> → {⊥, }, where {⊥, } represents logical {false, true}, and the following five constraints:

$$m(\varepsilon) = 0,$$

$$\mathbf{x} \in \mathcal{S}\_{+} \iff \lambda^{T}(\mathbf{x}) = \top,$$

$$\forall \mathbf{x} a \in P \colon \mathbf{x} \in \Sigma^{\*}, a \in \Sigma \quad \delta(m(\mathbf{x}), a) = m(\mathbf{x} a),$$

$$\forall \mathbf{x} \in \mathcal{S}\_{+} \cup \mathcal{S}\_{-} \quad \lambda^{A}(m(\mathbf{x})) = \lambda^{T}(\mathbf{x}),$$

$$\forall q \in \mathcal{Q} \quad \forall a \in \Sigma \quad \bigvee\_{r \in \mathcal{Q}} \delta(q, a) = r.$$

They implemented the encodings using Z3Py, the Python front-end of an efficient SMT solver Z3.

This paper is organized into five sections. In Section 2, we present necessary definitions and facts originating from automata, formal languages, and declarative problem-solving. Section 3 describes our inference algorithm based on solving an AnsProlog program. Section 4 shows the experimental results of our approach and describes in detail all reference methods. Concluding comments are made in Section 5.

### **2. Preliminaries**

We assume the reader to be familiar with basic regular language and automata theory, for example, from [17], so that we introduce only some notations and notions used later in the paper.

#### *2.1. Words and Languages*

An *alphabet* Σ is a finite, non-empty set of symbols. A *word w* is a finite sequence of symbols chosen from an alphabet. The length of word *w* is denoted by |*w*|. The *empty word ε* is the word with zero length. Let *x* and *y* be words. Then *xy* denotes the *concatenation* of *x* and *y*, that is, the word formed by making a copy of *x* and following it by a copy of *y*. As usual, Σ∗ denotes the set of words over Σ. A word *w* is called a *prefix* of a word *u* if there is a word *x* such that *u* = *wx*. It is a *proper* prefix if *x* = *ε*. A set of words taken from some Σ∗, where Σ is a particular alphabet, is called a *language*.

A *sample S* is an ordered pair *S* = (*S*+, *S*−) where *S*+, *S*<sup>−</sup> are finite languages with an empty intersection (i.e., having no common word). *S*<sup>+</sup> is called the *positive part of S* (*examples*), and *S*<sup>−</sup> the *negative part of S* (*counter-examples*).

### *2.2. Non-Deterministic Finite Automata*

A *non-deterministic finite automaton* (NFA) is a five-tuple *A* = (Σ, *Q*,*s*, *F*, *δ*) where Σ is an alphabet, *Q* is a finite set of states, *s* ∈ *Q* is the initial state, *F* ⊆ *Q* is a set of final states, and *δ* is a relation from *Q* × Σ to *Q*. Members of *δ* are called *transitions*. A transition ((*q*, *a*),*r*) ∈ *δ* with *q*,*r* ∈ *Q* and *a* ∈ Σ, is usually written as *r* ∈ *δ*(*q*, *a*). Relation *δ* specifies the moves: the meaning of *r* ∈ *δ*(*q*, *a*) is that automaton *A* in the current state *q* reads *a* and can move to state *r*. If for given *q* and *a* there is no such *r* that ((*q*, *a*),*r*) ∈ *δ*, the automaton stops and we can assume it enters the rejecting state. Moving into a state that is not final is also regarded as rejecting but it may be just an intermediate state.

It is convenient to define ¯ *<sup>δ</sup>* as a relation from *<sup>Q</sup>* <sup>×</sup> <sup>Σ</sup><sup>∗</sup> to *<sup>Q</sup>* by the following recursion: ((*q*, *ya*),*r*) <sup>∈</sup> ¯ *δ* if ((*q*, *<sup>y</sup>*), *<sup>p</sup>*) <sup>∈</sup> ¯ *<sup>δ</sup>* and ((*p*, *<sup>a</sup>*),*r*) <sup>∈</sup> *<sup>δ</sup>*, where *<sup>a</sup>* <sup>∈</sup> <sup>Σ</sup>, *<sup>y</sup>* <sup>∈</sup> <sup>Σ</sup>∗, and requiring ((*t*,*ε*), *<sup>t</sup>*) <sup>∈</sup> ¯ *δ* for every state *t* ∈ *Q*. The *language accepted* by an automaton *A* is then

$$L(A) = \{ \mathbf{x} \in \Sigma^\* \mid \text{there is } q \in F \text{ such that } ((s, \mathbf{x}), q) \in \delta \}. \tag{1}$$

Two automata are *equivalent* if they accept the same language.

Let *A* = (Σ, *Q*,*s*, *F*, *δ*) be an NFA. Then we will say that *x* ∈ Σ<sup>∗</sup> is: (a) *recognized by accepting* (or *accepted*) if there is *<sup>q</sup>* <sup>∈</sup> *<sup>F</sup>* such that ((*s*, *<sup>x</sup>*), *<sup>q</sup>*) <sup>∈</sup> ¯ *δ*, (b) *recognized by rejecting* if there is *q* ∈ *Q* − *F* such that ((*s*, *<sup>x</sup>*), *<sup>q</sup>*) <sup>∈</sup> ¯ *δ*, and (c) *rejected* if it is not accepted.

### *2.3. Answer Set Programming*

Let us shortly introduce the idea of answer set programming (ASP). The readers interested in the details of ASP, alternative definitions, and the formal specification of AnsProlog are referred to handbooks [18–20].

Let A be a set of *atoms*. A *rule* is of the form:

$$a \gets b\_1, \dots, b\_{k'} \sim \mathcal{c}\_1, \dots, \sim \mathcal{c}\_m. \tag{2}$$

where *a*, *bi*s, and *ci*s are atoms and *k*, *m* ≥ 0. The *head* of the rule, *a*, may be absent. The part on the right of '←' is called the *body* of the rule. The symbol ∼ is called *default negation* and, by analogy to database systems, in logic programming it refers to the absence of information. Informally, *a* ← ... ∼ *b* means: if ... and there is no evidence for *b* then *a* should be included into a solution. A *program* Π is a finite set of rules.

Let *R* be the set of rules of the form:

$$a \leftarrow b\_1, \ldots, b\_k. \tag{3}$$

and A be a set of atoms occurring in *R*. The *model* of a set *R* of rules without negated atoms is a subset *M* ⊆ A which fulfills the following conditions:


Alternatively, if all atoms were treated as Boolean variables (i.e., presence is true, absence is false), *M* would be the model of an *R* exactly when all rules (i.e., clauses) are satisfied.

The semantics of a program is defined by an answer set as follows. The *reduct* Π*<sup>X</sup>* of a program Π relative to a set *X* of atoms is defined by

$$\Pi^X = \{a \gets b\_1, \dots, b\_k. \mid a \gets b\_1, \dots, b\_k. \sim \mathcal{c}\_1, \dots, \sim \mathcal{c}\_m \in \Pi \text{ and } \{c\_1, \dots, c\_m\} \cap X = \mathcal{O}\}. \tag{4}$$

The <sup>⊆</sup>-smallest model of <sup>Π</sup>*<sup>X</sup>* is denoted by Cn(Π*X*). A set *<sup>X</sup>* of atoms is an *answer set* of <sup>Π</sup> if *X* = Cn(Π*X*).

For the sake of simplicity, AnsProlog programs are written using variables (by convention, variables start with uppercase letters). Such programs are then grounded, i.e., transformed to programs with no variables, by applying a Herbrand substitution. Note, however, that clever grounding discards rules that are redundant, i.e., that can never apply, because some atoms in their bodies have no possibility to be derived [19]. For example, the program:

*el*(*a*) ←. *el*(*b*) ←. *equal*(*L*, *L*) ← *el*(*L*). *neq*(*L*, *Y*) ← *el*(*L*), *el*(*Y*), ∼*equal*(*L*,*Y*).

can be transformed to Π:

*el*(*a*) ←. *el*(*b*) ←. *equal*(*a*, *a*) ← *el*(*a*). *equal*(*b*, *b*) ← *el*(*b*). *neq*(*a*, *a*) ← *el*(*a*), *el*(*a*), ∼*equal*(*a*, *a*). *neq*(*a*, *b*) ← *el*(*a*), *el*(*b*), ∼*equal*(*a*, *b*). *neq*(*b*, *a*) ← *el*(*b*), *el*(*a*), ∼*equal*(*b*, *a*). *neq*(*b*, *b*) ← *el*(*b*), *el*(*b*), ∼*equal*(*b*, *b*).

which has a single answer set: *X* = {*equal*(*a, a*)*, equal*(*b, b*)*, el*(*a*)*, el*(*b*)*, neq*(*b, a*)*, neq*(*a, b*)}. A reduct Π*<sup>X</sup>* becomes:

*el*(*a*) ←. *el*(*b*) ←. *equal*(*a*, *a*) ← *el*(*a*). *equal*(*b*, *b*) ← *el*(*b*). *neq*(*a*, *b*) ← *el*(*a*), *el*(*b*). *neq*(*b*, *a*) ← *el*(*b*), *el*(*a*).

Its minimal model Cn(Π*X*) is just *X*. In other words, a set *X* of atoms is an answer set of a logic program Π if: (i) *X* is a classical model of Π and (ii) all atoms in *X* are justified by some rule in Π.

Recently, Answer Set Programming has emerged as a declarative problem-solving paradigm. This particular way of programming in AnsProlog is well-suited for modeling and solving problems that involve common sense reasoning. It has been fruitfully used in a range of applications.

Early ASP solvers used backtracking to find solutions. With the evolution of Boolean SAT solvers, several ASP solvers were built on top them. The approach taken by these solvers was to convert the ASP formula into SAT propositions, apply the SAT solver, and then convert the solutions back to ASP form. Newer systems, such as Clasp (which is a part of the Clingo solver, https://potassco.org/clasp/), take advantage of the conflict-driven algorithms inspired by SAT, without the complete conversion into a Boolean-logic form. These approaches improve the performance significantly, often by an order of magnitude, over earlier backtracking algorithms [21].

### **3. Proposed Encoding for the Induction of NFA**

Our translation reduces NFA identification into an AnsProlog program. Suppose we are given a sample *S* over an alphabet Σ, and a positive integer *k*. We want to find a *k*-state NFA *<sup>A</sup>* = (Σ, {*q*0, *<sup>q</sup>*1, ... , *qk*−1}, *<sup>q</sup>*0, *<sup>F</sup>*, *<sup>δ</sup>*) such that every *<sup>w</sup>* ∈ *<sup>S</sup>*<sup>+</sup> is recognized by accepting and every *w* ∈ *S*<sup>−</sup> is recognized by rejecting. The parameter *k* can be regarded as the degree of data generalization. The smallest *k*, say *k*0, for which our logic program has an answer set, will give the most general automaton. As *k* increases, we obtain a set of nested languages, the largest for *k*<sup>0</sup> and the smallest for some *km* ≥ *k*0. Usually, the running time for *k* > *k*<sup>0</sup> is shorter than for *k*0.

Let Pref(*S*) be the set of all prefixes of *S*<sup>+</sup> ∪ *S*−. The relationship between an automaton *A* and a sample *S* in terms of ASP is constrained as shown below in seven groups of rules. In rules (5)–(24) the following convention for naming variables is used: *P* stands for a prefix, *N* stands for a number (state index), *I*, *J*, and *M* also represent state indexes, *C* stands for a character (the element of alphabet), *W* stands for word (which is also a prefix), *U* represents another prefix.

1. We have the following domain specification, i.e., our AnsProlog facts.

$$q(i) \leftarrow \cdot \quad \text{for all } i \in \{0, 1, \ldots, k - 1\}. \tag{5}$$

$$symbal(a) \leftarrow \cdot \quad \text{ for all } a \in \Sigma. \tag{6}$$

$$\text{Prefix}(p) \leftarrow \cdot \quad \text{for all } p \in \text{Pref}(S). \tag{7}$$

*positive*(*s*) ← . for all *s* ∈ *S*+. (8)

*negative*(*s*) ← . for all *s* ∈ *S*−. (9)

$$\text{join}(u, u, v) \leftarrow \cdot \quad \text{for all } u, v \in \text{Pref}(S) \text{ and } a \in \Sigma \text{ such that } ua = v. \tag{10}$$

Facts (5) and (6) define the set of states *Q* and the input alphabet Σ, while facts (7)–(9) describe the input sample. In particular, they define the prefixes as well as words to be recognized by accepting and rejecting, respectively.

Finally, fact (10) defines the concatenation operation, which given prefix *u* ∈ Pref(*S*) and symbol *a* ∈ Σ produces prefix *v* ∈ Pref(*S*).

2. The next rules ensure that in an automaton *A* every prefix goes to at least one state and every state is final or not.

$$\mathbf{x}(P,N) \leftarrow \operatorname{prefix}(P),\\\mathbf{q}(N), \sim \operatorname{not\\_x}(P,N). \tag{11}$$

$$\text{not\\_x}(P, N) \leftarrow \text{prefix}(P), \newline q(N), \sim \text{x}(P, N). \tag{12}$$

$$\text{has\\_state}(P) \leftarrow \text{prefix}(P), \text{q(N)}, \text{x(P, N)}.\tag{13}$$

$$
\leftarrow \text{prefix}(P)\_{\prime} \sim \text{has\\_state}(P). \tag{14}
$$

$$
\Box f \\
\text{final}(\text{N}) \leftarrow \newline q(\text{N})\_\prime \sim \text{not\\_final}(\text{N}).\tag{15}
$$

$$\text{not\\_final(N)} \leftarrow q(N), \sim \text{final(N)}.\tag{16}$$

Rules (11) and (12) describe the reachability of states *q* ∈ *Q* by prefixes *p* ∈ Pref(*S*). State *q* is *reachable* by prefix *p iff* the prefix can be read by following a series of transitions from state *q*<sup>0</sup> to state *q* (this series of transitions builds a *path* for prefix *p*). The unreachable states are described by the default negation rule *not*\_*x*. Clearly, for every prefix *p* ∈ Pref(*S*) and every state *q* ∈ *Q*, either (11) or (12) holds. Here *P* (a prefix) and *N* (a number, state index) are variables, which means that during the grounding they will be substituted for, respectively, every *p* ∈ Pref(*S*) because of the atom *prefix*(*P*) in the body of the rule and for every *i* ∈ {0, 1, ... , *k* − 1} because of the atom *q*(*N*) in the body of the rule. Notice that for every *p* ∈ Pref(*S*) we already have fact *prefix*(*p*) and for every *i* ∈ {0, 1, ... , *k* − 1} we already have fact *q*(*i*), which are the sources of this substitution. Rules (13) and (14) declare that for every prefix *p* ∈ Pref(*S*) there has to be some reachable state *q* ∈ *Q*. These rules follow from the fact that the members of sets *S*<sup>+</sup> and *S*<sup>−</sup> have to be recognized by accepting or rejecting, respectively. In other words, for each *w* ∈ (*S*<sup>+</sup> ∪ *S*−) there has to be at least one path in the inferred NFA.

Finally, rules (15) and (16) ensure that each state *q* ∈ *Q* is either accepting (final) or rejecting (not final). Such rules as the pair (15) and (16) are recommended in ASP textbooks to specify that each element either is/has something or is/has not (refer for example to Chapter 4 of Chitta Baral's [18]).

3. For encoding transitions we will use predicates *delta*.

$$\text{delta(I, C, J)} \leftarrow q(I), \\ \text{symbol(C), q(J)} \leftarrow \text{not\\_delta} \\ \text{delta(I, C, J)}.\tag{17}$$

$$\text{not\\_delta}(I, \mathbb{C}, I) \leftarrow q(I),\\\text{symbol}(\mathbb{C}), q(I), \sim \text{delta}(I, \mathbb{C}, I). \tag{18}$$

Rule (17) says that if there exists a transition between a pair of states *qi*, *qj* ∈ *Q*, marked with a symbol *c* ∈ Σ then *delta*(*I*, *C*, *J*) is in the model. Otherwise, the default negation rule *not*\_*delta* applies (rule (18)).

4. Without sacrificing the generality, we can assume that *q*<sup>0</sup> is the initial state.

$$
\leftarrow \sim \mathfrak{x}(\varepsilon, 0). \tag{19}
$$

$$
\leftarrow \ge \mathfrak{x}(\varepsilon, N), \mathfrak{q}(N), N \neq 0. \tag{20}
$$

Rules (19) and (20) mean that only state *q*<sup>0</sup> is reachable by the empty word *ε*.

5. Every counter-example has to be recognized by rejecting.

$$\mathbf{x} \leftarrow q(\mathbf{N}), \mathbf{x}(\mathbf{N}, \mathbf{N}), \text{final}(\mathbf{N}), \text{negative}(\mathbf{W}). \tag{21}$$

Recall that for the headless rules at least one predicate present in the body of the rule cannot be satisfied. Hence, rule (21) means that there is no final state that is reachable by any word *w* ∈ *S*−.

6. Every example has to be recognized by accepting. In this rule we used an extension syntax of ASP—a choice construction. Here, it means that the number of final states, *qn*, for which ((*q*0, *<sup>W</sup>*), *qn*) <sup>∈</sup> ¯ *δ* cannot be equal to 0 for any example *w*.

$$\left\{ \leftarrow \text{ positive(N)}, \left\{ \left[ \text{final(N)} : q(N), \text{x(N, N)} \right] \right\} = 0. \right. \tag{22}$$

7. Finally, there are mutual constraints between *x* and *delta* predicates.

$$\text{ax}(\mathcal{W}, \mathcal{M}) \leftarrow q(I), q(\mathcal{M}), \newline \text{join}(\mathcal{U}, \mathcal{C}, \mathcal{W}), \newline \text{x}(\mathcal{U}, I), \newline \text{delta}(I, \mathcal{C}, \mathcal{M}). \tag{23}$$

$$\leftarrow \left\{ \text{join}(\mathcal{U}, \mathcal{C}, \mathcal{W}), q(\mathcal{N}), \text{x}(\mathcal{W}, \mathcal{N}), \left\{ \text{delta}(\mathcal{I}, \mathcal{C}, \mathcal{N}) : q(\mathcal{J}), \text{x}(\mathcal{U}, \mathcal{J}) \right\} \right\} = 0. \tag{24}$$

Rule (23) says that for some state *r* that is reachable by word *w* = *uc*, there exists some state *qi* reachable by word *u* and there is a transition between states *qi* and *qm* with symbol *c*.

Similarly, rule (24) says that if there is a word *w* = *uc* leading to some state *qi* ∈ *Q*, then the number of transitions with symbol *c* outgoing from a state reachable by word *u* cannot be zero.

**Example 1.** *Let us see an example. Suppose we are given S*<sup>+</sup> = {*abc*, *c*}*, S*<sup>−</sup> = {*a*, *ab*}*, and k* = 2*. Rules (5) to (10) concretize into:*

*q(0)* ←*. q(1)* ←*. symbol(a)* ←*. symbol(b)* ←*. symbol(c)* ←*. prefix(ε)* ←*. prefix(a)* ←*. prefix(ab)* ←*. prefix(c)* ←*. prefix(abc)* ←*. positive(c)* ←*. positive(abc)* ←*. negative(a)* ←*. negative(ab)* ←*. join(ε*, *a*, *a)* ←*. join(ε*, *c*, *c)* ←*. join(a*, *b*, *ab)* ←*. join(ab*, *c*, *abc)* ←*.*

*Rules (11) to (24) always remain unchanged. This program has an answer set* {*q(0),* ...*, delta(*1, *b*, 1*), final(0)*}*. In order to construct an associated NFA it is enough to take all final and delta predicates, which define, respectively, final states and transitions of the resultant automaton. So we have obtained an NFA depicted in Figure 1.*

Additionally, in Appendix A there is a description of how answer sets are determined. In Appendix B a larger illustration is given.

**Figure 1.** An inferred non-deterministic finite automaton (NFA).

### **4. Experimental Results**

In this section, we describe some experiments comparing the performance of our approach (the program can be found at https://gitlab.com/wojtek3dan/asp4nfa) with the methods mentioned in the introductory section and described in more detail in Section 4.2. We used an ASP solver, Clingo, which can be executed sequentially or in parallel [22]. While comparing our approach with RA-PS1, RA-PS2, OA-PS1, and OA-PS2, all programs ran on an 8-core processor. ASP vs. SAT comparison was performed using a single core. For these experiments, we used a set of 40 samples (the samples can be found at https://gitlab.com/wojtek3dan/asp4nfa/-/tree/master/samples) based on randomly generated regular expressions.

### *4.1. Benchmarks*

As far as we know, all standard benchmarks are too hard to be solved by pure exact algorithms. Thus, we generated problem instances using our own algorithm. This algorithm builds a set of words with the following parameters: size |*E*| of a regular expression to be generated, alphabet size |Σ|, the number |*S*| of words actually generated and their minimum, *d*min, and maximum, *d*max, lengths. The algorithm is arranged as follows. First, construct a random regular expression *E*. Next, obtain corresponding minimum-state DFA *M*. Then, as long as a sample *S* is not symmetrically structurally complete (refer to Chapter 6 of [3] for the formal definition of this concept) with respect to *M*, repeat the following steps: (a) using the Xeger library (https://pypi.org/project/xeger/) for generating random strings from a regular expression, get two words *u* and *w*; (b) truncate as few symbols from the end of *w* as possible in order to obtain a counter-example *w*¯; if it succeeds, add *u* to *S*<sup>+</sup> and *w*¯ to *S*−. Finally, accept *S* = (*S*+, *S*−) as a valid sample if it is not too small, too large or highly imbalanced. In order to ensure that these conditions are fulfilled, the equations |*S*+| ≥ 8, |*S*−| ≥ 8, and |*S*+| + |*S*−| ≤ 1000 hold for all our samples. In generating a random word from a regex or from

an automaton we encounter a problem with, respectively, star operator and self-loops. Theoretically, there are infinitely many words matched to these fragments, so we have to bound the number of repetitions. We set this parameter to four.

In this manner we produced 40 sets with: |*E*| ∈ [27, 46], |Σ|∈{2, 4, 6, 8}, |*S*| ∈ [27, 958], *d*min = 0, and *d*max = 305. The file names with samples have the form 'a|Σ|words|*E*|.txt'. To give the reader a hint on the variability of the resulting automata, we show in Table 1 the numbers of states and transitions in each of the 40 NFAs found using our approach. We show there also the size of *M* and the size of minimal DFA *D* compatible with sample data. Example solutions from each group of problems, defined by the size of the input alphabet |Σ| are also shown in Figure 2.


**Table 1.** Sizes of NFAs found by the answer set programming (ASP) solver (*k*0—number of states, *t*—number of transitions, |*M*|—number of states in deterministic finite automaton (DFA) *M*, |*D*|—number of states in minimal DFA *D* compatible with sample data).

**Figure 2.** Example solutions found by the ASP solver for problems a2words33, a4words28, a6words31, and a8words37. (**a**) Problem a2words33; (**b**) Problem a4words28; (**c**) Problem a6words31; (**d**) Problem a8words37.

### *4.2. Compared Algorithms*

As already mentioned, our algorithm was compared with a SAT-based algorithm and several exact parallel algorithms. To make the paper self-contained let us briefly describe these algorithms.

The SAT-based algorithm defines three types of binary variables, *xwq*, *yapq*, and *zq*, for *w* ∈ Pref(*S*), *a* ∈ Σ, *p*, *q* ∈ *Q*. Variable *xwq* = 1 *iff* state *q* is reachable by prefix *w*, otherwise *xwq* = 0. Variable *yapq* = 1 *iff* there exists a transition from state *p* to state *q* with symbol *a*, otherwise *yapq* = 0. Finally, *zq* = 1 *iff* state *q* is final, and *zq* = 0 otherwise. The constraints involving these variables are as follows:

1. All examples have to be accepted, while none of the counter-examples should be, which is described by

$$\forall\_{w \in S\_{+} - \{\varepsilon\}} \sum\_{\eta \in Q} x\_{w\eta} z\_{\eta} \ge 1,\tag{25}$$

$$\forall\_{\text{uv}\in S\_{-}-\{\varepsilon\}} \sum\_{q\in Q} x\_{\text{uv}q} z\_q = 0. \tag{26}$$

2. All prefixes *w* = *a*, *w* ∈ Pref(*S*), *a* ∈ Σ, result from the transitions outgoing from state *q*<sup>0</sup>

$$
\forall\_{w \gets a} x\_{w\emptyset} - y\_{a\emptyset \emptyset} = 0. \tag{27}
$$

3. For all states *q* ∈ *Q* reachable by prefixes *w* = *va*, *v*, *w* ∈ Pref(*S*), *a* ∈ Σ, there has to be some state *r* reachable by prefix *v*, and there has to be an outgoing transition from *r* to *q* with symbol *a*. By symmetry, if there exists a path for prefix *v* ending in some state *r* and there exists a transition from *r* to *q* with symbol *a* then there exists a path to state *q* with prefix *w* = *va*. These conditions are expressed as

$$\forall\_{\mathbf{u}\simeq\mathbf{u}a} - \mathbf{x}\_{\mathbf{u}\mathbf{q}} + \sum\_{r\in Q} \mathbf{x}\_{\mathbf{r}r} y\_{ar\mathbf{q}} \ge \mathbf{0}\_{\prime} \tag{28}$$

$$\forall\_{q,r\in Q} \mathbf{x}\_{u\eta} - \mathbf{x}\_{vr} \mathbf{y}\_{ar\eta} \ge 0. \tag{29}$$

Additionally, it holds that *zq*<sup>0</sup> = 1, when *ε* ∈ *S*+, *zq*<sup>0</sup> = 0, when *ε* ∈ *S*−, and *zq*<sup>0</sup> is not predefined when *ε* ∈/ (*S*<sup>+</sup> ∪ *S*−). The solution to the presented problem formulation is sought by a SAT solver.

**Example 2.** *Let us consider Example 1 again. In the SAT-based formulation we have the following variables xaq*<sup>0</sup> *, xaq*<sup>1</sup> *,* ...*, xabcq*<sup>1</sup> *, yaq*0*q*<sup>0</sup> *, yaq*0*q*<sup>1</sup> *,* ...*, ycq*1*q*<sup>1</sup> *, zq*<sup>0</sup> *, and zq*<sup>1</sup> *. Constraints (25)–(29) remain unchanged. A set of assignments satisfying the constraints at hand is as follows: xaq*<sup>1</sup> = 1*, xabq*<sup>1</sup> = 1*, xcq*<sup>0</sup> = 1*, xabcq*<sup>0</sup> = 1*, yaq*0*q*<sup>1</sup> = 1*, ybq*1*q*<sup>1</sup> = 1*, ycq*0*q*<sup>0</sup> = 1*, ycq*1*q*<sup>0</sup> = 1*, zq*<sup>0</sup> = 1*. All remaining variables are zeros. The resulting NFA is shown in Figure 3. Note that even though the set of transitions in Figure 3 is smaller than in Figure 1 both solutions are valid.*

**Figure 3.** An inferred NFA.

Identification of a *k*-state NFA by means of the exact algorithms RA-PS1, RA-PS2, OA-PS1, and OA-PS2 is based on the SAT formulation given before. Assuming *k* is fixed we only need to determine the set of final states *F* and the transition function *δ*. Let us recall that a set of final states is *feasible iff* the following conditions are satisfied: (i) *F* = ∅, (ii) *q*<sup>0</sup> ∈ *F*, if *ε* ∈ *S*+, (iii) *q*<sup>0</sup> ∈/ *F*, if *ε* ∈ *S*−. Clearly, an NFA without final states cannot accept any word, and if the empty word *ε* is in *S*<sup>+</sup> (resp. *S*−) the initial state *q*<sup>0</sup> has to be final (resp. not final). Since every feasible set *F* may lead to an NFA consistent with the sample *S* (as the NFAs need not be unique), we distribute the different sets *F* among processes and try to identify the *δ* function by means of a backtracking algorithm.

While searching for the values of *yapq* variables, we apply different search orders. This is so, because there is no universal ordering method assuring fast convergence to the solution. The orderings used in the analyzed algorithms are *deg*, *mmex*, and *mmcex*. The *deg* ordering is a static ordering method based on a variable degree, i.e., the number of constraints the variables are involved in. The ordering does not change as the algorithm progresses. The *mmex* and *mmcex* orderings change dynamically while the algorithm runs. They aim at satisfying first the equations related to examples, or counter-examples, respectively.

The Parallelization Scheme 1 (PS1) maximizes the number of sets *F* processed simultaneously. If the number of available processes is greater than the number of sets *F* to be analyzed, we assign multiple variable orderings (VOs) to each set. In the RA-PS1 algorithm this assignment is performed randomly, while in the OA-PS1 algorithm, the *deg*, *mmex*, and *mmcex* methods are ordered by their complexity and chosen in a round robin fashion.

The Parallelization Scheme 2 (PS2) maximizes the number of variable orderings applied to the same set *F*. This way we shorten the time needed to obtain an answer whether an NFA exists for the given set *F*. If the number of available processes is smaller than the product of the number of sets *F* and the number of variable orderings used, we need to choose the sets *F* to be processed first. In the RA-PS2 algorithm we choose them at random, while in the OA-RS2 algorithm, we analyze first the sets for which the size of *F* is smaller.

**Example 3.** *Let us consider the problem given in Example 1. Since k* = 2 *and ε* ∈/ (*S*<sup>+</sup> ∪ *S*−)*, the following sets F can be defined: F*<sup>1</sup> = {*q*0}*, F*<sup>2</sup> = {*q*1}*, and F*<sup>3</sup> = {*q*0, *q*1}*. Let us also assume that we can use the three VOs discussed before. Finally, let the number of processes p* = 3 *(denoted by pi, for i* = 0, 1, 2*). We can have the following example configurations of algorithms RA-PS1, RA-PS2, OA-PS1, OA-PS2:*

*1. Algorithm RA-PS1—process p*<sup>0</sup> *gets* (*F*1, *VO*3)*; process p*<sup>1</sup> *gets* (*F*2, *VO*2)*; process p*<sup>2</sup> *gets* (*F*3, *VO*3)*. Each process uses a single VO to analyze one of the possible sets Fi, i* = 1, 2, 3*. There is no guarantee that all VOs are used at least once.*


*Note that in Parallelization Scheme 2, obtaining a negative answer, i.e., that an NFA does not exist for the given set Fi, by means of one VO allows us to stop the execution of other VOs and move on to another set Fj, i* = *j.*

### *4.3. Performance Comparison*

In all experiments, we used Intel (Santa Clara, California, U.S.) Xeon CPU E5-2650 v2, 2.6 GHz (8 cores, 16 threads), under Ubuntu 18.04 operating system with 190 GB RAM. The time limit (TL) was set to 1000 s. The results are listed in Table 2. In order to determine whether the observed mean difference between ASP and the remaining methods is a real CPU time decrease, we used a paired samples *t* test [23] pp. 1560–1565, for ASP vs. SAT, ASP vs. RA-PS1, ASP vs. RA-PS2, ASP vs. OA-PS1, and ASP vs. OA-PS2. As we can see from Table 3, *p* value is low in all cases, so we can conclude that our results did not occur by chance and that using our ASP encoding is likely to improve CPU time performance for prepared benchmarks.

Let us explain how the mean values were computed. All TL cells were substituted by 1000. Notice that this procedure does not violate the significance of the statistical tests, because our program completed computations within the time limit for all problems (files). Thus, determining all running times would even strengthen our hypothesis.

To make the advantage of the ASP-based approach over the exact parallel algorithms even more convincing let us analyze the largest sizes of automata analyzed by the algorithms within the time limit TL = 1000 s. The summary of obtained sizes is given in Table 4. Note that the table includes only the problems for which TL entries exist in Table 2. The entries marked with \* denote executions in which the algorithms started running for the given *k* but were terminated due to the time limit, without producing the final NFA.


**Table 2.** Execution times of exact solving NFA identification in seconds.


**Table 2.** *Cont*.

**Table 3.** Obtained *p* values from the paired samples *t* test.


**Table 4.** Sizes of NFAs reached by the parallel algorithms. The sign \* means that the time limit was exceeded.



**Table 4.** *Cont*.

### **5. Conclusions**

We have experimented with a model learning approach based on ASP solvers. The approach is very flexible, as proven by its successful adaptation for learning NFAs, implemented in the provided open source tool. Experiments indicate that our approach clearly outperforms the current state-of-the-art satisfiability-based method and all backtracking algorithms proposed in the literature. The approach does scale well (as far as non-deterministic acceptors are considered): we have shown that it can be used for learning models from up to a thousand words. In the future, we wish to develop more efficient encodings that will make the approach scale even better. We hope this paper encourages more interest in ASP-based problem solving since the presented approach has several benefits over traditional model learning algorithms. The ASP encoding is more readable than SAT encoding and the resulting program is much faster than its backtracking counterparts.

**Author Contributions:** Conceptualization, W.W.; methodology, O.U. and W.W.; software, T.J. and W.W.; validation, T.J. and O.U.; formal analysis, O.U.; investigation, W.W. and T.J.; resources, W.W.; writing—original draft preparation, W.W. and T.J.; writing—review and editing, O.U. and T.J.; supervision, O.U.; project administration, O.U.; funding acquisition, O.U. All authors have read and agreed to the published version of the manuscript.

**Funding:** This research was supported by the National Science Center (Poland), grant number 2016/21/B/ST6/02158.

**Conflicts of Interest:** The authors declare no conflict of interest.

#### **Appendix A. How Answer Sets Are Computed**

Let Π be a grounded program, and let A be a set of atoms occurring in Π. Assume that atom *z* is not in A. Observe that every grounded program Π that has rules

$$
\lambda \gets b\_{1\prime}, \dots, b\_{k\prime} \sim c\_{1\prime}, \dots, \sim c\_m. \tag{A1}
$$

with empty head, can be transformed to a program without such rules by inserting *z* and ∼*z* in this manner:

$$z \gets b\_1, \dots, b\_{k\prime} \sim c\_{1\prime}, \dots, \sim c\_{m\prime} \sim z. \tag{A2}$$

A grounded program without empty-headed rules will be called *normal*.

A rule *r* of the form:

$$a \leftarrow b\_1, \ldots, b\_k. \tag{A3}$$

where there is no default negation in the body, and the head is not empty, will be called *positive*. A program that contains only positive rules will be called positive too. We will denote by head(*r*) the

set {*a*}, and by body(*r*) the set {*b*1, ... , *bk*}. Now, let us define how a positive program *P* can act on a set of atoms *X* ⊆ A (here A is a set of atoms occurring in *P*):

$$X^P = \{ \text{head}(r) \mid r \in P \text{ and } \text{body}(r) \subseteq X \}. \tag{A4}$$

This operation can be repeated and we define:

$$X^{P^1} = X^P \quad \text{and} \quad X^{P^i} = (X^{P^{i-1}})^P. \tag{A5}$$

It is easy to see that Cn(*P*) = % *<sup>i</sup>*≥<sup>1</sup> <sup>∅</sup>*P<sup>i</sup>* . Because for a certain *i* the equation *XP<sup>i</sup>* = *XPi*+<sup>1</sup> holds, determining Cn(*P*) is straightforward and fast.

Consider any normal program Π. We recall from Section 2.3 that a set *X* ⊆ A is an answer set of Π if *X* = Cn(Π*X*) (please do not confuse the reduct with program's acting on a set of atoms). Take two sets, *<sup>L</sup>* and *<sup>U</sup>*, such that *<sup>L</sup>* <sup>⊆</sup> *<sup>X</sup>* <sup>⊆</sup> *<sup>U</sup>* for an answer set *<sup>X</sup>* of <sup>Π</sup>. Observe that: (i) *<sup>X</sup>* <sup>⊆</sup> Cn(Π*L*), and (ii) Cn(Π*U*) <sup>⊆</sup> *<sup>X</sup>*. Thus we get:

$$L \cup \mathsf{Cn}(\Pi^{\mathsf{LI}}) \subseteq X \subseteq \mathsf{UI} \cap \mathsf{Cn}(\Pi^{\mathsf{L}}).\tag{A6}$$

The last property is a recipe for expanding the lower bound *L* and cutting down the upper bound *<sup>U</sup>*. The procedure in which we replace *<sup>L</sup>* by *<sup>L</sup>* <sup>∪</sup> Cn(Π*U*) and then *<sup>U</sup>* by *<sup>U</sup>* <sup>∩</sup> Cn(Π*L*) as long as *<sup>L</sup>* or *<sup>U</sup>* are changed, will be called *narrowing*. At some point we get *L* = *U* = *X*. When we start from *L* = ∅, *U* = A, then there are also two more possibilities: *L* ⊆ *U* (there is no answer set), and *L* ⊂ *U*. In the latter case we can take any *a* ∈ *U* − *L* and check out two paths: *a* should be included into *L* or *a* should be excluded from *U*. This leads to Algorithm A1 [19]:

**Algorithm A1:** Final algorithm

```
SOLVE(Π, L, U)
   (L, U) ← narrowing(Π, L, U)
   if L ⊆ U then return
   if L = U then output L
   else
       choose a ∈ U − L
       SOLVE(Π, L ∪ {a}, U)
       SOLVE(Π, L, U − {a})
```
Which outputs all answer sets of a program Π provided that it had been invoked with SOLVE(Π, ∅, A). The pessimistic time complexity of this algorithm can be assessed by the recurrence relation *<sup>T</sup>*(*n*) = <sup>2</sup>*T*(*<sup>n</sup>* <sup>−</sup> <sup>1</sup>) + *<sup>n</sup>*2, where *<sup>n</sup>* <sup>=</sup> <sup>|</sup>*U*|−|*L*|, which gives us the exponential complexity *T*(*n*) = *O*(2*n*).

### **Appendix B. The Complete Example of an ASP Program for NFA Induction**

Suppose we are given Σ = {*a*, *b*}, *S*<sup>+</sup> = {*a*}, *S*<sup>−</sup> = {*b*}, and *k* = 2. After grounding rules (5)–(24) we get a program Π in a Clasp format (symbol :- denotes left arrow, symbol not denotes default negation, and symbol lambda denotes the empty word *ε*):

```
symbol(a).
symbol(b).
prefix(lambda).
prefix(a).
prefix(b).
positive(a).
negative(b).
join(lambda,a,a).
```

```
join(lambda,b,b).
q(0).
q(1).
delta(0,a,0):-not not_delta(0,a,0).
delta(1,a,0):-not not_delta(1,a,0).
delta(0,b,0):-not not_delta(0,b,0).
delta(1,b,0):-not not_delta(1,b,0).
delta(0,a,1):-not not_delta(0,a,1).
delta(1,a,1):-not not_delta(1,a,1).
delta(0,b,1):-not not_delta(0,b,1).
delta(1,b,1):-not not_delta(1,b,1).
not_delta(0,a,0):-not delta(0,a,0).
not_delta(1,a,0):-not delta(1,a,0).
not_delta(0,b,0):-not delta(0,b,0).
not_delta(1,b,0):-not delta(1,b,0).
not_delta(0,a,1):-not delta(0,a,1).
not_delta(1,a,1):-not delta(1,a,1).
not_delta(0,b,1):-not delta(0,b,1).
not_delta(1,b,1):-not delta(1,b,1).
x(lambda,0):-not not_x(lambda,0).
x(a,0):-not not_x(a,0).
x(b,0):-not not_x(b,0).
x(lambda,1):-not not_x(lambda,1).
x(a,1):-not not_x(a,1).
x(b,1):-not not_x(b,1).
not_x(lambda,0):-not x(lambda,0).
not_x(a,0):-not x(a,0).
not_x(b,0):-not x(b,0).
not_x(lambda,1):-not x(lambda,1).
not_x(a,1):-not x(a,1).
not_x(b,1):-not x(b,1).
x(a,0):-delta(1,a,0),x(lambda,1).
x(a,1):-delta(1,a,1),x(lambda,1).
x(b,0):-delta(1,b,0),x(lambda,1).
x(b,1):-delta(1,b,1),x(lambda,1).
x(a,0):-delta(0,a,0),x(lambda,0).
x(a,1):-delta(0,a,1),x(lambda,0).
x(b,0):-delta(0,b,0),x(lambda,0).
x(b,1):-delta(0,b,1),x(lambda,0).
:-x(a,0),0>=#count{0,delta(0,a,0):x(lambda,0),delta(0,a,0);
0,delta(1,a,0):delta(1,a,0),x(lambda,1)}.
:-x(a,1),0>=#count{0,delta(0,a,1):x(lambda,0),delta(0,a,1);
0,delta(1,a,1):x(lambda,1),delta(1,a,1)}.
:-x(b,0),0>=#count{0,delta(0,b,0):x(lambda,0),delta(0,b,0);
0,delta(1,b,0):x(lambda,1),delta(1,b,0)}.
:-x(b,1),0>=#count{0,delta(0,b,1):x(lambda,0),delta(0,b,1);
0,delta(1,b,1):x(lambda,1),delta(1,b,1)}.
final(0):-not not_final(0).
final(1):-not not_final(1).
not_final(0):-not final(0).
not_final(1):-not final(1).
:-0>=#count{0,final(0):final(0),x(a,0);0,final(1):final(1),x(a,1)}.
:-final(0),x(b,0).
:-final(1),x(b,1).
:-x(lambda,1).
```

```
:-not x(lambda,0).
has_state(lambda):-x(lambda,0).
has_state(a):-x(a,0).
has_state(b):-x(b,0).
has_state(lambda):-x(lambda,1).
has_state(a):-x(a,1).
has_state(b):-x(b,1).
:-not has_state(lambda).
:-not has_state(a).
:-not has_state(b).
```
One of the answer sets *X* is:

```
q(0) q(1) prefix(lambda) prefix(a) prefix(b) symbol(a) symbol(b)
negative(b) positive(a) join(lambda,a,a) join(lambda,b,b) x(lambda,0)
not_x(lambda,1) has_state(lambda) not_delta(0,a,0) not_delta(1,a,0)
delta(0,b,0) not_delta(1,b,0) delta(0,a,1) not_delta(1,a,1)
not_delta(0,b,1) not_delta(1,b,1) not_x(a,0) x(b,0) x(a,1)
not_x(b,1) not_final(0) final(1) has_state(a) has_state(b)
```
Note that the above answer set corresponds to a 2-state automaton having non-final state *q*<sup>0</sup> and final state *q*<sup>1</sup> (see predicates not\_final(0) and final(1)), and the following transitions *q*<sup>0</sup> ∈ *δ*(*q*0, *b*), *q*<sup>1</sup> ∈ *δ*(*q*0, *a*) (defined by predicates delta(0,b,0) and delta(0,a,1)).

It can be easily verified that Π*<sup>X</sup>* = *P* is the positive program:

```
symbol(a).
symbol(b).
prefix(lambda).
prefix(a).
prefix(b).
positive(a).
negative(b).
join(lambda,a,a).
join(lambda,b,b).
q(0).
q(1).
delta(0,b,0).
delta(0,a,1).
not_delta(0,a,0).
not_delta(1,a,0).
not_delta(1,b,0).
not_delta(1,a,1).
not_delta(0,b,1).
not_delta(1,b,1).
x(lambda,0).
x(b,0).
x(a,1).
not_x(a,0).
not_x(lambda,1).
not_x(b,1).
x(a,0):-delta(1,a,0),x(lambda,1).
x(a,1):-delta(1,a,1),x(lambda,1).
x(b,0):-delta(1,b,0),x(lambda,1).
x(b,1):-delta(1,b,1),x(lambda,1).
x(a,0):-delta(0,a,0),x(lambda,0).
x(a,1):-delta(0,a,1),x(lambda,0).
```

```
x(b,0):-delta(0,b,0),x(lambda,0).
x(b,1):-delta(0,b,1),x(lambda,0).
final(1).
not_final(0).
has_state(lambda):-x(lambda,0).
has_state(a):-x(a,0).
has_state(b):-x(b,0).
has_state(lambda):-x(lambda,1).
has_state(a):-x(a,1).
has_state(b):-x(b,1).
```
Cn(*P*) = *X* since % *<sup>i</sup>*≥<sup>1</sup> <sup>∅</sup>*P<sup>i</sup>* <sup>=</sup> <sup>∅</sup>*<sup>P</sup>* <sup>∪</sup> (∅*P*)*<sup>P</sup>* <sup>=</sup> *<sup>X</sup>*. Further action of *<sup>P</sup>* on *<sup>X</sup>* does not change it.

### **References**


**Publisher's Note:** MDPI stays neutral with regard to jurisdictional claims in published maps and institutional affiliations.

© 2020 by the authors. Licensee MDPI, Basel, Switzerland. This article is an open access article distributed under the terms and conditions of the Creative Commons Attribution (CC BY) license (http://creativecommons.org/licenses/by/4.0/).

## *Article* **Parsing Expression Grammars and Their Induction Algorithm**

### **Wojciech Wieczorek 1,\*, Olgierd Unold <sup>2</sup> and Łukasz Str ˛ak <sup>3</sup>**


Received: 27 October 2020 ; Accepted: 1 December 2020; Published: 7 December 2020

### **Featured Application: PEG library for Python.**

**Abstract:** Grammatical inference (GI), i.e., the task of finding a rule that lies behind given words, can be used in the analyses of amyloidogenic sequence fragments, which are essential in studies of neurodegenerative diseases. In this paper, we developed a new method that generates non-circular parsing expression grammars (PEGs) and compares it with other GI algorithms on the sequences from a real dataset. The main contribution of this paper is a genetic programming-based algorithm for the induction of parsing expression grammars from a finite sample. The induction method has been tested on a real bioinformatics dataset and its classification performance has been compared to the achievements of existing grammatical inference methods. The evaluation of the generated PEG on an amyloidogenic dataset revealed its accuracy when predicting amyloid segments. We show that the new grammatical inference algorithm achieves the best ACC (Accuracy), AUC (Area under ROC curve), and MCC (Mathew's correlation coefficient) scores in comparison to five other automata or grammar learning methods.

**Keywords:** classification; genetic programming; grammatical inference; parsing expression grammar

### **1. Introduction**

The present work sits in the scientific field known as grammatical inference (GI), automata learning, grammar identification, or grammar induction [1,2]. The matter under consideration is the set of rules that lie behind a given sequence of words (so-called strings). The main task is to discover the rule(s) that will help us to evaluate new, unseen words. Mathematicians investigate infinite sequences of words and for this purpose they proposed a few inference models. In the most popular model, Gold's identification in the limit [3], learning happens incrementally. After each new word, the algorithm returns some hypothesis, i.e., an automaton or a grammar, and a entire process is regarded as successful when the algorithm returns a correct answer at a certain iteration and does not change it afterwards. However, very often in practice we deal only with a limited number of words (some of them being examples and others counter-examples). In such cases the best option is to use a selected heuristic algorithm, among which the most recognized instances include: evidence driven state merging [4], the *k*-tails method [5], the GIG method [6], the TBL (tabular representation learning) algorithm [7], the learning system ADIOS (automatic distillation of structure) [8], error-correcting grammatical inference [9], and alignment-based learning [10]. However, all of these methods output classical acceptors like (non)deterministic finite state automata (FSA) or context-free

grammars (CFG). FSAs are fast in recognition but lack in expressiveness. CFGs, on the other hand, are more expressive but need more computing time for recognizing. We propose here using parsing expression grammars (PEGs), which are as fast as FSAs and can express more than CFGs, in the sense that they can represent some context-sensitive grammars. To the best of our knowledge no one else devised a similar induction algorithm before. As far as non-Chomsky grammars are considered for representing acceptors, Eyraud et al. [11] applied a string-rewriting system to the GI domain. However, as the authors claimed, pure context-sensitive languages can probably not be described with their tool. PEGs are relatively new, but have been implemented in few applications (e.g., Extensible Markup Language schema validation using Document Type Definition automatic transformation [12] and a text pattern-matching tool [13]).

The purpose of the present proposal is threefold. The first objective is to devise an induction algorithm that will suit well real biological data-amyloidogenic sequence fragments. Amyloids are proteins capable of forming fibrils instead of the functional structure of a protein, and are responsible for a group of serious diseases. The second objective is to determine that the proposed algorithm is also well suited for the benchmark data as selected comparative grammatical inference (GI) algorithms and a machine learning approach (SVM). We assume that the given strings do not contain periodically repeated substrings, which is why it has been decided to build up non-circular PEGs that represent finite sets of strings. The last objective is to write a Python library for handling PEGs and make it available to the community. Although there are at least three other Python packages for generating PEG parsers, namely Arpeggio (http://www.igordejanovic.net/Arpeggio), Grako (https://bitbucket.org/ neogeny/grako), and pyPEG (https://fdik.org/pyPEG), our implementation (https://github.com/ wieczorekw/wieczorekw.github.io/tree/master/PEG) is worth noting for its simple usage (integration with Python syntax via native operators) and because it is dozens of times faster in processing long strings, as will be shown in detail in Section 3.3. In addition to Python libraries, to enrich the research, a library named EGG (https://github.com/bruceiv/egg/tree/deriv) written in C++ was used for comparison, in which an expression has to be compiled into machine code before it is used [14].

This paper is organized into five sections. Section 2 section introduces the notion of parsing expression grammars and also discusses their pros and cons in comparison with regular expressions and CFGs. Section 3 describes the induction algorithm. Section 4 discusses the experimental results. Section 5 summarizes the collected results.

### **2. Definition of PEGs**

PEGs reference *regular expressions* (RE) and *context-free grammars* (CFGs), both derivative from formal language theory. We briefly introduce the most relevant definitions.

An alphabet Σ is a non-empty set of symbols (characters without any meaning). A string or word (*s*, *w*) is a finite sequence of symbols. The special case of the string is an empty string (the empty sequence of symbols). The example of the alphabet is a set {*a*, *b*, *c*} and an example of strings over the alphabet is {*a*, *aa*, *ab*, *ba*, *abc*}. A formal language *L* over an alphabet Σ is a subset of Σ<sup>∗</sup> (Kleene star, all strings over Σ). A regular expression is a formal way of describing the class of languages called *regular language*. Let *r*, *r*1, and *r*<sup>2</sup> be the regular expression over Σ, and *a*, *b* ∈ Σ; the following operations are allowed in syntax:


Given an alphabet Σ = {*a*, *b*}, a formal language *L* = {*w* ∈ Σ<sup>∗</sup> | *w* begins with *a* and ends with *a*} can be expressed as the regular expression *a*(*a* | *b*)∗*a*. CFG is a tuple of *G* = (*V*, Σ, *R*, *S*), where *V* is the final set of nonterminal symbols, Σ is the final set of terminal symbols disjoint from *V*, *R* is a finite relation *V* → (*V* ∪ Σ) and defines rules, and *S* is the start symbol, chosen from *V*. The most common *R* is defined as the production rule notation; for example, for formal language: *<sup>L</sup>* <sup>=</sup> {*<sup>w</sup>* <sup>∈</sup> <sup>Σ</sup><sup>∗</sup> <sup>|</sup> *<sup>w</sup>nwn <sup>n</sup>* <sup>&</sup>gt; <sup>1</sup>}, the equivalent context-free grammar is *G* = ({*S*}, {*a*, *b*}, *P*, *S*) with the productions:

$$\begin{aligned} \mathcal{S} &\to aSa \\ \mathcal{S} &\to bSb \\ \mathcal{S} &\to a \mid b \end{aligned}$$

The word *aba* can be accepted using the first production and the third one. The book by Hopcroft et al. [15] contains more information related to the formal language field.

The formalism of PEGs was introduced by Bryan Ford in 2004 [16]. However, herein we give definitions and notation compatible with the provided PEG library. Let us start with an informal introduction to parsing expression grammars (PEGs).

A *parsing expression grammar* (PEG) is a 4-tuple *G* = (*V*, *T*, *R*,*s*), where *V* is a finite set of nonterminal symbols, *T* is a finite set of terminal symbols (letters), *R* is a finite set of rules, *s* is a parsing expression called the *start expression*, and *V* ∩ *T* = ∅. Each rule *r* ∈ *R* is a pair (*A*, *e*), which we write as *A* ⇐ *e*, where *A* ∈ *V* and *e* is a parsing expression. For any nonterminal *A*, there is exactly one *e* such that *A* ⇐ *e* ∈ *R*. We define *parsing expressions* inductively as follows. If *e*, *e*1, and *e*<sup>2</sup> are parsing expressions, then so is:


The choice of operators , |, +, <sup>∼</sup>, and ⇐ is caused by being consistent with our Python library. The operators have their counterparts in Python (>>, |, +, ˜, and <=) with the proper precedence. Thus the reader is able to implement expressions in a very natural way, using the native operators.

A PEG is an instance of a recognition system, i.e., a program for recognizing and possibly structuring a string. It can be written in any programming language and looks like a grammar combined with a regex, but its interpretation is different. Take as an example the following regex: (*<sup>a</sup>* <sup>|</sup> *<sup>b</sup>*)+*b*. We can write a "similar" PEG expression: <sup>+</sup>(*<sup>a</sup>* <sup>|</sup> *<sup>b</sup>*) *<sup>b</sup>*. The regex accepts all words over the alphabet {*a*, *b*} that end with the letter *b*. The PEG expression, on the contrary, does not recognize any word since PEGs behave greedily, so the part +(*a* | *b*) will consume all letters, including the last *b*. An appropriate PEG solution resembles a CFG:

$$\begin{aligned} A &\Leftarrow a \mid b \\ E &\Leftarrow b \gg \sim A \mid A \gg E \end{aligned}$$

The sign ⇐ associates an expression to a nonterminal. The sign denotes concatenation. What makes a difference is the ordered choice | and not-predicate <sup>∼</sup>. The nonterminal *E* first tries to consume the final *b*; then, in the case of failure, it consumes *a* or *b* and recursively invokes itself. In order to write parsing expressions in a convenient way we will freely omit unnecessary parentheses assuming the following operators precedence (from highest to lowest): <sup>∼</sup>, +, , |, ⇐. The Kleene star operation can be performed via +*e* | (Python does not have a unary star operator and the PEG implementation library had to be adjusted). The power of PEGs is clearly visible in fast, linear-time parsing and in the possibility of expressing some context-sensitive languages [17].

From now on we will use the symbols *a*, *b*, and *c* to represent pairwise different terminals, *A*, *B*, *C*, and *D* for pairwise different nonterminals, *x*, *x*1, *x*2, *y*, and *z* for strings of terminals, where |*x*1| = *k* (*k* ≥ 0), |*x*2| = *m* (*m* ≥ 0), and *e*, *e*1, and *e*<sup>2</sup> for parsing expressions. To formalize the syntactic meaning of a PEG *G* = (*V*, *T*, *R*,*s*), we define a function consume(*e*, *x*), which outputs a nonnegative integer (the number of "consumed" letters) or nothing (None):


The language *L*(*G*) of a PEG *G* = (*V*, *T*, *R*,*s*) is the set of strings *x* for which consume(*s*, *x*) = None. Please note that the definition of the language of a PEG differs fundamentally from the much more well-known CFGs: in the former it is enough to consume any prefix of a word (including the empty one) to accept it, and in the latter the whole word should be consumed to accept it. Direct (like *A* ⇐ *A e*) as well as indirect left recursions are forbidden, since it can lead to an infinite loop while performing the consume function. It is worth emphasizing that the expression <sup>∼</sup>(∼*e*) works as non-consuming matching. As a consequence, we can perform language intersection *L*(*G*1) ∩ *L*(*G*2) by writing <sup>∼</sup>(∼*s*1) *s*<sup>2</sup> if only *G*<sup>1</sup> = (*V*1, *T*, *R*1,*s*1), *G*<sup>2</sup> = (*V*2, *T*, *R*2,*s*2), and *V*<sup>1</sup> ∩ *V*<sup>2</sup> = ∅. Interestingly, it is not proven yet that there exist context-free languages that cannot be recognized by a PEG.

In the next section we deal with non-circular PEGs that will have to be understood as grammars without any recursions or repetitions. Note that such a non-circular PEG, say *G* = ({*A*}, *T*, {*A* ⇐ *e*}, *A*), can be written as a single expression *e* with no nonterminal and no + operation.

#### **3. Induction Algorithm**

The proposed algorithm is based on the genetic programming (GP) paradigm [18]. In it, machine learning can be viewed as requiring discovery of a computer program (an expression in our case) that produces some desired output (the decision class in our case) for particular inputs (strings representing proteins in our case). When viewed in this way, the process of solving problems becomes equivalent to searching a space of possible computer programs for a fittest individual computer program. In this paradigm, populations of computer programs are bred using the principle of survival of the fittest and using a crossover (recombination) operator appropriate for mating computer programs.

This section is split into two subsections. In the first subsection, we will describe the scheme of the GP method adapted to the induction problem. In the second, a deterministic algorithm for the obtaining of an expression matched to the data will be presented. This auxiliary algorithm is used to feed an initial population of GP with promising individuals.

#### *3.1. Genetic Programming*

Commonly, genetic programming uses a generational evolutionary algorithm. In generational GP, there exist well-defined and distinct generations. Each generation is represented by a population of individuals. The newer population is created from and then replaces the older population. The execution cycle of the generational GP—which we used in experiments—includes the following steps:

	- Select two individuals in the current population using a selection algorithm.
	- Perform genetic operations on the selected individuals.
	- Insert the result of crossover, i.e., the better one out of two children, into the emerging population.

In order to put the above procedure to work, we have to define the following elements and routines of GP: the primitives (known in GP as the terminal set and the function set), the structure of an individual, the initialization, genetic operators, and the fitness function.

Individuals are parse trees composed of the PEG's operators <sup>∼</sup>, , and |, and terminals are elements of Σ ∪ {}, where Σ is a finite alphabet (see an example in Figure 1).

**Figure 1.** Example of a genetic programming individual coded as the expression *a b* (*c* <sup>∼</sup>*d* | *b* | *a*).

An initial population is built upon *S*<sup>+</sup> (positive strings, examples) and *S*<sup>−</sup> (negative strings, counterexamples) so that each tree is consistent with a randomly chosen *k*-element subset, *X*, of *S*+ and a randomly chosen *k*-element subset, *Y*, of *S*−. An expression forming an individual in an initial population is created by means of a deterministic algorithm given further on. In a crossover procedure, two expressions given as parse trees are involved. A randomly chosen part of the first tree is replaced by another randomly chosen part from the second tree. The same operation is performed on the second tree in the same manner. We also used tournament selection, in which *r* (the tournament size) individuals are chosen at random and one of them with the highest fitness is returned. Finally, the fitness function measures an expression's accuracy based on an individual *e*, and the sample (*S*+, *S*−) with Equation (1):

$$f(t) = \frac{|\{w \in \mathbb{S}\_{+} \colon w \in L(G(\varepsilon))\}| + |\{w \in \mathbb{S}\_{-} \colon w \notin L(G(\varepsilon))\}|}{|\mathbb{S}\_{+}| + |\mathbb{S}\_{-}|},\tag{1}$$

where *G*(*e*) is a non-circular PEG *G*(*e*)=({*A*}, Σ, {*A* ⇐ *e*}, *A*).

### *3.2. Deterministic Algorithm Used in Initializing a GP Population*

For a set of strings, *S*, and a letter, *r*, by a left quotient, denoted by *r*−1*S*, we will mean the set {*<sup>w</sup>* : *rw* <sup>∈</sup> *<sup>S</sup>*}, i.e., *<sup>a</sup>*−1{*ax*, *ax*1, *<sup>x</sup>*2} <sup>=</sup> {*x*, *<sup>x</sup>*1}. Let *<sup>X</sup>* and *<sup>Y</sup>* be pairwise disjoint, nonempty sets of words over an alphabet Σ. Our aim is to obtain a compact non-circular PEG *G* satisfying the following two conditions: (i) *X* ⊆ *L*(*G*), (ii) *Y* ∩ *L*(*G*) = ∅. The Algorithm 1 (function *I*(*X*,*Y*)) does it recursively.


The "Append" method used in lines 7, 9, and 13 in Algorithm 1 concatenates the existing rule *e* with a new expression. The recursive call *I* in line 9 cuts sets *X* and *Y* to words that start with terminal symbol *a* and then all words that satisfy this condition are passed with words without the first symbol *a* (according to the left quotient). Line 10 is used when set *A* is empty. The execution of the algorithm is shown by the following example. The input is *X* = {*abba*, *bbbb*, *abaa*, *abbb*, *bbaa*, *bbab*}, *Y* ={*baaa*, *aaab*, *babb*, *aaba*, *aaaa*, *baba*}. Figure 2 shows the successive steps of the algorithm. At the beginning set, *A* and *B* are determined. The first symbols in the sets of words *X*, *Y* are equal to {*a*, *b*}. The terminal symbol *a* of the set *A* belongs to the set *B*. The string *a* is added to the rule *e* and the method is recursively invoked with left quotients *a*−1*X* and *a*−1*Y* (left leaf from the root in the Figure 2). In the next step, the *a* symbol is not in the set *B*. From the recursive call, the *b* symbol is returned and added to the *e* rule. After returning, the same procedure is repeated for the symbol *b*.

The algorithm has the following properties: (i) the *X*, *Y* sets in successive calls are always nonempty, (ii) word lengths can be different, (iii) it always halts after a finite number of steps and returns some PEG, (iv) the resultant PEG is consistent with *X* (examples) and *Y* (counter-examples), and (v) for random words output PEGs are many times smaller than the input. Properties from (i) to (iii) are quite obvious, (iv) will be proven, and we have checked (v) in a series of experiments, and detailed results are given in the next subsection.

Let *n* = |*X*| + |*Y*|, *m* = |Σ|, the length of every word (from *X* and *Y*) not exceed *d*, and *T* be the running time of the algorithm. Then, for random input words, we can write *T*(*n*, *m*, *d*) = *<sup>O</sup>*(*n*) + *mT*(*n*/*m*, *<sup>m</sup>*, *<sup>d</sup>* <sup>−</sup> <sup>1</sup>), which leads to *<sup>O</sup>*(*m*min{*d*,log *<sup>n</sup>*}*n*). In practice, *<sup>m</sup>* and *<sup>d</sup>* are small constants, so the running time of *I*(*X*,*Y*) is usually linear with respect to *n*.

**Lemma 1.** *Let* Σ *be a finite alphabet and X, Y be two disjoint, finite, nonempty sets of words over* Σ*. If x* ∈ *X and e is a parsing expression returned by I*(*X*,*Y*) *then consume(e*, *x)* = *None.*

**Proof.** We will prove the above by induction on *k*, where *k* ≥ 0 is the length of *x*.

Basis: We use *k* = 0 as the basis. Because *k* = 0, *x* = . Let us consider two cases: (1) is the only word in *X*, and (2) |*X*| ≥ 2. In the first case, lines 5–9 of the algorithm are skipped and (in line 11)

*e* = <sup>∼</sup>(*a* | *b* | *c* | ...) is returned, where *a*, *b*, *c*, ... are the first letters of *Y* (since is in *X*, *Y* has to contain at least one nonempty word). consume(*a* | *b* | *c* | ... , ) = None implies consume(*e*, ) = 0. In the second case, the loop in lines 5–9 and line 13 are executed, so the returned expression *e* (in line 14) has the following form: *α*<sup>1</sup> | *α*<sup>2</sup> | ··· | *α<sup>j</sup>* | <sup>∼</sup>(*a* | *b* | *c* | ...) , where *α<sup>i</sup>* is a single letter, say *ri*, or *ri β<sup>i</sup>* with *β<sup>i</sup>* being some parsing expression. For such an *e*, consume(*e*, ) = 0 holds too.

Induction: Suppose that |*x*| = *k* + 1 and that the statement of the lemma holds for all words of length *j*, where 0 ≤ *j* ≤ *k*. Let *x* = *uw*, where *u* ∈ Σ. Obviously |*w*| = *k*. Again let us consider two cases: (1) *u* ∈/ *B*, and (2) *u* ∈ *B*. In the first case, *e*, which is returned in line 14, has the form *u* | *α* or *α* | *u* | *β* or *α* | *u*, where *α* and *β* are some expressions. In either case consume(*e*, *uw*) ≥ 0 (at least *u* will not fail for *x* = *uw*). In the second case, *e*, which is returned in line 14, is a sequence of addends, one of which is *<sup>u</sup> <sup>α</sup>*, where *<sup>α</sup>* <sup>=</sup> *<sup>I</sup>*(*u*−1*X*, *<sup>u</sup>*−1*Y*). Suffix *<sup>w</sup>* is an element of the set *<sup>u</sup>*−1*<sup>X</sup>* so we invoke the inductive hypothesis to claim that consume(*α*, *w*) = None. Then consume(*e*, *uw*) = None, because of the properties of the sequence and the prioritized choice operators (at least *u α* will not fail for *x*).

**Lemma 2.** *Let* Σ *be a finite alphabet and X, Y be two disjoint, finite, nonempty sets of words over* Σ*. If y* ∈ *Y and e is a parsing expression returned by I*(*X*,*Y*) *then consume(e*, *y)* = *None.*

**Figure 2.** Example of the proposed Induction algorithm.

**Proof.** We will prove the above by induction on *k*, where *k* ≥ 0 is the length of *y*.

Basis: We use *k* = 0 as the basis, i.e., *y* = . Because ∈/ *X*, the returned (in line 13) expression *e* has the following form: *α*<sup>1</sup> | *α*<sup>2</sup> |···| *αj*, where *α<sup>i</sup>* is a single letter, say *ri*, or *ri β<sup>i</sup>* with *β<sup>i</sup>* being some parsing expression. For such an *e*, consume(*e*, ) = None.

Induction: Suppose that |*y*| = *k* + 1 and that the statement of the lemma holds for all words of length *j*, where 0 ≤ *j* ≤ *k*. Let *y* = *uw*, where *u* ∈ Σ. Naturally |*w*| = *k*. There are two main cases to consider: (1) *A* is empty (that happens only when *X* = {}), and (2) *A* is not empty. In the first case, *e* = <sup>∼</sup>(*a* | *b* | *c* | ... | *u* | ...) is returned, where *a*, *b*, *c*, ..., *u*, ... are the first letters of *Y* (the position of *u* in the sequence is not important). consume(*a* | *b* | *c* | ... | *u* | ... , *uw*) = 1 implies consume(*e*, *y*) = None. In the second case (i.e., *A* is not empty), let us consider four sub-cases: (2.1) *u* ∈ *A* and ∈ *X*, (2.2) *u* ∈ *A* and ∈/ *X*, (2.3) *u* ∈/ *A* and ∈ *X*, and (2.4) *u* ∈/ *A* and ∈/ *X*. As for (2.1), the returned expression *e* is of the following form: *α*<sup>1</sup> | *α*<sup>2</sup> |···| *α<sup>j</sup>* | <sup>∼</sup>(*a* | *b* | *c* | ... | *u* | ...) , where *α<sup>i</sup>* is a single letter, say *ri*, or *ri β<sup>i</sup>* with *β<sup>i</sup>* being some parsing expression. Exactly one of *α<sup>i</sup>* has the form *<sup>u</sup> <sup>β</sup><sup>i</sup>* (exactly one *ri* <sup>=</sup> *<sup>u</sup>*), where *<sup>β</sup><sup>i</sup>* <sup>=</sup> *<sup>I</sup>*(*u*−1*X*, *<sup>u</sup>*−1*Y*). Suffix *<sup>w</sup>* is an element of the set *<sup>u</sup>*−1*<sup>Y</sup>* so by the induction hypothesis consume(*βi*, *w*) = None. Then consume(*e*, *uw*) = None. Notice that the last addend—i.e., the one with —will also fail due to consume(∼(*a* | *b* | *c* | ... | *u* | ...), *uw*) = None. Sub-case (2.2) is provable similarly to (2.1). When *u* ∈/ *A* (sub-cases 2.3 and 2.4), none of *ri* is *u* and it is easy to see that consume(*e*, *y*) = None.

**Theorem 1.** *Let* Σ *be a finite alphabet and X, Y be two disjoint, finite, nonempty sets of words over* Σ*. If e is a parsing expression returned by I*(*X*,*Y*) *and G is a non-circular PEG defined by G* = ({*A*}, Σ, {*A* ⇐ *e*}, *A*) *then X* ⊆ *L*(*G*) *and Y* ∩ *L*(*G*) = ∅*.*

**Proof.** This result follows immediately from the two previous Lemmas.

### *3.3. Python's PEG Library Performance Evaluation*

In order to assess the fifth property of Algorithm 1 (for random words output PEGs are many times smaller than the input), we created random sets of words with different sizes, lengths and alphabets. Table 1 shows our settings in this respect.


**Table 1.** The settings of the generator of random input for our PEG algorithm.

Naturally, |*X*| and |*Y*| denote the number of examples and counter-examples, while words' lengths vary from *dmax* to *dmin*. Those datasets are publicly available along with the source code of our PEG library. Figure 3 depicts the number of symbols in a PEG and the number of letters in a respective test set.

The number of letters in an input file simply equals <sup>∑</sup>*w*∈*X*∪*<sup>Y</sup>* (|*w*| + <sup>1</sup>), where +1 stands for a new line sign (i.e., words' separator). As for PEGs, the symbol has not been counted, since it may be omitted. Outside the Python language, concatenation of two symbols, for instance *a* and *b*, can be written as *ab* instead of *a b*. Notice also that in Figure 3 the ordinates are in the logarithmic scale, because the differences are large.

The runtime of Python implementation of the proposed PEG library was benchmarked against comparable libraries, i.e., Arpeggio and Grako. The pyPEG library was rejected, because we were unable to define more complex expressions with it. As a testbed we have chosen Tomita's languages [19]. This test set contains seven different expressions that serve as rules for the generation of words over a binary alphabet. Their description in a natural language can be found in Table 2. Seven regular expressions appropriate to the rules were created, and then the generators of random input words were implemented. Thus, for every language we had two sets: matching (positive) and non-matching (negative) words to a particular regular expression. These expressions take the following forms:


Equivalent PEG expressionswere defined as well in every comparable library (see Table 2).



Table 3 summarizes CPU time results. Every row contains the means for 30 runs. In all experiments we used the implementation of algorithms written in Python (our PEG library, Grako, and Arpeggio) and C++ EGG. An interpreter ran on a four-core Intel i7-965, 3.2 GHz processor in a Windows 10 operating system with 12 GB RAM.

As can be seen, in all cases our library worked much faster than other Python libraries. Grammar 3 was skipped because we were unable to define it either by means of the Arpeggio or Grako libraries. It should be stated, however, that both of the libraries have more functionality than our PEG library, its principal function being only the membership operation, i.e., matching or not a word to a PEG. As a result, Arpeggio, Grako, and pyPeg are relatively not intuitive and obvious, especially for users not familiarized with formal languages theory. The dash character in the EGG result denotes segmentation runtime error. As expected the C++ library (EGG) overcame its Python counterparts.


**Table 3.** Average CPU times for available Python PEG libraries (in seconds).

**Figure 3.** The number of symbols in a PEG (red line) and the number of letters in a respective test set (blue line).

### **4. Results and Discussion**

The algorithm for generating non-circular parsing expression grammars (PEG) was tested over a recently published amyloidogenic dataset [20]. The GP parameters (John Koza, a GP pioneer, has introduced a very lucid form of listing parameters in the tableau of Table 4 named after him) are listed in Table 4. From there, we can read that a population size of *P* = 5 individuals were used for GP runs along with others. The terminal set contains standard amino acid abbreviations; "A" stands for Alanine, "R" for Arginine, etc. Concerning the initialization method, see Section 3.2. The best parameters were chosen in a trial-and-error manner until the values with the best classification quality were found. The dataset is composed of 1476 strings that represent protein fragments. The data came from four databases as shown in Figures 4 and 5. A total of 439 are classified as being amyloidogenic (examples), and 1037 as not (counter-examples). The shortest sequence length is 4 and the longest is 83. Such a wide range of sequence lengths was an additional impediment to learning algorithms.



In order to compare our algorithm to other grammatical inference approaches, we took most of the methods mentioned in the introductory section as a reference. Error-correcting grammatical inference [9] (ECGI) and alignment-based learning [10] (ABL) are examples of substring-based algorithms. The former builds an automaton incrementally based on the Levenstein distance between the closest word stored in the automaton and an inserted word. This process begins with an empty automaton, and for each word adds the error rules (insertion, substitution, and deletion) belonging to the transition path with the least number of error rules. The algorithm provides an automaton without loops that is more and more general. The latter, ABL, is based on searching identical and distinct parts

of input words. This algorithm consists of two stages. First, all words are aligned such that it finds a shared and a distinct part of all pairs of words, suggesting that the distinct parts have the same type. For example, consider the pair "abcd" and "abe". Here, "cd" and "e" are correctly identified as examples of the same type. The second step, which takes the same corpus as input, tries to identify the right constituents. Because the generated constituents found in the previous step might overlap, the correct ones have to be selected. Simple heuristics are used to achieve this, for example to take the constituent that was generated first (ABL-first) or to take the constituent with the highest score on some probabilistic function. We used another approach, in which all constituents are stored, but in the end we tried to keep only the minimum number of constituents that cover all examples.

ADIOS uses statistical information present in sequential data to identify significant segments and to distill rule-like regularities that support structured generalization [8]. It also brings together several crucial conceptual components; the structures it learns are (i) variable-order, (ii) hierarchically composed, (iii) context dependent, (iv) supported by a previously undocumented statistical significance criterion, and (v) dictated solely by the corpus at hand.

Blue-fringe [21] and Traxbar [22], the instances of state merging algorithms, can be downloaded from an internet archive (http://abbadingo.cs.nuim.ie/dfa-algorithms.tar.gz). They start from building a prefix tree acceptor (PTA) based on examples, and then iteratively select two states and do merging unless compatibility is broken. The difference between them comes from many ways in which the pair of states needed to merge can be chosen. Trakhtenbrot and Barzdin [23] described an algorithm for constructing the smallest deterministic FSA consistent with a complete labeled training set. The PTA is squeezed into a smaller graph by merging all pairs of states that represent compatible mappings from word suffixes to labels. This algorithm for completely labeled trees was generalized by Lang (1992) [22] to produce a (not necessarily minimum) automaton consistent with a sparsely labeled tree. Blue-fringe grows a connected set of red nodes that are known to be unique states, surrounded by a fringe of blue nodes that will either be merged with red nodes or promoted to red status. Merges only occur between red nodes and blue nodes. Blue nodes are known to be the roots of trees, which greatly simplifies the code for correct merging.

We also included one machine learning approach. An unsupervised data-driven distributed representation, called ProtVec [24], was applied and protein family classification was performed using a support vector machine classifier (SVM) [25] with the linear kernel.

**Figure 4.** Combined amyloid databases used in this work. Pos and Neg denote, respectively, positive and negative word counts in the database.

**Figure 5.** Combined amyloid databases used in work. Pos and Neg denote, respectively, positive and negative word counts in the database.

The data were randomly split into two subsets, a training (75% of total) and a test set (25% of total). Given the training set and the test set, we used all algorithms to infer predictors (automata or grammars) on the training set, tested them on the test set, and computed their performances. Comparative analyses of the following five measures: Precision, Recall, F-score, the AUC, and Matthews correlation coefficient are summarized in Table 5. The measures are given below:


where the terms true positives (*tp*), true negatives (*tn*), false positives (*fp*), and false negatives (*fn*) compare the results of the classifier under test with trusted external judgments. Thus, in our case, *tp* is the number of correctly recognized amyloids, *fp* is the number of nonamyloids recognized as amyloids, *fn* is the number of amyloids recognized as nonamyloids, and *tn* is the number of correctly recognized nonamyloids. The last column concerns CPU time of computations (induction plus classification in s).


**Table 5.** Results of classification quality for the test set by the decreasing AUC.

The results show that there is no single method that outperformed the remaining methods regardless of an established classification measure. However, the methods can be grouped as relatively good and relatively weak from a certain angle. As regards Recall and F-score, relatively good are Blue-fringe, ADIOS, and PEG. As regards MCC, which is generally recognized as being one of the best classification measures, relatively good are PEG and ABL. Moreover, PEG achieved the best AUC, which in the case of binary prediction is equivalent to balanced accuracy.

To evaluate the convergence toward an optimal solution, we studied average fitness change over generations along with the increasing of the expression sizes (see Figure 6). The shape of the plot does not show any indication of premature convergence. Moreover, we did not observe excessive tree size expansion, which is quite often seen in genetic programming.

**Figure 6.** Average error (1−fitness accuracy) vs. expression length for different generations based on random data with two letters in the alphabet, 100 words at each set of example and counter-example and word lengths between 2 and 20.

All programs ran on an Intel Xeon CPU E5-2650 v2, 2.6 GHz processor under an Ubuntu 16.04 operating system with 192 GB RAM. The computational complexity of all the algorithms is polynomially bounded; however, the differences in running time were quite significant and our approach ranked at the top. The algorithm for PEG induction (https://github.com/wieczorekw/ wieczorekw.github.io/tree/master/PEG) was written in the Python 3 programming language. The languages of implementation for six successive methods were: Python, Java, C, Python, C, and Python.

### **5. Conclusions**

We proposed a new grammatical inference (PEG-based) method and applied it to a real bioinformatics task, i.e., classification of amyloidogenic sequences. The evaluation of generated PEGs on an amyloidogenic dataset revealed the method's accuracy in predicting amyloid segments. We showed that the new grammatical inference algorithm gives the best ACC, AUC, and MCC scores in comparison to five other automata or grammar learning methods and the ProtVec/SVM method.

In the future, we will implement circular rules in the PEG library, which will improve the expressiveness of grammars and may improve the quality of classification.

**Author Contributions:** Conceptualization, W.W.; formal analysis, W.W., O.U. and Ł.S.; investigation, W.W. and Ł.S.; methodology, W.W. and O.U.; software W.W., Ł.S.; validation, W.W. and Ł.S.; writing—original draft, W.W. and Ł.S.; writing—review & editing, W.W., O.U., Ł.S.; data curation, O.U.; funding acquisition, O.U.; project administration, O.U.; supervision, O.U.; resources, Ł.S.; visualization, Ł.S. All authors have read and agreed to the published version of the manuscript.

**Funding:** This research was supported by the National Science Center, grant 2016/21/B/ST6/02158.

**Conflicts of Interest:** The authors declare no conflict of interest.

### **References**


**Publisher's Note:** MDPI stays neutral with regard to jurisdictional claims in published maps and institutional affiliations.

© 2020 by the authors. Licensee MDPI, Basel, Switzerland. This article is an open access article distributed under the terms and conditions of the Creative Commons Attribution (CC BY) license (http://creativecommons.org/licenses/by/4.0/).

## *Article* **Anticipatory Classifier System with Average Reward Criterion in Discretized Multi-Step Environments**

**Norbert Kozłowski \*,† and Olgierd Unold †**

Department of Computer Engineering, Faculty of Electronics, Wroclaw University of Science and Technology, 50-370 Wroclaw, Poland; olgierd.unold@pwr.edu.pl

**\*** Correspondence: norbert.kozlowski@pwr.edu.pl; Tel.: +48-792-922-331

† These authors contributed equally to this work.

**Abstract:** Initially, Anticipatory Classifier Systems (ACS) were designed to address both single and multistep decision problems. In the latter case, the objective was to maximize the total discounted rewards, usually based on Q-learning algorithms. Studies on other Learning Classifier Systems (LCS) revealed many real-world sequential decision problems where the preferred objective is the maximization of the average of successive rewards. This paper proposes a relevant modification toward the learning component, allowing us to address such problems. The modified system is called AACS2 (Averaged ACS2) and is tested on three multistep benchmark problems.

**Keywords:** learning classifier systems; anticipatory classifier systems; reinforcement learning; genetic algorithms; OpenAI gym

### **1. Introduction**

Learning Classifier Systems (LCS) [1] comprise a family of flexible, evolutionary, rule-based machine learning systems that involve a unique tandem of local learning and global evolutionary optimization of the collective model localities. They provide a generic framework combining the discovery and learning components. Despite the misleading name, LCSs are not only suitable for classification problems but may instead be viewed as a very general, distributed optimization technique. Due to representing knowledge locally as IF-THEN rules with additional parameters (such as predicted payoff), they have high potential to be applied in any problem domain that is best solved or approximated through a distributed set of local approximations or predictions. The main feature of LCS is the employment of two learning components. The discovery mechanism uses the evolutionary approach to optimize the individual structure of each classifier. On the other side, there is a credit assignment component approximating the classifier fitness estimation. Because those two interact bidirectionally, LCSs are often perceived as being hard to understand.

Nowadays, LCS research is moving in multiple directions. For instance, BioHEL [2] and ExSTraCS [3] algorithms are designed to handle large amounts of data. They extend the basic idea by adding expert-knowledge-guided learning, attribute tracking for heterogeneous subgroup identification, and a number of other heuristics to handle complex and noisy data mining. On the other side, there are some advances made towards combining LCS with artificial neural networks [4]. Liang et al. [5] took the approach of combining the feature selection of *Convolutional Neural Networks* with LCSs. Tadokoro et al. [6] have a similar goal—they want to use *Deep Neural Networks* for preprocessing in order to be able to use LCSs for high-dimensional data while preserving their inherent interpretability. An overview of recent LCS advancements is published yearly as a part of the *International Workshop on Learning Classifier Systems* (IWLCS) [7].

In the most popular LCS modification-XCS [8], where the classifier fitness is based on the *accuracy* of a classifier's payoff prediction instead of the prediction itself, the learning component responsible for local optimization follows the Q-learning [9] pattern. Classifier

**Citation:** Kozłowski, N.; Unold, O. Anticipatory Classifier System with Average Reward Criterion in Discretized Multi-Step Environments. *Appl. Sci.* **2021**, *11*, 1098. https:// doi.org/10.3390/app11031098

Academic Editor: Grzegorz Dudek Received: 28 October 2020 Accepted: 16 January 2021 Published: 25 January 2021

**Publisher's Note:** MDPI stays neutral with regard to jurisdictional claims in published maps and institutional affiliations.

**Copyright:** © 2021 by the authors. Licensee MDPI, Basel, Switzerland. This article is an open access article distributed under the terms and conditions of the Creative Commons Attribution (CC BY) license (https:// creativecommons.org/licenses/by/ 4.0/).

predictions are updated using the immediate reward and the discounted maximum payoff anticipated in the next time step. The difference is that, in XCS, it is the prediction of a general rule that is updated, whereas, in Q-learning, it is the prediction associated with an environmental *state–action* pair. In this case, both algorithms are suitable for multistep (sequential) decision problems in which the objective is to maximize the discounted sum of rewards received in successive steps.

However, in many real-world situations, the discounted sum of rewards is not the appropriate option. This choice is right when the rewards received in all decision instances are equally important. The criterion applied in this situation is called *the average reward criterion* and was introduced by Puterman [10]. He stated that the decision maker might prefer it when the decisions are made frequently (so that the discount rate is very close to 1) or other terms cannot easily describe the performance criterion. Possible areas of an application might include situations where system performance is assessed based on the throughput rate (like making frequent decisions when controlling the flow of communication networks).

The averaged reward criterion was first implemented in XCS by Tharakunnel and Goldberg [11]. They called their modification AXCS and showed that it performed similarly to the standard XCS in the Woods2 environment. Later, Zang et al. [12] formally introduced the R-learning [13,14] technique to XCS and called it XCSAR. They compared it with XCSG (where the prediction parameters are modified by applying the idea of gradient descent) and ACXS (maximizing the average of successive rewards) in large multistep problems (Woods1, Maze6, and Woods14).

In this paper, we introduce the average reward criterion to yet another family of LCSs-anticipatory learning classifier systems (ALCS). They differentiate from others so that a predictive schema model of the environment is learned rather than reward prediction maps. In contrast to the usual classifier structure, classifiers in ALCS have a state prediction or *an anticipatory part* that predicts the environmental changes caused when executing the specified action in the specified context. Similarly, as in the XCS, ALCSs derive classifier fitness estimates from the accuracy of their predictions; however, anticipatory state predictions' accuracy is considered rather than the reward prediction accuracy. Popular ALCSs use the discounted criterion in the original form, optimizing the performance in the infinite horizon.

Section 2 starts by briefly describing the psychological insights from the concepts of imprinting and anticipations and the most popular ALCS architecture-ACS2. Then, the RL and the ACS2 learning components are described. The default discounted reward criterion is formally defined, and two versions of undiscounted (averaged) criterion integration are introduced. The created system is called AACS2, which stands for *Averaged ACS2*. Finally, three testing sequential environments with increasing difficulty are presented: the Corridor, Finite State Worlds, and Woods. Section 3 examines and describes the results of testing ACS2, AACS2, Q-learning, and R-learning in all environments. Finally, the conclusions are drawn in Section 4.

### **2. Materials and Methods**

#### *2.1. Anticipatory Learning Classifier Systems*

In 1993, Hoffman proposed a theory of *"Anticipatory Behavioral Control"* that was further refined in [15]. It states that higher animals form an internal environmental representation and adapt their behavior through learning anticipations. The following points (visualized in Figure 1) can be distinguished:


**Figure 1.** The theory of anticipatory behavioral control; the figure was adapted from [16], p. 4.

This insight into the presence and importance of anticipations in animals and man leads to the conclusion that it would be beneficial to represent and utilize them in animals.

Stolzmann took the first approach in 1997 [17]. He presented a system called ACS (*"Anticipatory Classifier System"*), enhancing the classifier structure with an anticipatory (or effect) part that anticipates the effects of an action in a given situation. A dedicated component realizing Hoffmann's theory was proposed—*Anticipatory Learning Process* (ALP), introducing new classifiers into the system.

The ACS starts with a population [*P*] of most general classifiers ('#' in a condition part) for each available action. To ensure that there is always a classifier handling every consecutive situation, those cannot be removed. During each behavioral act, the current perception of environment *σ*(*t*) is captured. Then, a match set [*M*](*t*) is formed, consisting of all classifiers from [*P*] where the condition matches the perception *σ*(*t*). Next, one classifier *cl* is drawn from [*M*](*t*) using some exploration policy. Usually, an epsilon-greedy technique is used, but [18] describes other options as well. Then, the classifier's action *cl*.*a* is executed in the environment, and a new perception *σ*(*t* + 1) and reward *φ*(*t* + 1) values are presented to the agent. Knowing the classifiers' anticipation and current state, the ALP module can adjust the classifier *cl*'s condition and effect parts. Based on this comparison, certain cases might be present:


After the ALP application, the Reinforcement Learning (RL) module is executed (see Section 2.3 for details).

Later, in 2002, Butz presented an extension to the described system called ACS2 [16]. Most importantly, he modified the original approach by:


Figure 2 presents the complete behavioral act, and Refs. [19,20] describe the algorithm thoroughly.

**Figure 2.** A behavioral act in ACS2; the figure was adapted from [16], p. 27.

Some modifications were made later to the original ACS2 algorithm. Unold et al. integrated the action planning mechanism [21], Orhand et al. extended the classifier structure with *Probability-Enhanced-Predictions* introducing a system capable of handling non-deterministic environments and calling it PEPACS [22]. In the same year, they also tackled an issue of perceptual aliasing by building a *Behavioral Sequences*—thus creating a system called BACS [23].

#### *2.2. Reinforcement Learning and Reward Criterion*

Reinforcement Learning (RL) is a formal framework in which the agent can influence the environment by executing specific actions and receive corresponding feedback (reward) afterwards. Usually, it takes multiple steps to reach the goal, which makes the process much more complicated. In the general form, RL consists of:


In each trial, the agent perceives the environmental state *s*. Next, it evaluates all possible actions from *A* and executes action *a* in the environment. The environment returns a signal *r* and next state *s* as intermediate feedback.

The agent's task is to represent the knowledge, using the policy *π* mapping states to actions, therefore optimizing a long-run measure of reinforcement. There are two popular optimality criteria used in Markov Decision Problems (MDP)—a *discounted reward* and an *average reward* [24,25].

#### 2.2.1. Discounted Reward Criterion

In discounted RL, the future rewards are geometrically discounted according to a discount factor *γ*, where 0 ≤ *γ* < 1. The performance is usually optimized in the infinite horizon [26]:

$$\lim\_{N \to \infty} E^{\pi} \left( \sum\_{t=0}^{N-1} \gamma^t r\_t(s) \right) \tag{1}$$

The *E* expresses the expected value, *N* is the number of time steps, and *rt*(*s*) is the reward received at time *t* starting from state *s* under the policy.

### 2.2.2. Undiscounted (Averaged) Reward Criterion

The *averaged reward criterion* [13], which is the undiscounted RL, is where the agent selects actions maximizing its long-run average reward per step *ρ*(*s*):

$$\rho^{\pi}(\mathbf{s}) = \lim\_{N \to \infty} \frac{E^{\pi}\left(\sum\_{t=0}^{N-1} r\_t(\mathbf{s})\right)}{N} \tag{2}$$

If a policy maximizes the average reward over all states, it is a *gain optimal policy*. Usually, average reward *ρ*(*s*) can be denoted as *ρ*, which is state-independent [27], formulated as *<sup>ρ</sup>π*(*x*) = *<sup>ρ</sup>π*(*y*) = *<sup>ρ</sup>π*, <sup>∀</sup>*x*, *<sup>y</sup>* <sup>∈</sup> *<sup>S</sup>* when the resulting Markov chain with policy *<sup>π</sup>* is ergodic (aperiodic and positive recurrent) [28].

To solve an average reward MDP problem, a stationary policy *π* maximizing the average reward *ρ* needs to be determined. To do so, the *average adjusted sum* of rewards earned following a policy *π* is defined as:

$$V^{\pi}(s) = E^{\pi} \left( \sum\_{t=0}^{N \to \infty} (r\_t - \rho^{\pi}) \right) \tag{3}$$

The *Vπ*(*s*) can also be called a *bias* or *relative value*. Therefore, the optimal relative value for a state–action pair (*s*, *a*) can be written as:

$$V(s, a) = r^a(s, s') - \rho + \max\_b V(s', b) \forall s \in S \text{ and } \forall a \in A \tag{4}$$

where *ra*(*s*,*s* ) denotes the immediate reward of action *a* in state *s* when the next state is *s* , *ρ* is the average reward, and max*<sup>b</sup> V*(*s* , *b*) is the maximum relative value in state *s* among all possible actions *b*. Equation (4) is also known as the Bellman equation for an average reward MDP [28].

### *2.3. Integrating Reward Criterions in ACS2*

Despite the ACS's *latent-learning* capabilities, the RL is realized using two classifier metrics-reward *cl*.*r* and immediate reward *cl*.*ir*. The latter stores the immediate reward predicted to be received after acting in a particular situation and is used mainly for model exploitation where the reinforcement might be propagated internally. The reward parameter *cl*.*r* stores the reward predicted to be obtained in the long run.

For the first version of ACS, Stolzmann proposed a *bucket-brigade* algorithm to update the classifier's reward *rc* [20,29]. Let *ct* be the active classifier at time *t* and *ct*+<sup>1</sup> the active classifier at time *t* + 1:

$$r\_{\varepsilon\_l}(t+1) = \begin{cases} (1 - b\_r) \cdot r\_{\varepsilon\_l}(t) + b\_r \cdot r(t+1), & \text{if } r(t+1) \neq 0\\ (1 - b\_r) \cdot r\_{\varepsilon\_l}(t) + b\_r \cdot r\_{\varepsilon\_{l+1}}(t), & \text{if } r(t+1) = 0 \end{cases} \tag{5}$$

where *br* ∈ [0, 1] is the *bid-ratio*. The idea is that if there is no environmental reward at time *t* + 1, then the currently active classifier *ct*+<sup>1</sup> gives a payment of *br* · *rct*<sup>+</sup><sup>1</sup> (*t*) to the previous active classifier *ct*. If there is an environmental reward *r*(*t* + 1), then *br* · *r*(*t* + 1) is given to the previous active classifier *ct*.

Later, Butz adopted the Q-learning idea in ACS2 alongside other modifications [30]. For the agent to learn the optimal behavioral policy, both the reward *cl*.*r* and intermediate reward *cl*.*ir* are continuously updated. To assure maximal Q-value, the quality of a classifier is also considered assuming that the reward converges in common with the anticipation's accuracy. The following updates are applied to each classifier *cl* in action set [*A*] during every trial:

$$\begin{array}{rcl}cl.r.&=&cl.r+\beta\left(\phi(t)+\gamma\max\_{\mathit{cl'}\in[M](t+1)\land\mathit{cl'}.\to\{\#\}}\left(\mathit{cl'}.\mathfrak{q}\cdot\mathit{cl'}.r\right)-\mathit{cl.r}\right)\\cl.ir&=&cl.ir+\beta\left(\phi(t)-\mathit{cl.ir}\right)\end{array}\tag{6}$$

The parameter *β* ∈ [0, 1] denotes the learning rate and *γ* ∈ [0, 1) is the discount factor. With a higher *β* value, the algorithm takes less care of past encountered cases. On the other hand, *γ* determines to what extent the reward prediction measure depends on future reward.

Thus, in the original ACS2, the calculation of the discounted reward estimation at a specific time *t* is described as *Q*(*t*), which is part of Equation (6):

$$Q(t) \leftarrow \phi(t) + \gamma \max\_{cl' \in [M](t+1) \land cl'.E \neq \{\theta\}^L} (cl'.q \cdot cl'.r) \tag{7}$$

The modified ACS2 implementation replacing the discounted reward with the averaged version with the formula *R*(*t*) is defined below (Equation (8)):

$$R(t) = \phi(t) - \rho + \max\_{cl' \in [M](t+1) \land cl'.E \neq \{\theta\}^L} (cl'.q \cdot cl'.r) \tag{8}$$

The definition above requires an estimate of the average reward *ρ*. Equation (4) showed that the maximization of the average reward is achieved by maximizing the relative value. The next sections will propose two variants of setting it to use the average reward criterion for internal reward distribution. The altered version is named AACS2, which stands for *Averaged ACS2*.

As the next operation in both cases, the reward parameter of all classifiers in the current action set [*A*] is updated using the following formula:

$$cl.r \gets cl.r + \beta(R - cl.r) \tag{9}$$

where *β* is the learning rate and *R* was defined in Equation (8).

### 2.3.1. AACS2-v1

The first variant of the AACS2 represents *ρ* parameter as the ratio of the total reward received along the path to reward and the average number of steps needed. It is initialized as *ρ* = 0, and its update is executed as the first operation in RL using the Widrow–Hoff delta rule (Equation (10)). The update is also restricted to be executed only when the agent chooses the action greedily during the explore phase:

$$
\rho \leftarrow \rho + \mathbb{\zeta}[\phi - \rho] \tag{10}
$$

The *ζ* parameter denotes the learning rate for average reward and is typically set at a very low value. This ensures a nearly constant value of average reward for the update of the reward, which is necessary for the convergence of average reward RL algorithms [31].

### 2.3.2. AACS2-v2

The second version is based on the XCSAR proposition by Zang [12]. The only difference from the AACS2-v1 is that the estimate is also dependent on the maximum classifier fitness calculated from the previous and current match set:

$$\rho \leftarrow \rho + \mathbb{\mathbb{Q}}[\phi + \max\_{\substack{cl \in [M](t) \wedge cl.E \neq \{\theta\}^\bot}} (cl.q \cdot cl.r) - \max\_{\substack{cl \in [M](t+1) \wedge cl.E \neq \{\theta\}^\bot}} (cl.q \cdot cl.r) - \rho] \tag{11}$$

### *2.4. Testing Environments*

This section will describe Markovian environments chosen for evaluating the introduction of the average reward criterion. They are sorted from simple to more advanced, and each of them has different features allowing us to examine the difference between using discounted and undiscounted reward distribution.

#### 2.4.1. Corridor

The corridor is a 1D multi-step, linear environment introduced by Lanzi to evaluate the XCSF agent [32]. In the original version, the agent location was described by a value between [0, 1]. When one of the two possible actions (move left or move right) was executed, a predefined *step-size* adjusted the agent's current position. When the agent reaches the final state *s* = 1.0 the reward *φ* = 1000 is paid out.

In this experiment, the environment is discretized into *n* unique states (Figure 3). The agent can still move in both directions, and a single trial ends when the terminating state is reached or the maximum number of steps is exceeded.


**Figure 3.** The Corridor environment. The agent (denoted by "\*") is inserted randomly and its goal is to reach the final state *n* by executing two actions-moving left or right.

#### 2.4.2. Finite State World

Barry [33] introduced the *Finite State World* (FSW) environment to investigate the limits of XCS performance in long multi-steps environments with a delayed reward. It consists of *nodes* and directed *edges* joining the nodes. Each node represents a distinct environmental state and is labeled with a unique state identifier. Each edge represents a possible transition path from one node to another and is also labeled with the action(s) that will cause the movement. An edge can also lead back to the same node. The graph layout used in the experiments is presented in Figure 4.

**Figure 4.** A Finite State World of length 5 (FSW-5) for a delayed reward experiment.

Each trial always starts in state *s*0, and the agent's goal is to reach the final state *sr*. After doing so, the reward *φ* = 100 is provided, and the trial ends. The environment has a couple of interesting properties. First, it can be easily scalable, just by changing the number of nodes, which will change the action chain length. It also enables the agent to choose the optimal route at each state (where the sub-optimal ones do not prevent progress toward the reward state).

#### 2.4.3. Woods

The Woods1 [34] environment is a two-dimensional rectilinear grid containing a single configuration of objects that is repeated indefinitely in the horizontal and vertical directions (Figure 5). It is a standard testbed for classifier systems in multi-step environments. The agent's learning task is to find the shortest path to food.

There are three types of objects available-food ("F"), rock ("O"), and empty cell ("."). In each trial, the agent ("\*") is placed randomly on an empty cell and can sense the environment by analyzing the eight nearest cells. Two versions of encoding are possible. Using binary encoding, each cell type is assigned two bits, so the observation vector has

the length of 16 elements. On the other hand, using an encoding with the alphabet {0, *F*, .}, the observation vector is compacted to the length of 8.

In each trial, the agent can perform eight possible moves. When the resulting cell is empty, it is allowed to change the position. If its type is a block, then it stays in place, and one time-step elapses. The trial ends when the agent reaches the food providing the reward *φ* = 1000.

**Figure 5.** Environment Woods1 with an animat "\*". Empty cells are denoted by ".".

### **3. Results**

The following section describes the differences observed between using the ACS2 with standard discounted reward distribution and two proposed modifications. In all cases, the experiments were performed in an explore–exploit manner, where the mode was alternating in each trial. Additionally, for better reference and benchmarking purposes, basic implementations of Q-Learning and R-Learning algorithms were also introduced and used with the same parameter settings as ACS2 and AACS2. The most important thing was to distinguish whether the new reward distribution proposition still allows the agent to successfully update the classifier's parameter to allow the exploitation of the environment. To illustrate this, figures presenting the number of steps to the final location, estimated average change during learning, and the reward payoff-landscape across all possible state–action pairs were plotted.

For the reproduction purposes, all the experiments were performed in Python language. A PyALCS (https://github.com/ParrotPrediction/pyalcs) [35] framework was used for implementing additional AACS2-v1 and AACS2-v2 agents (https://github.com/ ParrotPrediction/pyalcs) and all the environments used are implemented according to the OpenAI Gym [36] in a separate repository (https://github.com/ParrotPrediction/openaienvs). Publicly available interactive Jupyter notebooks presenting all results are available for reproduction here (https://github.com/ParrotPrediction/pyalcs-experiments).

### *3.1. Corridor 20*

The following parameters were used: *β* = 0.8, *γ* = 0.95, = 0.2, *ζ* = 0.0001. The experiments were run on 10,000 trials in total. Because there is only one state to be perceived by the agent, the genetic generalization feature was disabled. The corridor of size *ncorridor* = 20 was tested, but similar results were also obtained for greater sizes.

The average number of steps can be calculated <sup>∑</sup>*ncorridor* <sup>0</sup> *n ncorridor*−<sup>1</sup> , which for the tested environment gives the approximate value of 11.05. It is seen that the average reward per step in this environment should be close to 90.47.

Figure 6 demonstrates that the environment is learned in all cases. The anticipatory classifier systems obtained an optimal number of steps after the same number of exploit trials, which is about 200. In addition, the AACS2-v2 updates the *ρ* value more aggressively in earlier phases, but the estimate converges near the optimal reward per step.

For the payoff-landscape, all allowed state–action pairs were identified in the environment (38 in this case). The final population of learning classifiers was established after 100 trials and was the same size. Both Q-table and R-learning tables were filled in using the same parameters and number of trials.

**Figure 6.** Performance on Corridor 20 environment. Plots are averaged in ten experiments. For the number of steps, a logarithmic scale ordinate and a moving average with window 250 was applied.

Figure 7 depicts the differences in the payoff-landscapes. The relative distance between adjacent state–action pairs can be divided into three groups. The first one relates to the discounted reward agents (ACS2, Q-Learning). Both generate almost a similar reward payoff for each state–action. Later, there is the R-Learning algorithm, which estimates the *ρ* value and separates states evenly. Furthermore, two AACS2 agents are performing very similarly. The *ρ* value calculated by the R-Learning algorithm is lower than the average estimation by the AACS2 algorithm.

**Figure 7.** Payoff Landscape for Corridor 20 environment. Payoff values were obtained after 10,000 trials. For the Q-Learning and R-Learning, the same learning parameters were applied. The ACS2 and Q-Learning generate exactly the same payoffs for each state–action pair.

#### *3.2. Finite State Worlds 20*

The following parameters were selected: *β* = 0.5, *γ* = 0.95, = 0.1, *ζ* = 0.0001. The experiments were performed in 10,000 trials. Similarly as before, there is only one state observed, and the genetic generalization mechanism remains turned off. The size of the environments was chosen to be *nf sw* = 10, resulting in 2*nf sw* + 1 = 21 distinct states.

Figure 8 presents that agents are capable of learning a more challenging environment without any problems. It takes about 250 trials to reach the reward state performing an optimal number of steps. Like in the corridor environment from Section 3.1, the *ρ* parameter converges with the same dynamics.

The payoff-landscape Figure 9 shows that the average value estimate is very close to the one calculated by the R-learning algorithm. The difference is mostly visible in the state– action pairs located afar from the final state. The discounted versions of the algorithms performed precisely the same.

**Figure 8.** Performance on the FSW-10 environment.Plots are averaged in ten experiments. For the number of steps, a moving average with window 25 was applied. Notice that the abscissa on both plots is scaled differently.

**Figure 9.** Payoff Landscape for the FSW-10 environment. Payoff values were obtained after 10,000 trials. For the Q-Learning and R-Learning, the same learning parameters were applied.

### *3.3. Woods1*

The following parameters were used: *β* = 0.8, *γ* = 0.95, = 0.8, *ζ* = 0.0001. Each environmental state was encoded using three bits, so the perception vector passed to agent has the length of 24. The genetic generalization mechanism was enabled with the parameters: mutation probability *μ* = 0.3, cross-over probability *χ* = 0.8, genetic algorithm application threshold *θga* = 100. The experiments were performed in 50,000 trials and repeated five times.

The optimal number of steps in the Woods1 environment is 1.68, so the maximum average reward can be calculated as 1000/1.68, i.e., 595.24.

Figure 10 shows that the ACS2 did not manage to learn the environment successfully the number of steps performed in the exploit trial is not stable and varies much higher than the optimal value. On the other hand, both AACS2 versions managed to function better. The AACS2-v2 stabilizes faster and with weaker fluctuations. The best performance was obtained for the Q-Learning and R-Learning algorithm that managed to learn the environment in less than 1000 trials. The average estimate *ρ* value converges at the value of 385 for both cases after 50,000 trials, which is still not optimal.

**Figure 10.** Performance in the Woods1 environment. For brevity, the number of steps is averaged on 250 latest exploit trials. Both AACS2 variants managed to converge to the optimal number of steps.

What is interesting is that neither ACS2 nor AACS2 population settled to the final size. Figure 11 demonstrates the difference in size for each algorithm between total population size and the number of reliable classifiers. Even though the algorithm manages to find the shortest path for AACS2, the number of created rules is greater than all unique state– action pairs in the environment, which is 101. The experiment was also performed ten times longer (one million trials) to see if the correct rules will be discovered, but that did not happen.

**Figure 11.** Comparison of classifier populations in Woods1 environment. None of the algorithms managed to create stable population size. The number of exploit trials is narrowed to the first 5000 exploit trials, and the plots are averaged on 50 latest values for clarity.

Finally, the anticipatory classifier systems' inability to solve the environment is depicted in payoff-landscape Figure 12. The Q-Learning and R-Learning have three spaced threshold levels, corresponding to states where the required number of steps to the reward states is 1, 2, and 3. All ALCS struggle to learn the correct behavior anticipation. The number of classifiers detected for each state–action is greater than optimal.

**Figure 12.** Payoff-landscape in the Woods1 environment. Three threshold levels are visible for the Q-Learning and R-Learning algorithms representing states in the environment with a different number of steps to the reward state.

#### **4. Discussion**

Experiments performed indicated that anticipatory classifier systems with the averaged reward criterion can be used in multi-step environments. The new system AACS2 varies only in a way the classifier reward *cl*.*r* metric is calculated. The clear difference between the discounted criterion is visible on the payoff landscapes generated from the testing environments. The AACS2 can produce a distinct payoff-landscape with uniformly spaced payoff levels, which is very similar to the one generated by the R-learning algorithm. When taking a closer look, all algorithms generate step-like payoff-landscape plots, but each particular state–action pairs are more distinguishable when the reward-criterion is used. The explanation of why the agent moves toward the goal at all can be found in Equation (8)—it is able to find the next best action by using the best classifiers' fitness from the next match set.

In addition, the rate at which the average estimate value *ρ* converges is different for AACS2-v1 and AACS2-v2. Figures 6, 8, and 10 demonstrate that the AACS2-v2 settles to the final value faster, but also has greater fluctuations. That is caused by the fact that both match sets' maximum fitness is considered when updating the values. Zang also observed this and proposed that the learning rate *ζ* in Equation (11) could decay over time [12]:

$$\mathcal{Z} \leftarrow \mathcal{J} - \frac{\mathcal{J}^{\text{max}} - \mathcal{J}^{\text{min}}}{\text{NumOfTrials}} \tag{12}$$

where *ζmax* is the initial value of *ζ*, and *ζmin* is the minimum learning rate required. The update should take place at the beginning of each exploration trial.

In addition, the fact that the optimal *ρ* value was not optimal value might be caused by the exploration strategy adopted. The selected policy was -greedy. Because the estimated average reward is updated only when the greedy action is executed, the number of greedy actions to be performed during the exploration trial is uncertain. In addition, the probability distribution when the agent observes the rewarding state might be too low in order to enable the estimated average reward to reach optimal value. This was observed during the experimentation—the *ρ* value was very dependent on the parameter used.

To conclude, additional research would be beneficial paying extra attention to:


**Author Contributions:** Conceptualization, O.U. and N.K.; data curation, N.K.; formal analysis, N.K. and O.U.; funding acquisition, O.U.; investigation, N.K.; methodology, N.K. and O.U.; project administration, O.U.; resources, N.K. and O.U.; software, N.K.; supervision, O.U.; validation, O.U. and N.K., visualization, N.K.; writing—original draft preparation, N.K. and O.U.; writing—review and editing, N.K. and O.U. Both authors have read and agreed to the published version of the manuscript.

**Funding:** This research received no external funding.

**Institutional Review Board Statement:** Not applicable.

**Informed Consent Statement:** Not applicable.

**Conflicts of Interest:** The authors declare no conflict of interest.

### **Abbreviations**

The following abbreviations are used in this manuscript:


#### **References**


## *Article* **Gaussian Mixture Models for Detecting Sleep Apnea Events Using Single Oronasal Airflow Record**

### **Hisham ElMoaqet 1,\*, Jungyoon Kim 2,\*, Dawn Tilbury 3, Satya Krishna Ramachandran 4, Mutaz Ryalat <sup>1</sup> and Chao-Hsien Chu <sup>5</sup>**


Received: 9 October 2020; Accepted: 2 November 2020; Published: 6 November 2020

**Abstract:** Sleep apnea is a common sleep-related disorder that significantly affects the population. It is characterized by repeated breathing interruption during sleep. Such events can induce hypoxia, which is a risk factor for multiple cardiovascular and cerebrovascular diseases. Polysomnography, the gold standard, is expensive, inaccessible, uncomfortable and an expert technician is needed to score sleep-related events. To address these limitations, many previous studies have proposed and implemented automatic scoring processes based on fewer sensors and machine learning classification algorithms. However, alternative device technologies developed for both home and hospital still have limited diagnostic accuracy for detecting apnea events even though many of the previous investigational algorithms are based on multiple physiological channel inputs. In this paper, we propose a new probabilistic algorithm based on (only) oronasal respiration signal for automated detection of apnea events during sleep. The proposed model leverages AASM recommendations for characterizing apnea events with respect to dynamic changes in the local respiratory airflow baseline. Unlike classical threshold-based classification models, we use a Gaussian mixture probability model for detecting sleep apnea based on the posterior probabilities of the respective events. Our results show significant improvement in the ability to detect sleep apnea events compared to a rule-based classifier that uses the same classification features and also compared to two previously published studies for automated apnea detection using the same respiratory flow signal. We use 96 sleep patients with different apnea severity levels as reflected by their Apnea-Hypopnea Index (AHI) levels. The performance was not only analyzed over obstructive sleep apnea (OSA) but also over other types of sleep apnea events including central and mixed sleep apnea (CSA, MSA). Also the performance was comprehensively analyzed and evaluated over patients with varying disease severity conditions, where it achieved an overall performance of *TPR* = 88.5%, *TNR* = 82.5%, and *AUC* = 86.7%. The proposed approach contributes a new probabilistic framework for detecting sleep apnea events using a single airflow record with an improved capability to generalize over different apnea severity conditions

**Keywords:** sleep apnea; airflow signal; Gaussian Mixture Models (GMM)

#### **1. Introduction**

Sleep apnea is a highly prevalent sleep disorder that can cause significant daytime sleepiness and result in many cardiovascular comorbidities [1–4]. It is characterized by repetitive significant airflow reductions during sleep causing recurrent hypoxia and sleep fragmentation [5–7]. When breathing doesn't completely stop but the volume of air going into the lungs is significantly reduced, then the respiratory event is called a hypopnea. More than 200 million patients worldwide are affected with sleep apnea [8].

Sleep apnea events are three types: Obstructive, central, and mixed [9]. Obstructive sleep apnea (OSA) is characterized by repetitive upper airway obstruction that limit airflow from going in to the lungs with the presence of continued respiratory effort. Central sleep apnea (CSA) is characterized by the loss of all respiratory effort during sleep due to a neurological disorder. Mixed sleep apnea (MSA) is combination of both obstructive and central sleep apnea symptoms.

Polysomnography (PSG), often called a sleep study, is the gold standard for detecting sleep apnea. Polysomnography records basic human body activities during sleep in an attended setting (sleep laboratory). This includes electrocardiogram (ECG) for heart, oronasal thermal airflow signal (FlowTh) and nasal pressure signal (NPRE) for respiration, electroencephalogram (EEG) for brain, electromyogram (EMG) for muscles, and oxygen level in the blood (SpO2) [10,11]. Connecting a large number of sensors and wires to a subject in a dedicated sleep lab makes PSG uncomfortable, expensive, and unavailable to a large number of sleep patients in many parts of the world [9]. Moreover, clinicians need an offline inspection of the recordings to score apnea and derive the apnea-hypopnea index (AHI), which is the parameter used to establish sleep apnea and its severity [12]. Thus, the analysis process is labor-intensive and time-consuming, leading to a delayed diagnostic process and increased patient waiting lists [13–16] as well as being highly susceptible to human errors [17].

To overcome limitations of PSG, several studies have been proposed for automated detection of sleep apnea using a limited subset of signals among those involved in PSG [18]. This includes respiratory signals [19–29], ECG [16,30–32], SpO2 [33–35], tracheal sound signals [36], or some combinations of signals listed above [37]. A number of portable devices for sleep apnea monitoring and diagnosis have been developed and are available. LifeShirt, SleepStrip, and ApneaLink are among the most popular products [38].

Respiratory airflow signal is a straightforward choice to look for simpler alternatives to PSG, since apneas are primarily defined on the basis of its amplitude oscillations [12]. According to the American Academy of Sleep Medicine (AASM), the primary sensor for identifying apnea in sleep diagnostic studies is the oronasal thermal airflow sensor [11]. Thus, several studies focused on automated detection of sleep apena events based exclusively on the analysis of this signal. In these studies, the airflow signal is analyzed in different analytic domains (linear, nonlinear, time and frequency domains) to extract features which are then used in a rule-based threshold classifier or in a "black box" machine learning model. Rule-based threshold classification has been used in [24,39–42]. Support Vector Machines (SVM) [43,44], Artificial Neural Networks (ANN) [20,45–47], linear discriminant analysis (LDA), and regression trees (CART) with the AdaBoost (AB) [22] are among the most popular machine learning models that used respiratory flow signal.

Despite their popularity in sleep apnea problems, a major limitation in classical rule-based threshold detectors is that they provide classification based on simple comparison for the features against experimentally derived thresholds while overlooking the statistical distributions for the input features as well as the output classes. Even more complex discriminative (black box) methods are based on learning a function that maps the features directly into decisions. There hasn't been much research considering probabilistic view of classification for sleep apnea detection.

Gaussian Mixture Model (GMM) is a probabilistic machine learning framework that aims at providing a richer class of density models than single Gaussian using a finite weighted mixture of Gaussian densities. It is well known as a rich framework capable of characterizing any continuous density. This framework has also shown promising results in classification problems including noisy features [48]. Nevertheless, it hasn't been well evaluated for sleep apnea detection problems.

The contribution of this paper is two fold. First, we develop a probability based classification approach for automated detection of sleep events using single oronasal airflow record. Second, we study the performance of the proposed approach over a large data set of 96 patients of different sleep apnea severity levels. Finally, we conduct a comprehensive evaluation and comparison of the proposed probabilistic framework against a rule-based classifier for the same input features as well as two previously published algorithms for apnea detection using airflow signal.

This paper is organized as follows. Section 2 describes the data set, the proposed algorithm, classification methods, and evaluation metrics. Section 3 presents results for the proposed algorithm along with a detailed comparison with related works. Section 4 discusses the results and lessons learned and 5 summarizes conclusion of the paper.

### **2. Materials and Methods**

#### *2.1. Data Set*

The Institutional Review Board (IRB) at the University of Michigan approved this study (IRB#HUM00069035). Full polysomnography (PSG) data was collected for 96 patients at the University of Michigan Sleep Disorders Center. For each patient, polysomnography consisted of electroencephalography (EEG), electrooculography (EOG), submental and tibial electromyography (EMG), electrocardiography (ECG), two piezoelectric belts for recording plethysmography (PPG), oronasal airflow sensor (FlowTH), air pressure transducer (NPRE), digital micro-phone and pulse oximeter.

The oronasal airflow sensor used in this study is a thermocouple-based one from Respironics Model: Pro-Tech-*P*1273 (Philips Healthcare, Eindhoven, The Netherlands). Clinical annotations for respiratory events were carried out by expert clinicians from the Sleep Disorders Center at the University of Michigan (Ann Arbor, MI) and according to recommendations of the AASM [11]. Apneic events in the data set are either obstructive (OSA), central (CSA), or mixed (MSA). The data set spans different apnea severity levels as reflected by the apnea-hypopnea index (AHI) computed over night for patients in the study. 10 are non/minimal sleep apnea patients (AHI < 5), 36 are mild sleep apnea patients (5 ≤ AHI < 15), 27 are moderate sleep apnea patients and 23 are severe sleep apnea patients (AHI ≥ 30). Table 1 provides distribution for the numbers and types of the apneic events per each class of the patient severity levels.

The oronasal airflow signals of 66 patients (with different apnea severity levels) were used for training the proposed modeling framework. The developed framework was then tested on 30 patients (distinct from the training data) composed of 5 none/minimal apnea , 5 mild, 5 moderate, and 15 severe sleep apnea patients.


**Table 1.** Distribution of Number and Types of Apneic Events per each Class of the Severity Levels.

### *2.2. A Data-Driven Approach for Characterizing Changes in Respiratory Baseline*

According to the AASM, an apnea event is scored if there is a drop in peak thermal airflow signal excursions by ≥90% of the corresponding baseline for a duration ≥10 s [11]. Nevertheless, the airflow baseline is not precisely defined neither in the AASM Scoring Manual nor in sleep literature. To overcome this limitation, a data driven approach will be used to derive the respiratory flow baseline from the airflow signal (FlowTH). The derived baseline will then be used to characterize dynamic changes in respiration with respect to this (dynamic) baseline in order to detect the occurrence of apneic events. To establish respiratory baseline, we will consider two important respiratory features: Inter-breath intervals and breath amplitudes.

A sliding window method will be used for detecting apnea events in the oronasal airflow signal (FlowTH). At time step *t*, two windows will be established. The first window (baseline window-*Wb*) will be used to derive the local respiratory baseline. The second window (detection window-*Wm*) will be used to detect apneic events based on relative changes in inter-breath intervals and breath amplitudes in *Wm* with respect to those in the *Wb*. In this study, we considered a *Wb* of length *Lb* = 600 s that contains the airflow measurements up to time *t* and a *Wm* of length *Lm* = 100 s that contains airflow measurements starting from time step *t* + 1.

After constructing *Wb* and *Wm*, peaks and valleys of the respiratory airflow signal are detected in both windows. An example of *Wb* and *Wm* that both include an apneic event along with peak and valley detections is illustrated in Figure 1a. The inter-breath intervals and the breath amplitudes can now be extracted from *Wb* and *Wm* as follows:

$$PP\_i \quad = \ t\_{i+1} - t\_i \tag{1}$$

$$PV\_i \quad = \quad P\_i - V\_i \tag{2}$$

where the airflow breath *i* has a peak *Pi* that occurs at time instance *ti*, a valley *Vi*, an inter-breath interval *PPi*, and a breath amplitude *PVi*. These Equations generate sequences of inter-breath intervals and breath amplitudes in *Wb* and *Wm* as illustrated in Figure 1b,c,f,g.

After getting these sequences, it is required to extract the (inter-breath) intervals and (breath) amplitudes in *Wb* that contribute most to the respiratory baseline estimate of this window. Similarly, it is required to extract the intervals and amplitudes in *Wm* that belong to the apneic event to be detected. This can be effectively done by sorting sequences of intervals and amplitudes in both *Wb* and *Wm* based on corresponding values. For convenience, *PPi* and *PVi* in *Wb* are sorted in a descending order while those in *Wm* are sorted in an ascending order as illustrated in Figure 1d,e,h,i. This process will generate ordered sequences *PPb* <sup>=</sup> {*PP<sup>d</sup> <sup>i</sup>* }, *PVb* <sup>=</sup> {*PV<sup>d</sup> <sup>i</sup>* }, *PPm* <sup>=</sup> {*PP<sup>a</sup> <sup>i</sup>* }, and *PVm* <sup>=</sup> {*PV<sup>a</sup> i* } where subscripts *b* and *m* specify *Wb* and *Wm* respectively, and superscripts *a* and *d* specify ascending and descending orders respectively.

Although the length of *Wb* and *Wm* are fixed, the number of airflow breaths in these windows typically vary during different sleep stages and across different patients. Thus, the ordered sequences (*PPb*, *PVb*, *PPm*, and *PVm*) will be filtered to keep only the intervals and amplitudes that contribute most to the baseline estimate in *Wb* and the apneic events in *Wm* respectively. This can be mathematically expressed as follows:

$$L\_{PP\_b} \quad = \begin{array}{c} \lfloor F\_{PP\_b} N\_{PP\_b} \rfloor \end{array} \tag{3}$$

$$L\_{PV\_b} \quad = \begin{array}{c} \lfloor F\_{PV\_b} N\_{PV\_b} \rfloor \end{array} \tag{4}$$

$$L\_{PP\_m} = \begin{array}{c} \lfloor F\_{PP\_m} N\_{PP\_m} \rfloor \end{array} \tag{5}$$

$$L\_{PV\_m} = \quad \lfloor F\_{PV\_m} N\_{PV\_m} \rfloor \tag{6}$$

where *Fs* is the cut-off filter applied to the ordered sequence *s* of length *Ns* in order to generate a filtered sequence of length *Ls* where *s* ∈ {*PPb*, *PPm*, *PVb*, *PVm*}. Accordingly, the filtered sequences include the highest *LPPb* inter-breath intervals and *LPVb* breath amplitudes in *Wb* and the lowest *LPPm* intervals and *LPVm* amplitudes in *Wm*. The filter values were defined individually for each of the sequences to allow them to be tuned separately to maximize the ability to detect apneic windows. The mathematical means of the filtered sequences can now expressed as follows:

$$B\_{PP\_b} \quad = \quad \frac{1}{L\_{PP\_b}} \sum\_{i=1}^{L\_{PP\_b}} PP\_i^d \tag{7}$$

$$B\_{PV\_b} \quad = \quad \frac{1}{L\_{PV\_b}} \sum\_{i=1}^{L\_{PV\_b}} PV\_i^d \tag{8}$$

$$B\_{PP\_m} = \underbrace{1}\_{L\_{PP\_m}} \sum\_{i=1}^{L\_{PP\_m}} PP\_i^a \tag{9}$$

$$B\_{PV\_m} = -\frac{1}{L\_{PV\_m}} \sum\_{i=1}^{L\_{PV\_m}} PV\_i^a \tag{10}$$

where *Bs* is the mathematical mean of the filtered sequence *s*. The relative changes in the inter-breath intervals (*Ic*) and the amplitude of the breaths (*Ac*), with respect to the respiratory baseline, can now be computed as follows:

$$I\_{\mathcal{E}} = \frac{B\_{PP\_b} - B\_{PP\_m}}{B\_{PP\_b}} \tag{11}$$

$$A\_{\mathcal{L}} = \frac{B\_{PV\_b} - B\_{PV\_m}}{B\_{PV\_b}} \tag{12}$$

**Figure 1.** (**a**) Oronasal respiration signal with peak/ valley detections illustrated by (red '\*')/ (blue 'o') respectively. (**b**) Sequence of extracted breath amplitudes of *Wb*. (**c**) Sequence of extracted breath amplitudes of *Wm*. (**d**) *PVb* Sequence of descending ordered amplitudes in *Wb*. (**e**) *PVm* Sequence of ascending ordered amplitudes of *Wm*. (**f**) Extracted inter-breath intervals of *Wb*. (**g**) Extracted inter-breath intervals of *Wm*. (**h**) *PPb* sequence of descending ordered inter-breath intervals of *Wb*. (**i**) *PPm* sequence of ascending ordered intervals of *Wm*. Annotations in subplot (**a**) illustrate example of apnea events present in *Wb*, *Wm*. Arrows point to the breath amplitudes from apnea events shown in *Wb*, *Wm*.

### *2.3. Detection of Apnea Events based on Relative Changes in Respiratory Baseline*

For the classification part, we propose a probabilistic view of classification for automated detection of apnea events. We leverage a Gaussian Mixture Model (GMM) to derive a decision boundary based on probabilistic assumptions about the underlying distribution of the respiratory input features. We denote this modeling scheme as a Gaussian Mixture Model (GMM) classifier. In order to demonstrate the improvement achieved by considering GMM as a generative machine learning model, we compare results with a classical threshold-based detector that uses the same input features for automated detection of apnea events.

To prepare data for classification, a sliding window that is being successively updated each 20 s (step-size = 20 s) was applied over the oronasal airflow signal. At each of the steps, *Wb* and *Wm* are constructed. Then, Equations (1)–(12) are applied to compute *Ic* and *Ac* (input features to the classification model) while *Wm* provides classification label based on whether or not an apnea event was clinically scored in this window. Considering our data set with 66 patients for training and 30 for testing, Table 2 shows the distribution of the data segments and corresponding labels for both training and test sets.

**Table 2.** Segments for training and testing classification models.


### 2.3.1. Rule-Based Threshold Based Classification

The rule-based classifier detects apnea in detection windows (*Wm*) when the input features *Ic* and *Ac* both activate the classification rules <sup>ˆ</sup>*Ilb* <sup>≤</sup> *<sup>I</sup>* <sup>≤</sup> <sup>ˆ</sup>*Iub* and *Ac* <sup>&</sup>gt; *<sup>A</sup>*<sup>ˆ</sup> where <sup>ˆ</sup>*Ilb*, <sup>ˆ</sup>*Iub* are the classification thresholds for *Ic* and *A*ˆ is the classification threshold for *Ac*. An exhaustive search approach is applied for each of these thresholds in order to learn their optimal values. A novel approach was used to fit the rule-based classifier and learn the classification rules using a two step optimization method. The classification thresholds for *Ic* are optimized first to obtain the receiver operating characteristics curve (*ROC*) with the maximum area under *ROC* (*AUC*) over our training data. Once the *Ic* classification rule is learned, the optimal classification threshold for *Ac* is learned by searching along the selected receiver operating curve (with maximum *AUC*) for the threshold that provides the maximum sensitivity (*TPR*) that constrains (*FPR*) not to exceed the maximum acceptable limit of 20% (*FPR* ≤ 20%). More details about the derivation and tuning of the rule-based classifier can be found in our recent study [49].

### 2.3.2. Classification with Gaussian Mixture Models (GMM)

A Gaussian mixture model (GMM) is a probabilistic modeling framework. In this model, the probability density function (PDF) of **<sup>x</sup>** <sup>∈</sup> *<sup>R</sup><sup>d</sup>* is defined as a finite weighted sum of *<sup>k</sup>* Gaussian distributions:

$$p(\mathbf{x}|\boldsymbol{\Theta}) = \sum\_{i=1}^{k} \gamma\_i p(\mathbf{x}|\boldsymbol{\theta}\_m) \tag{13}$$

such that **x** is the 2-dimensional feature vector [*Ic*, *Ac*] *<sup>T</sup>* computed every time *Wb* and *Wm* are constructed, ∑*<sup>k</sup> <sup>i</sup>*=<sup>1</sup> *γ<sup>i</sup>* = 1, **Θ** is the mixture model, *γ<sup>i</sup>* corresponds to the weight of component *i*, and the density of each component is given by the normal probability distribution:

$$p(\mathbf{x}|\theta\_m) = \frac{|\Sigma\_m|^{-\frac{1}{2}}}{(2\pi)^{d/2}} \exp\left\{-\frac{1}{2}(\mathbf{x} - \boldsymbol{\mu}\_m)^T \boldsymbol{\Sigma}\_m^{-1} (\mathbf{x} - \boldsymbol{\mu}\_m)\right\} \tag{14}$$

The parameters *γ*, the mean *μ*, and the covariance **Σ** are optimized during the training process using the expectation maximization algorithm [50] such that the log-likelihood of the model is maximized. During testing, a likelihood estimate is obtained for the apnea class, defined by the model **Θ***A*, and for the non-apneic (normal respiration) class, defined by the model **Θ***N*. Using the Bayesian classification formula, the likelihood estimates are combined to compute the posterior probability of apnea for the sample **x** :

$$P(A|\mathbf{x}) = \frac{p(\mathbf{x}|\Theta\_A)P(A)}{p(\mathbf{x}|\Theta\_A)P(A) + p(\mathbf{x}|\Theta\_N)P(N)}\tag{15}$$

where *P*(*A*) and *P*(*N*) are the prior probabilities of the apnea and non-apnea (normal) classes respectively. These probabilities were set by symmetry to be equal *P*(*A*) = *P*(*N*) = 0.5 assuming we have no prior knowledge about them. The combination of the two GMMs and the Bayesian classification formula in Equation (15) form the GMM classifier [51].

#### *2.4. Evaluation of Apnea Detection Results*

#### 2.4.1. Classification Performance over Detection Windows

In this paper, five statistical metrics of accuracy (*ACC*), true positive rate (*TPR*), true negative rate (*TNR*), positive predictive value (*PPV*), and *F*<sup>1</sup> score are applied to assess the performance of the proposed modeling framework over all detection windows (*Wm*):

$$\text{ACC} = \frac{\sum TP + \sum TN}{\sum TP + \sum FP + \sum FN + \sum TN} \times 100\% \tag{16}$$

$$TPR \quad = \frac{\sum TP}{\sum TP + \sum FN} \times 100\% \tag{17}$$

$$TNR \quad = \begin{array}{c} \stackrel{\sum T N}{\sum F P + \sum T N} \times 100\% \end{array} \tag{18}$$

$$PPV \quad = \frac{\sum TP}{\sum TP + \sum FP} \times 100\% \tag{19}$$

$$F\_1 = -2\frac{TPR.PPV}{TPR + PPV} \times 100\% \tag{20}$$

where *TP* (true positive) is the number of apneic windows that were correctly classified as such, *TN* (true negative) is the number of normal windows that were correctly classified as such, *FP* (false positive) is the number of normal windows that were falsely classified as apneic, *FN* (false negative) is the number of apneic windows that were missed by the classifier. *ACC* is a classical measure for binary classification but is not enough in this problem due to class imbalance between the apnea and normal classes [52–54]. Thus, *TPR*, *TNR*, and *PPV* are used to report a more detailed performance in detecting apneic and normal windows. The *F*<sup>1</sup> score considers *TP* and *FP* detections simultaneously and thus accounts to the *TPR*/*PPV* tradeoff reporting a more comprehensive idea on the overall performance of the proposed model.

### 2.4.2. Receiver Operating Characteristics (*ROC*) Curve

The Receiver Operating Characteristics (*ROC*) curve is an effective tool used to graphically illustrate the diagnostic ability of a binary classification system as its classification threshold is varied [55]. This curve simply plots the *TPR* against the false positive rate (*FPR* = 100% − *TNR*) at various discrimination threshold settings. The Area Under receiver operating characteristics Curve (*AUC*) is used as a measure of the overall ability of the classification model to automatically detect sleep apnea events. A greater *AUC* indicates a more useful and effective classification model. Additionally, the *ROC* curve can be used for optimizing classification models by finding the operating threshold

that provides the highest *TPR* for the allowable *FPR* level. This approach was used in learning the classification rules of the rule-based classifier.

### **3. Results**

For the proposed GMM classifier (AICPV with GMM) and the rule-based threshold classifier (AICPV with Threshold), we used the PSGs of 66 patients for training, tuning, and optimizing the classifiers. The trained classifiers were then tested over the PSGs of the other distinct 30 patients.

### *3.1. Rule-Based Threshold Classifier*

The optimal filter values were set to *FPPb* = 0.3, *FPVb* = 0.4, *FPPm* = 0.1, and *FPVm* = 0.3. The classification rules for detecting apneic windows were identified as follows

$$0.05 \quad \le \quad l\_c \le 0.95\tag{21}$$

$$0.957 \quad \le \quad A\_c \tag{22}$$

such that an apneic window is detected by the rule-based classifier whenever both rules are active.

#### *3.2. GMM Classifier*

Cross validation over the training data set was used for investigating and selecting the choices of parameters for the GMM models. These parameters include the number of Gaussian distributions needed to model each class, and the type of covariance matrix used (diagonal or full symmetrical). Our results show that 12 Gaussian components are needed to model the GMM of the apneic class and 11 Gaussian components are needed to model the GMM of the normal class along with diagonal covariance matrices for both GMMs.

#### *3.3. Classification Performance Comparison over the Testing Data Set*

We performed an overall evaluation for the proposed model (AICPV with GMM—AICPVwGMM) over the 30 patient test data and we compared the performance results with the rule-based classification model that uses the same input features (AICPV with Threshold - AICPVwTH). Also, we considered performance comparison with two well known published algorithms for automated apnea detection using the oronasal thermal airflow signal and similar time-domain based features [39,45]. The first algorithm implements a classical threshold based classification model [39] while the other one uses an artificial neural network classification model [45]. Note that [39] includes an additional module for classifying the type of detected apnea using a neural network classifier trained on the thoracic effort signal. Nevertheless, we just included the apnea detection module from this study since the proposed algorithm uses only a single channel of oronasal airflow.

In order to do a fair comparison, the four algorithms AICPVwGMM, AICPVwTH, Refs. [39,45] were all trained, evaluated, and tested on identical data within our data set. Table 3 comprehensively compares classification performance over the test data between these four algorithms. First, it can be noticed that the AICPVwGMM and the AICPVwTH outperform the two previously published algorithms in all performance measures of this paper. The AICPV algorithms demonstrate a higher ability to detect apnea events (reflected by their *TPR*) as well a higher ability to detect normal respiration patterns (reflected by their *TNR*) as opposed to [39,45]. An overall better classification performance of the AICPV algorithms can be demonstrated with the *F*<sup>1</sup> and *AUC* values compared to [39,45]. The improvement achieved with AICPV algorithms is mainly caused by the dynamic approach considered in these algorithms such that apneic events are characterized based on relative changes in airflow breath amplitudes and inter-breath intervals with respect to local respiration baseline.


**Table 3.** Comparison of performance over a 30-patient test data set between the rule-based threshold classifier (AICPVwTH), the proposed Gaussian Mixture Models (GMM) classifier (AICPVwGMM), and two previous related studies.

Comparing the performance obtained with the proposed GMM based classifier (AICPVwGMM) and the rule-based classifier (AICPVwTH), we can notice a 65.2% increase in the *PPV* of the apnea detections obtained with the GMM based classifier as opposed to the rule-based threshold classifier. Recognizing *TPR*/*PPV* tradeoff, and that *TPR* and *PPV* are performance metrics of competing natures, it can be noticed that using a GMM-based classifier, we can achieve a high *TPR* with a significantly improved *PPV* compared to the rule-based classifier. This can be also noticed by observing the significant increase in the *F*<sup>1</sup> score for the detections with the proposed GMM model as opposed to the rule-based one. Also, higher *TNR* and *ACC* are obtained with the proposed algorithm reflecting a higher ability to detect normal respiratory patterns. The overall classification performance indicated by *AUC* also shows an improved detection with the proposed GMM modeling framework.

To provide an in-depth analysis and understand the sources of improvement with the proposed model as opposed to the rule-based classification model, we did a comprehensive analysis for the test performance of the proposed algorithms over different apnea types and different apnea severity conditions. Tables 4 and 5 provide a detailed comparison between the GMM-based classifier and the rule based classifier over different apnea types and different apnea severity levels. A clear increase in the ability to detect OSA and CSA events can be noticed in the *TPR* of these detections using the proposed algorithm along with a significant increase in the *PPV* of all types of apnea detections. It can be also noticed that the *TPR* of the MSA detections is superior with both algorithms and that it didn't change between them which is mainly due to the fact the MSA events are minority compared to OSA and CSA events.

Looking at the detailed performance of the proposed GMM model and the rule-based classifier over different apnea severity conditions, we can notice that the best performance for the rule-based threshold classifier is in severe apnea patients and that the performance significantly degrades over less severe cases. On other hand, the GMM based classifier maintains a high ability to detect apnea events in severe patients, but more importantly, it has a significantly higher ability to detect apneic events in less severe apnea patients which can be clearly seen through the *TPR* of these detections. Moreover, looking at the *AUC* values, we can also notice an excellent overall ability to detect apnea events in severe patients for the GMM modeling framework as well as a significantly improved overall ability to detect apnea in less severe patients compared to the rule-based classification model.


**Table 4.** Test performance of AICPV with GMM classifier over different types of apnea and different apnea severities.


**Table 5.** Test performance of AICPV with threshold classifier over different types of apnea and different apnea severities.

Finally, we performed a comprehensive analysis for the performance per class of apnea types among different patient severity levels. Table 6 evaluates how the proposed model performs on detecting different apnea types in each of apnea severity classes. As it can be seen in the table, the proposed model AIPCVwGMM maintains a high ability to detect different apnea types regardless disease severity in the test patients. MSA events are very rare in mild patients which caused low and skewed detection rates for these patients. Also, the final row in Table 6 was left empty since there are no MSA events in the class of none/minimal apnea patients. In general, the class of none/mnimal apnea reflects patients that are healthy or with few significant apnea events and so it a class of less interest compared to other disease states. Nevertheless, we kept performance results on all disease severity classes to report comprehensive assessment of the modeling framework.

It is worthy to be mentioned that the implemented algorithms were trained and tested using full overnight PSG records. Our goal is to test the proposed framework and to compare it with existing works in a more practical setting as opposed to many previous studies that considered shorter records avoiding the class imbalance problem and excluding segments with low signal to noise ratio (SNR). False positives were affected by the class imbalance problem, segments of low airflow signal quality, and irregular breathing patterns and artifacts. High airflow peak amplitudes resulting from increased respiratory effort after the end of apnea events may affect false positives by contributing falsely to the respiratory baseline. Future work may consider adding more advanced signal filtration algorithms that can allow more accurate detection incase of artifacts as well as to reject airflow segments with very low SNR.


**Table 6.** Detailed test performance of AICPV with GMM classifier (AICPVwGMM) over different Apnea–Hypopnea Index (AHI) severities per apnea class.

#### **4. Discussion of Results**

A probabilistic-based framework was developed in this study for automated apnea detection using single channel data from oronasal airflow record (FlowTH). The proposed framework leverages AASM recommendations to define apnea based on relative changes with respect to respiratory airflow baseline. to overcome the absence of a precise mathematical definition for airflow baseline, a data-driven method is developed to represent it based on two features: The breath amplitudes and inter-breath intervals. The apnea is then characterized based on relative changes in these features between two sliding windows: The baseline window which represents the current respiratory baseline and the detection window where an apneic event is to be detected.

For automatic detection of apneic events, we considered classification based on a probabilistic view using a GMM-based modeling framework. The proposed framework showed a significantly improved performance in detecting apnea compared to a rule-based classifier that uses the same input features as well as two previously published algorithms that respectively use threshold-based classification and neural networks, applied on time domain features from the same respiratory signal. Using relative changes in respiratory features to define apnea enabled a dynamic approach that accounts for continuous changes in the respiratory baseline making AICPVwTH and AICPVwGMM algorithms significantly more capable to detect apneic events than previous studies that considered absolute changes in classification features overlooking relative ones. Comparing the proposed model AICPVwGMM with AICPVwTH that uses a rule-based classifier showed a significantly improved performance for the GMM model characterized by achieving high *TPR* with a significantly improved overall *PPV* for all types of apnea detections. The proposed model also allows much better performance over different apnea severity levels compared to the rule-based classifier which showed best performance over severe apnea patients only with significantly degraded performance over less severe disease classes.

In recent years, many studies considered oronasal thermal airflow signal for automated apnea detection [19–21,39,42–45,56]. Nevertheless, patients with severe and moderate OSA conditions have been a major focus of many of these studies while not giving sufficient attention to less severe patient populations or other types of apnea events. The present results highlight the importance of evaluating apnea models on patients of varying severity conditions as well as on different apnea types. This also agrees with previous literature which demonstrated that high performance accuracies achieved with patients with high severity levels may not be generalizable to other groups of patients [57,58]. The detailed analysis presented in this paper using patients of varying apnea severity conditions and different apnea types provides the basis for a more comprehensive understanding of the performance of apnea detection systems.

Comparative analysis between the performance of proposed modeling framework and previously published research highlights some of the previously reported limitations for single-respiration channel based apnea detection methods. In particular, previous studies have reported significant fall in the diagnostic accuracy of automated sleep systems that measure two or fewer physiological parameters as opposed to those that measure three or more physiologic variables [38,57]. Although many automated sleep systems have proved effectiveness assisting lab PSGs and at home, they still cannot completely replace dedicated centers for sleep studies [59].

The present study leverages AASM recommendations for apnea detection in many aspects. We considered the AASM recommended sensor for scoring apneic events in PSG diagnostic studies which is the oronasal thermal airflow sensor. The criteria in the AASM manual were also employed for scoring an event after a drop in peak thermal airflow signal excursions by ≥ 90% of the corresponding baseline for a duration ≥10 s [11]. Nevertheless, there are many sources of uncertainty in the criteria defined in literature. First, the AASM doesn't provide a precise definition for the respiratory baseline. Second, the published criteria for mathematical scoring apnea are not consistent and vary with different standards [60]. Moreover, the high intraobserver and interobserver variability due to human scoring and human errors [17] significantly affect the robustness and performance of automated sleep systems. Adopting a probabilistic framework in the proposed study provides an efficient approach to propagate different sources of uncertainty using a data-driven modeling framework optimized with respect to the ability to detect apnea events. Interestingly, using the proposed GMM-based classification system showed an overall improved performance in apnea detection and a more consistent and generalizable performance across patients with different severity levels.

Finally, the presented approach focuses on automated detection for apneic events using oronasal airflow respiration signal. Future work many extend the algorithm by adding an additional input through the nasal pressure signal in order to study dynamical changes in this signal during hypopnea as recommended by the AASM. Using the oronasal airflow and the nasal pressure signals will enable detecting both apnea and hypopnea events needed to estimate the AHI in order to provide a diagnosis for sleep apnea severity. Additionally, this will improve the ability to estimate respiratory baseline by incorporating two signals instead of one leading potentially to a decreased number of false positive detections. Moreover, future work may consider adding input from the respiratory effort signal to provide diagnosis for the type of detected events (OSA, OSA, or MSA).

### **5. Conclusions**

In this study, a new algorithm is developed for automated detection of sleep apnea events using single channel data from oronasal thermal airflow sensor (FlowTH). The algorithm leverages AASM recommendations by defining apnea using relative changes in the oronasal airflow signal with respect to the evolving respiratory airflow baseline. A novel approach is developed to represent the respiratory airflow baseline based on two features: The inter-breath intervals and breath amplitudes. In order to detect apneic events, we considered a probabilistic view for classification using a GMM modeling framework. The proposed framework showed a significantly improved apnea detection performance for different apnea types and severity conditions as opposed to a rule-based classifier that uses the same input features as well as two previous methods for automated detection using the same respiratory source. The proposed modeling framework achieved an overall summary performance of *TPR* = 88.5%, *TNR* = 82.5%, and *AUC* = 86.7%.

**Author Contributions:** Conceptualization, H.E. and J.K.; methodology, H.E and J.K.; software, J.K.; validation, H.E. and J.K; formal analysis, J.K.; investigation, H.E., J.K., and M.R.; resources, H.E., J.K., D.T., and S.K.R.; data curation, H.E. and J.K.; writing—original draft preparation, H.E.; writing–review and editing, H.E, J.K., and M.R; visualization, H.E and J.K.; supervision, D.T., S.K.R., and C.-H.C.; project administration, D.T. and S.K.R. All authors have read and agreed to the published version of the manuscript.

**Funding:** This research received no external funding.

**Conflicts of Interest:** The authors declare no conflict of interest.

### **References**


**Publisher's Note:** MDPI stays neutral with regard to jurisdictional claims in published maps and institutional affiliations.

© 2020 by the authors. Licensee MDPI, Basel, Switzerland. This article is an open access article distributed under the terms and conditions of the Creative Commons Attribution (CC BY) license (http://creativecommons.org/licenses/by/4.0/).

## *Article* **Optimization of Warehouse Operations with Genetic Algorithms**

### **Mirosław Kordos 1,\*, Jan Boryczko 1, Marcin Blachnik <sup>2</sup> and Sławomir Golak <sup>2</sup>**


Received: 7 June 2020; Accepted: 9 July 2020; Published: 13 July 2020

**Abstract:** We present a complete, fully automatic solution based on genetic algorithms for the optimization of discrete product placement and of order picking routes in a warehouse. The solution takes as input the warehouse structure and the list of orders and returns the optimized product placement, which minimizes the sum of the order picking times. The order picking routes are optimized mostly by genetic algorithms with multi-parent crossover operator, but for some cases also permutations and local search methods can be used. The product placement is optimized by another genetic algorithm, where the sum of the lengths of the optimized order picking routes is used as the cost of the given product placement. We present several ideas, which improve and accelerate the optimization, as the proper number of parents in crossover, the caching procedure, multiple restart and order grouping. In the presented experiments, in comparison with the random product placement and random product picking order, the optimization of order picking routes allowed the decrease of the total order picking times to 54%, optimization of product placement with the basic version of the method allowed to reduce that time to 26% and optimization of product placement with the methods with the improvements, as multiple restart and multi-parent crossover to 21%.

**Keywords:** warehouse optimization; genetic algorithms; crossover

### **1. Introduction**

A large share of operating costs related to product storage is connected to order picking. Based on many studies, it has been established that about 60% of warehouse operation costs are the costs of picking up goods when completing orders [1]. As the speed of this operation is a decisive factor in the response time to customers' orders and is one of the factors contributing to their decision about choosing or not the same company at the next purchase, it seems that the role of the speed of order completion is even greater.

Thus, shortening the time of order picking is the most important and most beneficial factor in reducing the costs of operating the warehouse. It can be achieved without significant costs by optimizing the locations for particular products in a warehouse and then determining the fastest order completion routes in the optimized product placement.

Some aspects of the optimization seems obvious, for example that items that are often ordered together should be placed close to each other, and frequently purchased items should be located close to the delivery point. However, for the case of discrete variables, considered in this paper, where the location sizes are fixed (some goods, e.g., storage of the sand, gravel, and so forth, can be expressed by continuous variables, but this problem is not considered here), with *N* items in the warehouse, the number of all possible their placements is *<sup>N</sup>*!. For *<sup>N</sup>* <sup>=</sup> 100 it gives *<sup>N</sup>*! <sup>=</sup> 9.33 <sup>×</sup> 10157, and even

if the computer could analyze 1 billion permutations per second, it would 9.33 <sup>×</sup> <sup>10</sup><sup>157</sup> years to find the best product placement. So in practice, designing manually an optimal product placement is impossible, as the number of possible arrangements significantly exceeds the possibilities of analyzing all solutions by humans, or even by a computer program, which tries all possible configurations.

In this paper, we present a complete, fully automated system based on artificial intelligence, particularly on genetic algorithms, which can overcome the limitations of the search space by an intelligent search. The system usually analyzes only several thousands to several tens of thousands possible product placements to find the optimal solution or at least a solution so close to the optimal one, that in practice it will not make any difference. Moreover, the system returns also the quickest order picking routes. The advantage of applying such a solution is speeding up the operations and thus reduction of warehouse operating costs (where typically 60% are the costs of order picking [1]) and the possibility to serve more customers by the same employees in the same time and thus to further increase the sales and profits.

Although this issue has been analyzed for a long time, and especially in recent years its intensive development has been taking place [2–4], we were not able to find in the literature a complete, accurate, fully automatic solution, which would consider the real distribution of orders in the optimization of product placements. All the papers we were able to find presented only some partial approaches, with big simplifications, for example that the route length is determined only by one product with the longest distance from the entrance [2], or that the user is responsible for re-allocating the products by placing the more frequently used closer to the entrance [5].

Moreover, frequently the list of products within a single order is quite long, especially in the warehouses which sell goods mostly to retailers, what makes the optimization even more important. The purpose of this paper is to fill the gap by presenting our solution and by discussing its particular aspects and their influence on the accuracy and speed of the warehouse optimization process.

First we define the problem (Section 2), then we list the main points of our contribution (Section 3), next we come to the details in the following order: the literature review (Section 4), presentation of the proposed solution (Sections 5.1–5.4), experimental results (Section 6) and conclusions (Section 7).

### **2. Problem Statement**

The product placement determines the locations (usually shelves) of particular products in the warehouse. We define the cost of a product placement as the sum of the shortest order picking routes over all orders included in the order list for this product placement (Equation (1)). The assumption behind that is that if robots are used to collect the orders, then they will exactly follow the shortest routes found. If humans pick the orders, they will in most cases follow the same routes, however, sometimes they may decide to change the path. This can be however treated as random process and thus cannot be taken into account in the optimization.

$$\text{Cost} = \sum\_{n=1}^{N\_{\text{ord}}} \cdot \text{Rout}\_{\text{min}}(n), \tag{1}$$

where *Nord* is the number orders in the order list and *Routemin*(*n*) is the shortest order picking (order completion) route found for the *n*-th order.

The problem to solve is to find a product placement with with as low cost as possible. In other words we need to minimize the cost given by Equation (1) by proper products placing at particular locations in a warehouse.

For that purpose we developed the product placement optimization algorithm, which we describe in the following sections.

The inputs to the algorithm are:


The outputs of the algorithm are:


Here we only present the main idea and the details about each input and output can be found in the subsequent sections. The order list contains all considered orders, which should be the orders expected in a certain future period of time. In most cases these can be the past orders, as we can expect that the future orders will have the same distribution of products. Otherwise these can be the past orders from a corresponding season of the last year or the predicted future orders. Each order consists of several (at least one) products. To complete the order, the locations of all the products that are included in the order must be visited and the products must be picked. Thus a sub-problem of the main problem of product placement optimization is to find the quickest (shortest) order completion route for each order. This is necessary, as the sum of the quickest order picking routes is the main objective that we want to minimize by optimizing the placement of particular products in the warehouse.

This sub-problem of finding the quickest order completion route is equivalent to the Traveling Salesman Problem (TSP). While, the whole optimization of product placement is a different, much more complex problem. The main differences are:


### **3. Contribution**

The main points of the contribution of this paper are:


### **4. Related Works**

#### *4.1. Warehouse Planing and Operations*

A lot of literature positions were devoted to warehouse operations. Van Gils et al. [4] provided a review and classification of the scientific literature investigating combinations of tactical and operational order picking planning problems. Grosse et al. [6] analyzed human factors in order picking. Dijkstra and Roodbergen [7] discussed predetermined order picking sequences, including various layouts of the warehouse and its aisles. They also discussed a dynamic programming approach that determines storage location assignments for those layouts, using the route length formulas and optimal properties. Bolaños Zuñiga [3] presented a formal mathematical model for simultaneously determining storage location assignment and picker routing, considering precedence constraints based on the weight of the products and location for each product in a general warehouse. Bartholdi [1] in his book presented practical considerations for warehouse planning and construction. Rakesh [8] discussed methods that determine optimal lane depth, number of storage levels, and other parameters of warehouse layout to minimize space and material handling costs. The book of Davarzani [9] discussed warehouse planning, technology, equipment, human resource management, connections to other department and companies. Zunic [10] considered various warehouse designs, especially the V-shape isles and calculated the order picking routes for these designs. Dharmapriya [11] discussed the use of simulated annealing for the warehouse layout optimization taking into account the total demand and traveling cost, but without considering the co-existence of various products in the same orders.

### *4.2. Genetic Algorithms in Warehouse Optimization*

Artificial intelligence methods, in particular genetic algorithms, are solutions that have numerous successful implementations and that have been rapidly gaining popularity in recent years and were applied also to warehouse operation optimizations [12–14].

Genetic algorithms have two important advantages: fast intelligent search used to find the product placement and a global cost function.

Their first advantage is that due to intelligent searching, genetic algorithms do not need to analyze all solutions (all possible permutations of product locations), which is impossible due to their number. They usually analyze only a few thousands up to few hundred thousands of product placements (and not all *N*! possibilities) to find the solution. Although genetic algorithms do not guarantee finding the optimal solution every time, the solutions found are close enough to the optimal solution so that in practice this does not make a significant difference.

Their second advantage is that using genetic algorithms, we do not have to define ourselves the rules that characterize good solutions. This is a very important, because usually we do not know how to define the rules optimally, and we only have some intuitive knowledge (e.g., goods often purchased together should be close to each other). However, the expression of this knowledge in the strict mathematical form is impossible because of the complexity of the system and the frequently opposite optimums of various order completions. With genetic algorithms it is enough to formulate the cost function, which is expressed here as the sum of all order picking route lengths or as the total time required to complete all orders from a certain period.

In genetic algorithms, the problem is coded using arrays called chromosomes by analogy with encoding in the chromosomes of biological organisms [12,13]. Each product placement encoded by a chromosome represents one solution (one individual). Initially, a pool of random solutions is generated (in each solution the products are randomly assigned to the locations in a warehouse). Then an intelligent search is applied with the help of three basic operations—selection, crossover and mutation. The selection mechanism selects individuals for the crossover operator. It is organized by analogy to the biological process, where the better individuals have a higher probability of becoming parents and exchanging information to create offspring.

The crossover operator generates a new individual (child) from the existing ones (parents). In this way it allows to combine the information encoded in the chromosomes of different individuals into one new individual. The mutation mechanism exchanges the values between some positions in an individual chromosome. Then selected individuals are promoted to the next generation. The process is repeated iteratively as long as the satisfactory solution is found or as long as the improvement in the solutions is still occurring.

To sum up, it should be stated that the use of genetic algorithms or other evolutionary optimization methods in applications to warehouse systems, including the optimization of the distribution of goods and order picking routes can bring measurable benefits to companies using these solutions, accelerating their work and reducing operating costs.

Although the idea of applying genetic algorithms to warehouse optimization or order picking route optimization was presented in some literature positions, we have not found a complete automatic solution, which considers the order distribution, as we present in this paper.

As the optimization of the order picking route (which is equivalent to the traveling salesman problem— see Section 2) is a much simpler problem than the optimization of product placement in the warehouse, as it can be expected, much more papers are dedicated to the order picking route optimization and only few papers discuss the product placement optimization. Below we shortly present some of them.

Wang [5] applied genetic algorithms to optimize a fitness function consisting of three weighed terms—the turnover of goods, the gravity center of storage racks and the relevance of the goods. Wei [15] used a similar approach with genetic algorithm with PMX crossover and the fitness were defined by the terms of the aisle access, the weight of the items and the dimensions of the shelves.

Avdeikins and Savrasovs [2] applied genetic algorithms to warehouse optimization using order crossover (OX). Each individual in the population represented a warehouse layout. Each gene was a unique item. In their solution the fitness of each individual was calculated as the sum of maximal picking distance for each order—*fitness*(*i*) = ∑ *Okdmax*(*I*), where *O* was the order from the set 1, 2, ..., *k* and *dmax*(*I*) was the distance to furthest picking position. Distances were expressed by an integer value that for the first item was 1, for second 2, and so forth, increasing by 1 from one SKU to another. Using such fitness function moves the most frequently sold items closer to the warehouse entry, but this does not place in neighboring locations items that are frequently sold together, as this solution did not consider this aspect, neither it determines the real order picking routes and thus this cannot be considered a complete solution.

As the solution presented by Avdeikins and Savrasovs [2] may at the first glance seem similar to our solution, it is worth pointing out the differences. In our solution the fitness is calculated as a sum of all order picking route lengths. This is a fundamental difference between their work and our solution, as this allows us to minimize the time of the real order picking operations and thus for obtaining very accurate solutions, which also minimizes the distances between items contained in one order. The next difference is, that we use real transition costs (or real distances) and not an approximation by increasing the distance always by one, as was used by them. The next difference is that our solution determines as well the product placement as the fastest order picking routes. Moreover, we use newer, effective crossover operators and propose a lot of improvements to accuracy and speed of the optimization.

### *4.3. Crossover Operators*

Proper design of the crossover operator is a crucial factor in genetic algorithm performance. In this subsection we review the crossover operators and present in detail the AEX and HGreX operators, which are used in our solution.

Crossover allows to combine together the most valuable information from two or more different chromosomes (parents) into one chromosome (child) that can represent a better solution than its parents. For that kind of problems, where each item can occupy only one location at a time and each location must be occupied by one item, as order picking route optimization or product placement optimization in a warehouse, special crossover operators must be used to ensure that there will be no duplicate elements and that each element will be present in the newly created individual. Several such crossover operators have been developed.

Hassanat and Alkafaween [16] proposed several crossover operators, such as cut on worst gene crossover (COWGC) and collision crossover, and selection approaches, as select the best crossover (SBC). COWGC exchanges genes between parents by cutting the chromosome at the point that maximally decreases the cost. The collision crossover uses two selection strategies for the crossover operators. The first one selects this crossover operator from several examined operators, which maximally improves the fitness and the other one randomly selects any operator. The SBC algorithm applies multiple crossover operators at the same time on the same parents, and finally selects the best two offspring to enter the population. Hwang presented the order crossover (OX) and cycle crossover (CX) operators [17]. Tan proposed heuristic greedy crossover (HGreX) and its variants HRndX and RProX [18]. Other popular crossover operators comprise partially mapped crossover (PMX) edge recombination crossover (ERX) and alternating edges crossover (ERX) [19].

Several comparisons of the performance of these crossover operators can be found in the literature [19]. Based on these comparisons, the best performing methods for the traveling salesman problem were most frequently the variants of the HGreX crossover operator and the second best was the AEX crossover operator. For this reason we decided to start our approach from applying these two crossover operators for our warehouse optimization problem.

HGreX is only suitable for those kinds of problems, where the cost of transition between two positions can be defined. For example, it can be used to find the shortest order picking path, as we can define the distances (costs) between particular locations that contain the products from the orders and thus must be visited. However, it cannot be applied to optimization of the product placement in the warehouse, because the distribution of products is not directly related to any single order picking route, but to a whole set of different routes. Thus in this case, we can define the global goal, which is the minimization of sum of the lengths of all order picking routes, but we cannot express the cost between any two positions. AEX on the other hand does not use the transition cost and therefore can be applied also to the problems, where such cost cannot be defined, as the product placement optimization in a warehouse. First we will present the AEX operator and then the HGreX operator.

AEX creates the child from two parents by starting from the value, which is at the first position in the first parent. Then it adds this value, from the second parent, which in the second parent follows the value just taken from the first parent. Then again a value from the first parent that follows the value just selected from the second parent and so one. If this is impossible, because some element would repeat, then a random not selected so far element is chosen. In the presented example each position in the chromosome represents one location in the warehouse and each letter represents one product.

Let us assume we have two parents P1 and P2:

P1 = [ A B C D E F G H ] P2 = [ H A D B G F E C ]

To create the child, We start from any position of the first parent P1. Let us start from A. Then we add this value which is in the second parent after A, so we add D

Ch = [ A D \_ \_ \_ \_ \_ \_ ] and the values remaining in the parents: P1 = [ ABC DEFGH] P2 = [ H A DBGFEC]

Next we add to the child this value, which is in P1 after D, that is E

Ch = [ A D E \_ \_ \_ \_ \_ ] and the values remaining in the parents: P1 = [ ABC D EFGH] P2 = [ H A DBGF EC]

Next we add to the child this value, which is in P2 after E, that is C

Ch = [ A D E C \_ \_ \_ \_ ] and the values remaining in the parents: P1 = [ A B C D EFGH] P2 = [ H A DBGF E C ]

Next we add to the child this value, which is in P1 after C, that is D. However, D has already been used, so this is not a valid choice. In such a case we select randomly one of the remaining values in P1. Let us select G.

Ch = [ A D E C G \_ \_ \_ ] P1 = [ A B C D E F GH] P2 = [ H A D B G F E C ]

Next we add to the child this value, which is in P2 after G, that is F

Ch = [ A D E C G F \_ \_ ] and the values remaining in the parents: P1 = [ A B C D E F GH] P2 = [ H A D B G F E C ]

Next we add to the child this value, which is in P1 after F, that is currently H

```
Ch = [ A D E C G F H _ ]
and the values remaining in the parents:
P1 = [ A B C D E F G H ]
P2 = [ H A D B G F E C ]
```
And finally we add to the child this value, which is in P2 after H, that is currently B and the child becomes:

Ch = [ A D E C G F H B ]

The HGreX crossover operator works in similar way to AEX. The difference is, that it does not take alternatively the elements from both parents, but always chooses this element from the two parents to which the distance (cost) from the current element is shorter (lower).

Let us assume we have the same two parents P1 and P2:

P1 = [ A B C D E F G H ] P2 = [ H A D B G F E C ]

To create the child, we start from any position of the first parent P1, let us start from A. Then we add this value which has lower transition cost (shorter distance) to A from the two values that appear directly after A in both parents, that is from B and D. Let us assume that the cost of going from A to B is 12, and from A to D is 15. So we choose B as the next position in the child.

Ch = [ A B \_ \_ \_ \_ \_ \_ ] and the values remaining in the parents: P1 = [ A BCDEFGH] P2 = [ H A D BGFEC]

However, if the costs were—A to B: 18, and from A to D: 15, then we would have chosen D as the next position in the child. The conflicts are resolved identically as in AEX.

In Section 5.4.1 we introduce multi-parent versions of the crossover operators.

### **5. The Proposed Method**

In this section, we describe the proposed genetic algorithm based method that optimizes the product placement and order picking routes in the warehouse. The purpose of the method, as described in Section 2, is to minimize the product placement cost given by Equation (1), this is to find such assignment of particular products to positions in the warehouse, which minimizes the sum of the shortest order picking routes over all orders from the order list. Thus, the optimization process consists of the outer procedure (main process), which is the product placement optimization (presented in Section 5.2) and the inner procedure (the sub-process), which is the order picking route optimization for each considered product placement (presented in Section 5.3).

#### *5.1. Data Format and Problem Encoding*

This subsection explains the input data format and the problem encoding in genetic algorithm chromosomes. At the end also calculation of the transition cost matrix is explained.

To determine the quality of a given product placement (see Algorithm 1), first we must find the order picking routes for each order (see Algorithm 2). To find them, we must calculate the transition cost matrix (costs of moving between each pair of locations in the warehouse), and to calculate the full matrix, the user must provide the transition costs (or distances) between the adjacent locations.

Figure 1 shows a sample layout of a very small warehouse, which we use to explain the data format and problem encoding. Of course the real warehouses, for which the methodology was created will be much larger. In this example the distances entered by user are shown with the color lines in Figure 1 and the distance matrix with these entries is shown in Table 1. As distances are symmetric (e.g., *dist*(*loc*5, *loc*8) = *dist*(*loc*8, *loc*5)), it is enough to fill the distances over the diagonal. The remaining distances (e.g., *dist*(*loc*7, *loc*10)) will be calculated automatically.

**Figure 1.** Sample warehouse structure used to explain the data format and problem encoding. Distances between neighboring locations: in blue the distances of 1 unit, in red of 1.5 unit, in green of 3 units, in violet of 4 units.

The warehouse layouts with product placements and the order picking routes are encoded in the chromosomes. Let us assume that the number of available products equals the number of locations in the warehouse, that is, 13 for this sample warehouse and the products names are: A,B,C,D,E,F,G,H,I,J,K,L,M. Let us also assume that there are three different orders, which occur with the same frequency and which consist of the following products:

Order1: A,B,C,D,E,F,G Order2: G,H,I,J,K Order3: A,B,K

For the product placement optimization we will encode the problem in a chromosome with 13 positions. At the beginning we generate a population of random individual chromosomes representing the product placements (see Algorithm 1). Let us assume, the 15th randomly generated individual looks as follows:

Layout15 = [G H E F I J A C D K B L M]

In this case the product G is at the location 1 in Figure 1, product H at location 2, and so on. We need to find the shortest order picking routes for each individual (each product placement) to calculate its fitness. For this purpose, we generate a population of random individual chromosomes representing the routes. There is always 0 (which represents the entrance) at the first position and the other positions are occupied by randomly ordered products from this order. Let us assume, the 12th randomly generated individual for Order3 looks as follows:

Order3-Route12 = [0 A K B 0]

The length of this route *lengthR* is given by the formula:

$$lengthR = dist(loc0, loc\mathcal{T}) + dist(loc\mathcal{T}, loc\mathbf{10}) + dist(loc\mathbf{10}, loc\mathbf{11}) + dist(loc\mathbf{11}, loc\mathbf{0}))$$

as in Layout15 *loc*0 is the entrance, product A at location 7, product K on location 10 and product B at location 11.

**Table 1.** The original distance matrix corresponding to the warehouse structure shown in Figure 1 containing only the values required by the algorithm.

This input data format was specially designed in order to require minimal effort from the user entering the data, and at the same time to allow for maximal accuracy of calculations. Only the costs of transitions between neighboring locations are required in the input data. However, if the user wants to enter also the transition costs between some further locations, he is free to do it. The program preserves all distances entered by the user and only calculates the remaining distances.

The transition cost between locations can be entered as distance in meters, but also in seconds as the time needed to cover this distance. This takes into account that for example there is higher cost of covering the same distance vertically than horizontally or that turning around the corner requires more time than covering the same distance along a straight line and thus allowing to obtain higher accuracy of the order picking times. However, the units do not make any difference to the proposed method, which simply considers then as units of cost.

As the available plans of different warehouses can be in many different more or less usable formats, creating a separate software for preparation of the input data for each individual warehouse is no practical, as it would take more time than to enter the transition costs manually. The program does not need to know the geometrical structure of the warehouse. This is an additional advantage, because in this way much less work is required to prepare the input data.

**Algorithm 1** Product Placement (PP) Optimization Process

**Input 1:** Warehouse layout in the form of transition costs between neighboring locations (see Table 1) **Input 2:** The list of orders

**Output 1:** The optimized product placement (PP) in the warehouse

**Output 2:** Shortest order picking routes for each order for the optimized PP

1: With Dijkstra algorithm calculate the full matrix **D** of transition costs between product locations 2: **for** *k* = 1 **to** *numberO f ProcessRestarts* **do**



To calculate the remaining transition costs between each pair of locations (line 1 in Algorithm 1). any algorithm that can do this can be used, for example Dijkstra [20], Floyd Warshall [21] or Bellman-Ford Algorithm [22]. We use Dijkstra Algorithm, because it is the fastest one, especially for sparse graphs (as is the case here), where each vertex is connected only with few other vertices. For a graph of *v* vertices (locations) and *e* connecting edges (transition costs), calculating all the distances with Dijkstra Algorithm with a priority queue has the complexity *O*(*v*(*e* + *vlogv*)), while the complexity of the two other algorithms is *O*(*v*3) and *O*(*ev*2).

Calculating the cost matrix with the Dijkstra Algorithm takes only a very small, practically negligible, fraction of the time of the whole optimization of the product locations. Moreover, once calculated, the cost matrix can be re-used for other product placement as long as the physical layout of the warehouse does not change. The A\* Algorithm [23] can be faster than Dijkstra only when the approximate cost from the current to the target node can be assessed. However, in this problem, we are not able to assess the approximate cost, because we do not know the coordinates of particular locations, but only the transition cost between neighboring locations. Considering these two factors it is not justified to demand from the user preparing additional data with coordinates of each location in order to use the A\* Algorithm, as in this case the gain of the CPU time (usually less than a second) would not compensate the lost of the user's time (usually several hours) spent on preparing the additional data.

### *5.2. Product Placement Optimization*

In this subsection we present the algorithm used to optimize the placement of particular products in the warehouse in order to minimize the total time of completing the orders from the order list. The main optimization process is shown in pseudo-code in Algorithm 1 and as diagrams in Figures 2 and 3. The sub-process, which determines the shortest order picking routes is discussed in the next section.

**Figure 2.** Product Placement Optimization (main process).

**Figure 3.** Product Placement Optimization (inner block of the algorithm shown in Figure 2).

Now will explain the base version of the algorithm with *numberO f ProcessRestarts* = 1 (line 2 in Algorithm 1) and in Section 5.4.3 we will explain the use and purpose of multiple process restarts (*numberO f ProcessRestarts* > 1).

The process starts in line 1, where the Dijkstra algorithm calculates the full matrix **D** of transition costs (or distances) between product locations (see Section 5.1 for details). In line 3 the initial random population of product placements (PP) is generated. In line 7 the genetic algorithm starts the optimization. In line 10 the algorithm checks if any of the current individuals existed previously in current or any past epoch. If so the fitness of such an individual is not calculated but retrieved from the cache. In line 15 the shortest route for completing each different order for the current PP is calculated with Algorithm 2. Since we need to minimize the sum of order completion times, for each considered placement of products in the warehouse we need to calculate the time of each order completion and then add the times. To minimize the computational complexity of this step we group orders consisting of the same products together and assign to such an aggregated order a higher weight, which equals the number of single orders of which the aggregate order is composed. In line 18 the cost of the current PP is calculated, next the fitness is calculated from the cost and the cache is updated. In line 25 the parents for each child are selected and in the next line the child is created with the AEX crossover operator. If the fitness of the best PP has not improved for *NBIPP* iterations (line 28) then the optimization is finished. In line 32 the promotion of children and best parents to the next generation takes place. In the next line the mutation is applied and the cache is updated. Finally in the last line the algorithm returns the best product placement and the corresponding set of the shortest order picking routes.

In order to keep the selection pressure constant and thus the exploration of the search space stronger at earlier stages and the convergence faster at later stages of the process [24] as well in product placement optimization as in the sub-process of order picking route optimization, we use use fitness function normalization. Thus the cost of a given product placement and the lengths of the order picking routes are not used directly as fitness values, but they are re-scaled. We used roulette wheel selection as well for the product placement optimization as for the order picking route optimization. According to Razali [25] and to our experiments, roulette wheel selection and tournament selection give comparable results. Also the selection pressure can be controlled in both (by number of candidates for a parent in tournament selection and by the shape of fitness function in roulette wheel selection). We chose roulette wheel selection for debugging purposes, as with this selection it was easier for to analyze the detailed algorithm behaviour and to fine-tune it.

The cost of the *i*-th product placement *costPP*(*i*) expresses the sum of the shortest routes found for each order picking and is given by Equation (2):

$$CostPP(i) = \sum\_{j=1}^{N\_{diff}} N\_{rep}(j) \cdot length R\_{min}(j) \tag{2}$$

where *Ndi f f* is the number of different orders, *Nrep*(*j*) is the numbers expressing how many times the *j*-th order is repeated on the order list and *lengthRmin*(*j*) is the best (shortest) route found for the *j*-th order completion. The value *costPP* is used to assess the progress and the results of the optimization (see Figure 3).

The fitness *fitnessPP*(*i*) of the *i*-th product placement used by the roulette wheel selection is given by Equation (3):

$$fitnessPP(i) = c\_1 + \frac{costPP\_{max} - costPP(i)}{costPP\_{max} - costPP\_{min}},\tag{3}$$

where *c*<sup>1</sup> is a coefficient (the lower *c*1, the stronger preference for the individuals with lower cost), *costPPmax* is the maximal cost, *costPPmin* is the minimal cost and *costPP*(*i*) is cost of the *i*-th product placement in the population.

The variable *fitnessPP* take larger values for better individuals to ensure that better individuals have higher probability of being selected as parents, while *costPP* take smaller values for better individuals, as they express the product placement cost, which equals the sum of lengths of the shortest routes.

We use dynamic mutation probability (the probability of exchanging some places in the chromosome), which increases gradually during the optimization. Also the probability of mutation is higher for the individuals with lower fitness. This minimizes the chance of disrupting a high-fitness individual and enhanced the exploratory role of low-fitness individuals. Lower mutation rates also allow for more effective caching of the fitness values (see Algorithm 1). The effectiveness of this approach was based on various observations [26]. We use two different mutation operators—Reverse Sequence Mutation (RSM) with and Partial Shuffle Mutation (PSM) with the probability of applying RSM being three times higher. The choice of these mutation operators is based on the experimental study by Otman et al. [27]. The total probability *mutationProb*(*i*) of applying mutation to the *i*-th chromosome is expressed by Equation (4).

$$mutationProb(i) = (c\_i \sqrt{iter} + c\_n \cdot iter\_{NBI}) \frac{Fitness\_{\text{max}}}{Fitness(i) + c\_f} \tag{4}$$

where *ci*, *cn* and *c <sup>f</sup>* are coefficients, *iter* is the current iteration (epoch) of the genetic algorithm, *iterNBI* is the number of iterations without improvement of the best individual. Default universal values of the coefficients for our purposes were experimentally set to *ci* = 0.00001, *cn* = 0.00001 and *c <sup>f</sup>* = 0.3. Further refining of the mutation scheme, together with the mutation—crossover interactions is quite a complex issue and it will be one of our future research topics, when we will attempt to find optimal schemes for different situations.

#### *5.3. Optimization of Order Picking Routes*

As previously discussed, the task of optimization of order picking routes is a part of the optimization of product placement in the warehouse. The sum of the lengths of the order picking routes for a given product placement is its cost *costPP*—the lower the better (see Equation (1)). The process is presented in the pseudo-code in Algorithm 2 and in the diagram in Figure 4.

During the order picking route optimization the locations of products are constant. The product locations are changed by Algorithm 1 only before each round of order picking route optimization. As discussed in Section 5.1, the order picking route starts from the entrance, then visits all locations of products listed in the current order and returns to the warehouse entrance. The task is to optimize the sequence of visiting the locations to obtain the shortest route.

There are two main families of approaches to finding the shortest routes connecting a list of locations—the local search methods (e.g., nearest neighbor or k-opt [28]) and population methods (e.g., genetic algorithms or ant colony optimization [28]). In the Nearest Neighbor Algorithm, we always go from the current location to the nearest yet not visited location. In this way the algorithm implements local search. The local search guaranties finding the nearest location to the current location with 100% probability. However, the drawback of the local search approaches is that they do not include the global view of the situation. Although there was some research to improve these methods, the population methods still have the advantage of applying the global search. A sample route determined with the nearest neighbor, which shows the problems of this method, is shown in Figure 5.

**Figure 4.** Order picking route optimization.

**Figure 5.** A sample route (in red) connecting the locations 0, 4, 6, 2, 1, 12 determined with the Nearest Neighbor Algorithm.

As will be shown in the experimental evaluations in Section 6 and as it is also known from previous studies [29], when the Nearest Neighbor Algorithm is used instead of genetic algorithms for that kind of problems, the calculation time can be dramatically reduced, however at the cost of worse results (on average 10% longer routes).

To determine the shortest route for completing each order, our method can use three different algorithms:


$$FitnessR(i) = c\_2 + \frac{lengthR\_{\max} - lengthR(i)}{lengthR\_{\max} - lengthR\_{\min}},\tag{5}$$

where *c*<sup>2</sup> is a coefficient, which determined the strength of the selection (the lower *c*2, the stronger preference for the individuals representing shorter routes) *lengthRmax* is the maximal and *lengthRmin* is the minimal length of the order picking route in the population. This ensures that the re-scaled proportion between the maximal and minimal fitness is constant during the optimization (see Section 5.2 for explanations).

• Nearest Neighbor Algorithm. This is the fastest method. It also does not guarantee finding the best solution, and in application to our problem it usually finds worse solutions than genetic algorithms.

To provide the optimal trade-off between the accuracy and the speed of the route optimization, the three above algorithms can be applied and the values *Threshold*1 and *Threshold*2 are used to determine, which particular algorithm will be used for a given order, depending on the number of products in the order (see Algorithm 2 and Figure 4).

The number of possible order picking routes is equal to the number of permutations of a *k*-element set, which is *k*!. If there are fewer than *k* = 7 products in a given order than it is faster to evaluate half of possible permutations (for *k* = 6 there are 6!/2 = 360 various routes to examine.) than to use genetic algorithms. For *k* = 7 we need to evaluate 7!/2 = 2520 permutations. On the other hand genetic algorithms are usually able to find the solution evaluating fewer routes (e.g., with population of 50 individuals and 10 iteration, what gives only 500 evaluations). However, genetic algorithms have additional time overhead for operations as selection, crossover, generating random numbers, and so forth. So for 7 products in the order the calculation time is comparable and for more than seven products, genetic algorithms are faster. Thus we propose to set *Threshold*1 = 7.

As can be seen in Section 6, the results obtained with genetic algorithms with multi-parent HGreX crossover are better those obtained with Nearest Neighbor Algorithms. On the other hand Nearest Neighbor Algorithm is faster than genetic algorithms. However, its speed advantage is much higher in a single optimization of the route, where it can be two orders of magnitude faster. In our system, when the route optimization is an iteratively performed sub-process of the product placement optimization, the differences in speed between these two methods is much lower, below one order of magnitude. There are two reasons for that. The first one is that there is implemented cashing of the already calculated routes (see details in Section 5.4.2). The caching overhead is comparable to the computational effort of Nearest Neighbor Algorithm, so it can only accelerate the genetic algorithm based route calculation. The second reason is that the time of running the main process (product placement optimization) is the same in both cases.

*Threshold*2 indicates above which number of products in the order Nearest Neighbor Algorithm should be used to optimize this order picking route. The recommendation to obtain the best results is to set *Threshold*2 to such high value that the Nearest Neighbor Algorithm will not be used at all (e.g., *Threshold*2 = 1000). However, if our data is very big and computational and time resources are limited, we can set *Threshold*1 = 0 and *Threshold*2 = 0 and thus only the Nearest Neighbor Algorithms will be used for the route optimization. Sometimes it happens that there are only very few long orders (e.g., two orders of 50 products and all remaining orders below 20 products), so these few orders will not have significant influence on the final product placement and we can use Nearest Neighbor Algorithm for them to accelerate the calculations and genetic algorithms for all other orders (in this case by setting for example, *Threshold*2 = 30).

### *5.4. Improvements and Accelerations of the Process*

We use the following improvements to obtain better results and to accelerate the process: multi-parent crossover operator, order grouping, caching product placement costs and order picking routes of evaluated individuals, multiple restart, switching among permutations/genetic algorithms/nearest neighbor for route optimization, and parallelization of the process. In the following subsections we present particular improvements. Influence of these improvements on the obtained results is evaluated experimentally and presented in tables and figures in Section 6.

### 5.4.1. Multi-Parent Crossover Operators

As the use of multi-parent crossover operators can significantly accelerate (up to three times) the convergence speed of the classical genetic algorithms [30,31] (as well their single-objective as multi-objective version built upon the NSGA-II algorithm [32]), one of the ideas of this work was to apply the multi-parent approach to the crossover operators in the route and product placement optimization problems in hope that it can provide better results.

There is also another rationale behind increasing the number of parents in the HGreX crossover operator. In the Nearest Neighbor Algorithm the positions are added one by one to the route; each time the closest position is appended to the last position. In the extreme case, when we have a big population so that almost each possible two-element sequence exists in the population and the number of parents in the multi-parent HGreX crossover equals the population size - the so constructed genetic algorithm becomes equivalent to the nearest neighbor search. But on the other hand increasing the number of parents only a little bit, may add the local search component to the genetic algorithms and thus improve the results.

Let us assume that we will use four parents to create each child.

P1 = [ A B C D E F G H ] P2 = [ E G F H A C B D ] P3 = [ G H A E B F C D ] P4 = [ E F H D B A G C ]

Let us start from the fist position in P1, this is from A. Let us assume that there are the following distance d(A,B) = 12, d(A,C) = 15, d(A,E) = 18, d(A,G) = 11. Since in this the distance d(A,G) is the smallest the next position in the child will be G.

Ch = [ A G \_ \_ \_ \_ \_ \_ ] and the values remaining in the parents: P1 = [ ABCDEF GH] P2 = [ E GFH ACBD] P3 = [ G H AEBFCD] P4 = [ E F H D B A GC]

Then Let us assume that there are the following distance d(G,H) = 12, d(G,F) = 8, d(G,H) = 7, Since in this the distance d(G,H) is the smallest the next position in the child will be H, and so on. Conflict resolving is implemented in the same way as in the two-parent version of the operator. In case of AEX we were appending the consecutive positions to the child sequentially from consecutive parents.

We conducted the experiments with various number of parents and the conclusion was that for this problem about 8 parents is the optimal number for the modified HGreX crossover operator. As a result of applying multiple parents for the HGreX crossover, about a two-fold reduction of the number of iterations was observed, but what is more important, also shorter order picking routes were obtained (see the experimental results in Section 6 for details). On the other hand, increasing the number of parents for the AEX crossover did not significantly change the results.

### 5.4.2. Caching Cost of Product Placements and Lengths of Order Picking Routes

In product placement optimization the computational effort of calculating cost *costPP* (see Algorithm 1) of an given product placement and then based on it the fitness value of an individual is high, as it requires finding the shortest picking routes for all orders. Practically always either some parents are promoted to the next generation or some children are identical to some parents (even more in the final stages of the optimization). In this case, we do not calculate the cost of such an individual, but instead we directly assign the already calculated and cached cost of the previous identical individual (See Algorithm 1 and Figure 2). We also check the cache after mutation.

The situation with order picking route optimization with genetic algorithms is different. Here the computational effort of calculating the route length and determining the fitness of an individual is low and there is no use to implement caching for that. However, the cost of calculating the shortest route for an order (which may require several thousands calculations of route lengths represented by all individuals in all iteration of the optimization) is much higher and it makes a sense to implement cache here. Thus before calculating the shortest route for a given order *routemin*(*i*, *j*) (See Algorithm 2 and Figure 4), the cache is checked if it already contains the route for the current order *j*, where all the positions of products in the warehouse were the same. To clarify this, if the whole product placement can be found in cache, the sub-process of route calculations is not invoked from the main process, as the cost of this product placement *costPP* is retrieved from cache. However, if the product placement differs on some positions, the sub-process is invoked and it is checked for each order, if the positions in the product placement occupied by the products contained in the current orders already exist in the cache. If so, the route length *lengthRmin*(*i*, *j*) is retrieved from the cache. Otherwise it is calculated and the cache is updated. The order cache is used only for route calculation with genetic algorithms. (See Algorithm 1 and Figure 2).

The caching is not used for the Nearest Neighbor Algorithm, as the time overhead for the cache is comparable to the time used by nearest neighbor. For the same reason the cashing is not used with very short orders, where the shortest route is determined by permutations.

The caching obviously does not influence the results of the optimization and only allows to accelerate it. It is also worth noticing that the caching is more effective at the later stages of the optimization, as at the beginning the individuals change rapidly. In our experiments the caching allowed to accelerate the product placement optimization several times (see Section 6).

### 5.4.3. Multiple Restart

It may take many iterations for genetic algorithms to converge to the optimal solution. However, the fastest progress occurs at the beginning of the optimization. Genetic algorithms use some random numbers and thus are a stochastic process and as a result different solutions can be found with consecutive runs of the optimization. We observed that in the product placement optimization the best approach is to run the optimization several times only for a few iterations and save the current population. Then the optimization will continue only with the population of the best solution. It is a reasonable approach, because most frequently the optimization, which starts as the best also ends as the best. Thus this method allows joining time efficient optimization with good results (see the experimental verification in Section 6).

### 5.4.4. Order Grouping

All orders containing the same set of products are grouped together into one order. In this way the optimal picking route for this order has to be determined only once. For the purpose of calculating the cost and fitness of a given product placement, the length of the obtained route is multiplied by the number of the orders consisting of the same products (see Algorithm 1).

### 5.4.5. Three Route Optimization Methods

As described in Section 5.3, for the optimal balance between calculation speed and accuracy of the results, the order picking route in Algorithm 2 can be optimized with permutations (only short routes), genetic algorithms or the Nearest Neighbor Algorithm.

### 5.4.6. Process Parallelization

Genetic algorithms scale well for parallel implementations in the cases, where the cost of calculating the fitness function is high, because in these cases there is no need for frequent communication among threads. It is exactly the case of product placement optimization, where using any number of CPU cores up to the number of individuals in the population results in practically a linear increase of performance in the function of the number of CPU cores. Moreover, if there are more CPU cores available than the population size, it makes sense to increase the population size, at least up to three times to use more CPU cores. Although after exceeding the optimal population size the scaling with the growth of CPU number is no longer linear, this implementation is very simple. The other alternative with few hundreds of available CPU cores is to parallelize the calculation of particular order picking routes, but since we did not have access to such computational resources, we were not able to verify the efficiency of this approach.

### **6. Experimental Results**

In this section we experimentally evaluate the method and improvements presented in the previous sections.

We conducted the experiments with our own software, created in C# language. The source code and the data used in the experiments (warehouse plans with lists of corresponding orders) are available from the web page www.kordos.com/appliedsciences2020. Three of these warehouse structures (floor plans) and some sample orders for the warehouse *w*3 are additionally presented in Figures 6 and 7.

The following algorithms of order picking route optimization were evaluated—nearest neighbor, genetic algorithms with HGreX crossover, genetic algorithms with multi-parent HGreX crossover.

The following algorithms of product placement optimization are evaluated: genetic algorithms with AEX crossover, genetic algorithms with multi-parent AEX crossover, genetic algorithms with multi-parent AEX crossover and multiple restart.

As we could not find in literature a complete automatic solution for product placement optimization, which considers the order picking routes, as the solution presented here (see Sections 1 and 4.2), we obviously can not compare numerically our solution to other solutions on the same data.

First we evaluated the multi-parent modifications of the HGreX crossover operator to determine the optimal number of parents (see Sections 5.3 and 5.4.1). The results are presented in Table 2 and in Figure 8. Based on our tests the population sizes of about *N* = 80–120 allowed for the fastest convergence of the process (the lowest number of fitness value calculations). For larger populations, fitness function evaluations had to be performed more times to reach to the same results, so even if it took fewer iterations, the time to reach the results was longer [33]. However, if the populations were smaller it also required more evaluations of the fitness function and if the populations were too small, the convergence of the algorithm was impossible. Only for route optimization, when the number of products in the order was 20 or less, lower sizes of population were used and larger populations may be useful for longer chromosomes than those we used in the experimental evaluation.

**Figure 6.** Samples warehouse structures (floor plans) of the warehouses w5 and w1 used in the experiments. Each numbered cell represents one product location. Blue lines show the distance of 1 unit, red lines of 2 units and green lines of 3 units.


**Figure 7.** A sample warehouse structure (floor plan) of the warehouse w2 used in the experiments. Each numbered cell represents one product location. Blue lines show the distance of 1 unit, red lines of 2 units and green lines of 3 units.

**Figure 8.** The obtained route length *lengthR* (the lower the better) and product placement cost *costPP* (the lower the better) and the number of iterations for route optimization *iterR* with MP-HGreX and for product placement optimization *iterP* with MP-AEX (graphical representation of the data from Tables 2 and 3).

The stopping criterion for experiments shown in Table 2 was 20 iterations without improvement of the best individual. The number of reported iterations is the number after which the best individual was found. As it can be seen, increasing the number of parents in the crossover operator definitely reduces the number of required iterations (up to two times for 8 parents in this case) and what is more important, allows for obtaining better fitness values. However, when using more parents than the optimal number, the optimization again slows down and using more than 20 parents also the obtained route lengths are beginning to deteriorate (to increase). As discussed in Section 5.3, for large number of parents the HGreX operator behaves almost like the nearest neighbor search method, and also its performance tends to the same value.

**Table 2.** A sample route length *lengthR* and number of iterations *iterR* to obtain this length for a modified multi-parent HGreX crossover operator for an order of 60 products with fixed product placement (averages of 10 runs). *NN* in the last column denotes the result obtained for the Nearest Neighbor Algorithm.


The number of iterations used by the genetic algorithm with the HGreX crossover is comparable to that required by other modern crossover operators to find the shortest route for comparable population size [34]. Even if running the optimization for more iterations may find a little shorter route, there is usually no further gain for the product placement cost, as this only very rarely triggers the change of product locations. For the very rarely occurring orders it is enough to run the optimization for fewer epochs or to use Nearest Neighbor Algorithm, independently of the order length, because the quickest improvement occurs at the beginning of the optimization and the influence on the optimal product placement of very rare orders is also very low, as they are dominated by the more frequent orders.

Next we tested the usefulness of the multi-parent AEX crossover in product placement optimization. Since AEX does not use the cost of transitions between two elements, the parents for each element were chosen randomly. The results for a warehouse with 232 locations (chromosome length was 232 elements) are presented in Table 3. Based on the experiments we concluded that it is enough to use two-parent AEX crossover in product placement optimization, as increasing the number of parents did not cause any gain and if the number was 20 or more, the drop in the method effectiveness was observed.

**Table 3.** The product placement cost *costPP* as sum of order picking route lengths (the lower the better—see Equation 2) and number of iterations *iterP* to obtain this cost for a modified multi-parent AEX crossover operator for the warehouse size of 232 locations and a list of 80 orders, using an 8-parent HGreX for route optimization (averages of 10 runs).


Multiple restart of the product placement optimization with AEX crossover (MR-AEX) proved quite useful (see Algorithm 1). In the last row of Table 4 the optimization was restarted 5 times and each time it was run for 10 iterations and then we continued only with the population of the best individual, as described in Section 5.4.3. It allowed not only to obtain lower cost, but also the standard deviation of the results was about twice lower.

In the experiments presented in Table 4 we used the population size of 100 individuals for product placement optimization. For order picking route optimization we used a size of 100 individuals if the number of items was 25 or more and four times the number of items for shorter orders. The reason for choosing that population size is based on this fact, that the main cost of genetic algorithms is the evaluation of the fitness function (especially for product placement optimization). The number of the fitness function evaluations can be expressed by the multiplying population size by the number of epochs. Using larger populations, we can obtain the same results in fewer epochs. For smaller populations we also need to increase the mutation rate. However, the dependence between the population size and number of required epochs is not linear and there exists an optimal population size, which allows for the lowest number of fitness function evaluations [33]. The number also depends on the problem and on other parameters of genetic algorithms. In our experiments, the minimum was usually obtained for the population sizes between 80 and 120 individuals for the number of locations in the warehouse between 60 and 300 and then it very slowly grew, but much slower than linearly, with the increase of the warehouse. The dependence was very flat around the minimum (changing the population size e.g., from 80 to 100 individuals did not make a statistically significant difference in the number of required fitness function evaluations). However, outside of this range the dependence was more significant and for example using 1000 individuals allowed to decrease the number of epochs only about 3 times, what effectively increased the number of fitness function evaluations about 3-fold. Moreover, when the population was too small, not only the process time increased, but the process also began to be unstable and frequently was not able to converge. For this reason, a population size of 100 individuals was a safer choice than of, for example, 80 individuals.

We used the default mutation coefficients in Equation (4): *ci* = 0.00001, *cn* = 0.00001, *c <sup>f</sup>* = 0.3 as well in product placement as in route optimization.

Below we present some sample orders for the warehouse *w*3. The products in the orders are encoded by numbers, which are the products Ids (we cannot use letters as in the examples in previous sections, because there are not enough letters in the alphabet). The last number (*Nrep*) of each order shows how many times such order occurs in the order list, so its completion route length can be evaluated only once and then multiplied by *Nrep* while calculating the final product placement cost.

**order1**: 41 99 97 7 20 89 12 24 51 66 79 61 1 56 109 *Nrep* = 40 **order2**: 71 90 9 29 84 94 19 26 64 114 100 42 81 30 108 107 101 47 6 32 96 33 28 7 *Nrep* = 20 **order3**: 78 31 91 35 93 87 22 50 100 1 28 38 84 16 48 112 76 110 95 47 72 113 23 61 101 68 67 53 45 41 97 18 109 89 65 74 *Nrep* = 3 **order4**: 53 61 18 36 94 24 103 38 35 12 42 89 6 30 50 14 84 114 29 15 79 95 48 52 28 25 110 22 64 109 44 11 73 33 98 97 23 75 99 87 7 51 92 93 72 17 3 *Nrep* = 1

**Table 4.** The obtained product placement cost (the lower the better) as the sum of all order picking routes (see Equation (2)) for product placement and order picking route optimization methods with various improvements (see Section 5.4) for the six sample warehouses: w1, w2, w3, w4, w5, w6 with corresponding lists of orders, averaged over 10 optimization runs. The running times are presented in Table 5.


Table 4 presents the detailed results for six sample warehouse structures and order lists (this is the maximum number of warehouses, which can fit in one row of the table). MP-HGreX stands for Multi-Parent HGreX with 8 parents, MR-AEX is Multiple-Restart AEX with 5 restarts, saving the population afters 10 iterations, and then continuing with the population of the best individual (see Section 5.4.3 and Figure 2 for explanations). Table 5 presents the real running times of the optimization processes (including I/O operations and Dijkstra Algorithm).

**Table 5.** The real running time of the optimization processes in seconds using a computer with two Xeon X5-2696-v2 CPUs, averaged over 10 optimization runs for the experimental data presented in Table 4 and additionally for AEX/HGreX without cache (see Section 5.4.2).


The possible reduction of product placement cost depends on the character of the orders. The biggest improvement due to route optimization can be obtained for the orders containing long list of products. The highest improvement due to product placement optimization can be achieved when the orders frequently contain products of particular groups and the frequency with which particular products appear in the orders differs a lot. Thus the improvement possible to achieve is determined mostly by the properties of the orders. Thus, particular methods must be compared among each other for the same warehouse structure and for the same list of orders (This is similar, like in classification, where the possible accuracy depends on the dataset properties and various classifiers must be compared on the same data).

The *cost vs. rnd.* column in Table 4 contains the average relative reduction of the product placement cost calculated as *cost vs. rnd.* = *Average*(*Sum*(*F*1*p*(*w*)/*random*(*w*)), where *w* = 1...6 is the warehouse number. The last column contains statistical significance tests calculated on the whole data between two adjacent methods and therefore it is printed in-between rows of the compared methods. Since some persons prefer the T-test and others the Wilcoxon Signed Rank Test for this kind of data, we used both tests to satisfy everyone. As all the p-values in the last column of Table 4 are smaller than 0.05, it can be assumed that all the methods are significantly different from one another.

As can be seen in Figures 9 and 10, the best results are obtained for the multiple restart of genetic algorithms with AEX crossover operator (MR-AEX) for product placement optimization (see Section 5.4.3 and Algorithm 1) together genetic algorithm with multi-parent HGreX crossover operator (MP-HGreX) for order picking route optimization (see Figure 4).

**Figure 9.** Graphical representation of the data from Table 4). On horizontal axis: the optimization methods with various improvements (see Section 5.4). On vertical axis: the product placement cost *costPP* (the lower the better) obtained with particular methods for the warehouses w1, w2, w3, w4, w5, w6 with corresponding lists of orders.

**Figure 10.** Comparison of the performance of the presented methods with various improvements (see Section 5.4). On horizontal axis: the optimization method. On vertical axis: the average obtained product placement cost (the lower the better) as percentage of the cost with the random product placement and random routes over the six warehouses with order lists presented in Table 4

#### **7. Conclusions**

Shortening the time of order picking is the most important and most beneficial factor in reducing the costs of operating the warehouse (where typically 60% are the costs are generated by order picking [1]). It can be achieved without significant investment by optimizing the locations for particular products in a warehouse and then determining the fastest order completion routes. As the search space of the solutions is enormous (9.3 <sup>×</sup> 10157 possible placements of 100 products, 3.1 <sup>×</sup> 10614 of 300 products) the problem cannot be analyzed by brute force methods. Thus we presented a complete, fully automatic system based on genetic algorithms, which due to applying intelligent search allows to find the optimal product placements within minutes or tens of minutes for that size of problem (depending on the computer hardware and process parameters). Even though it is not guaranteed that the optimal solution will be found with genetic algorithms, it is possible to find a very close solution to the optimal one, so that in practice it will not make a significant difference.

The presented system takes as inputs the warehouse structure (in the form of partial transition costs) and the list of orders and returns the optimal product placement and corresponding shortest order picking routes. Implementation of such a system can accelerate order picking and thus reduce the warehouse operating costs. This allows to serve more customers by the same number of employees in the same time and thus to further increase the sales and profits.

The experiments showed that using the multi-parent HGreX crossover improves the results, while for the AEX crossover adding more parents does not change its efficiency. The best results were obtained for the multiple restart genetic algorithm with AEX crossover operator (MR-AEX) for the product placement optimization process together genetic algorithm with multi-parent HGreX crossover operator (MP-HGreX) for order picking route optimization. The cost or route length caching can be used to accelerate the process. Additionally the Nearest Neighbor Algorithm can be used for route optimization to even more accelerate the process, but this is usually at the expense of a little worse result.

In the future works we are planing to implement other modifications to further improve the speed of the optimization and the quality of the obtained solutions. First we want to evaluate new crossover operators and mixtures of various operators. In the experimental comparison of Puljic [19] the mix of different crossover operators performed slightly better than HGreX. Also more advanced mixes of genetic operators were proposed [35,36]. However, we did not decide to use this approach because the implementation was definitely more complex. Instead we modified the HGreX to use multiple parents, what significantly improved the results. Łapa et al. [37] proposed the use of different operators (not only crossovers but also different mutations and other operators) for different individuals in standard genetic algorithms. We are going to adjust these approaches to the route and product placement optimizations and investigate various options.

The other branch of our future research refers to constraints in the genetic algorithm operations as in some warehouses such constraints may exist and may limit the possible locations of particular products. In the literature the typical approach to constraints in genetic algorithms is the use of penalty functions [38]. Sometimes also dominance-based methods are used [39]. However, we are going to implement it differently by embedding the mechanism directly into the crossover and mutation operators specific to that problem in order to be able to enforce the constraint effectively and to limit the computational complexity of the optimization.

**Author Contributions:** Conceptualization, M.K. and S.G.; Formal analysis, M.B. and S.G.; Funding acquisition, M.B. and S.G.; Investigation, M.K. and J.B.; Methodology, M.K., S.G. and J.B.; Software, M.K. and J.B.; Validation, M.B. and S.G.; Visualization, M.B. and S.G.; Writing—original draft, M.K., J.B., M.B. and S.G. All authors have read and agreed to the published version of the manuscript.

**Funding:** This work was supported by the Silesian University of Technology project: BK-204/2020/RM4.

**Acknowledgments:** The authors want to thank Michał Krzyzowski, Łukasz Mysłajek, Antoni Kope´ ˙ c and Jakub Gaw ˛eda for their help in collecting and preparing the data used in this study.

**Conflicts of Interest:** The authors declare no conflict of interest. The founding sponsors had no role in the design of the study; in the collection, analyses, or interpretation of data; in the writing of the manuscript, and in the decision to publish the results.

### **References**


© 2020 by the authors. Licensee MDPI, Basel, Switzerland. This article is an open access article distributed under the terms and conditions of the Creative Commons Attribution (CC BY) license (http://creativecommons.org/licenses/by/4.0/).

*Article*

## **Application of Machine Learning Techniques to Delineate Homogeneous Climate Zones in River Basins of Pakistan for Hydro-Climatic Change Impact Studies**

### **Ammara Nusrat \*, Hamza Farooq Gabriel, Sajjad Haider, Shakil Ahmad, Muhammad Shahid and Saad Ahmed Jamal**

School of Civil and Environmental Engineering, National University of Sciences and Technology, Islamabad 44000, Pakistan; hamza.gabriel@nice.nust.edu.pk (H.F.G.); sajjadhaider@nice.nust.edu.pk (S.H.); shakilahmad@nice.nust.edu.pk (S.A.); m.shahid@nice.nust.edu.pk (M.S.); sahmed.be16igis@igis.nust.edu.pk (S.A.J.) **\*** Correspondence: ammara.phd14nice@student.nust.edu.pk; Tel.: +92-346-500-1107

Received: 15 August 2020; Accepted: 28 September 2020; Published: 1 October 2020

**Abstract:** Climatic data archives, including grid-based remote-sensing and general circulation model (GCM) data, are used to identify future climate change trends. The performances of climate models vary in regions with spatio-temporal climatic heterogeneities because of uncertainties in model equations, anthropogenic forcing or climate variability. Hence, GCMs should be selected from climatically homogeneous zones. This study presents a framework for selecting GCMs and detecting future climate change trends after regionalizing the Indus river sub-basins in three basic steps: (1) regionalization of large river basins, based on spatial climate homogeneities, for four seasons using different machine learning algorithms and daily gridded precipitation data for 1975–2004; (2) selection of GCMs in each homogeneous climate region based on performance to simulate past climate and its temporal distribution pattern; (3) detecting future precipitation change trends using projected data (2006–2099) from the selected model for two future scenarios. The comprehensive framework, subject to some limitations and assumptions, provides divisional boundaries for the climatic zones in the study area, suitable GCMs for climate change impact projections for adaptation studies and spatially mapped precipitation change trend projections for four seasons. Thus, the importance of machine learning techniques for different types of analyses and managing long-term data is highlighted.

**Keywords:** climate zone; climate change impact; Jhelum River Basin; Chenab River Basin

### **1. Introduction**

Climate change has emerged as a major driving force behind Pakistan's flooding over the past several decades, mainly by affecting snow melting and interfering with summer monsoon patterns [1]. Previous studies have also shown a growing interest in an accurate and calculated evaluation of climate change that can provide a rational system to adapt to a rapidly changing environment. Climate change impact analyses use long records of climate data from numerous gauging sites to estimate the historical and projected climate change trends. There is a growing institutional and social need to predict the potential impacts of climate change and the damage due to floods and droughts owing to spatio-temporal variability in the climate, specifically the rainfall. Previous studies have established that spatio-temporal dynamics in atmospheric covariates are strongly linked to the precipitation characteristics on a regional scale. The atmospheric circulation covariates are important in defining the precipitation patterns across the region [2]. Different atmospheric oscillations are responsible for this variability, which has been characterized by various climate models/simulations, including

several parameters related to atmospheric oscillation moisture and wind fluxes [3]. The driver of seasonal rainfall in southeast Asia is the Monsoon circulation [4] and increases in the probability and severity of extreme events (e.g., droughts, heatwaves and floods) have been documented in several studies [5,6]. The spatio-temporal variability of the hydrological cycle parameters affects the social and environmental sustainability [7]. Therefore, the patterns and projected trends of precipitation, which is the most important variable of the hydrological cycle varies spatially as well as temporally, must be studied.

Simulation results from general circulation models (GCMs) are used to study the implications of hydro-climatic change and are considered the most comprehensive and valid instruments, with the ability to reproduce climate variables in chronological order as well as yield future projections using scenarios [8–11] recommended by the intergovernmental panel on climate change (IPCC). Numerous outputs from GCMs are available to gain insight into climate processes and the impact of various scenarios on these processes [12,13]. The question arises as to how to use these climate simulations to obtain meteorological inputs for impact models to characterize all the uncertainties involved in the models. Previous studies have reported attempts to investigate the suitability of climate models when simulating different climate parameters, such as precipitation [14]. spatio-temporal distribution by the GCM is an important aspect that must be considered for the selection of the model [15]. Further discussion regarding the selection of appropriate GCMs is presented in Section 3.2.

Climate model simulations should have adequate precision to assimilate crucial information in the observational data, which forms the basis to forecast and more efficiently predict climatic conditions. The sources of uncertainty in climate simulations and predictions are (1) model equations parameterization and initial conditions associated with chaotic systems [16]; (2) the variability in the predicted scenario caused by anthropogenic forcing, which is unknown in the future; and (3) climate variability [17]. To account for these uncertainties, multiple GCM assessment studies analyze model efficiency at simulating the dynamic and thermodynamic variables and the environmental variables related to atmospheric chemistry [4,18,19]. The common trend among the climate research community, however, is the statistical performance evaluation of GCM using climate output variables from its simulations [7,19–23]. These climate variables are averaged temporally and spatially on various scales for the assessment of the GCM's ability to either simulate historical data or to present the range of climate output variables to detect a plausible future [22,24].The majority of assessment studies of climate change used the following two approaches to evaluate the climate model skill to represent past observational climate data:


identifying the climatologically homogeneous zones and then selecting the GCM in each zone through past performance assessment. These homogeneous sub-regions are based on similar spatial and temporal climate patterns, as depicted by the highly resolved observation data. Following the philosophy of McSweeny et al. [4] for the regional suitability of the GCMs, we selected the GCMs in every homogeneous precipitation region and validated our selection through the comparative analysis with Simple Composite Method (SCM). The SCM is broadly used for generating the multi-model ensemble that is, to obtain the equally weighted mean of all the ensemble members data at each grid point [7].

With the support of machine learning techniques, long records of climate data, from numerous gauging sites and web sources, can be easily analyzed and used to determine historical and projected trends of climate change. In the present study, the modules using machine learning techniques were developed throughout the various steps of the framework. The framework comprises five important steps: (1) developing homogeneous precipitation zones using highly resolved observation data; (2) the selection of the best suitable GCM for each zone based on the estimated correlation coefficient using the GCM data and highly resolved observed data in every homogeneous climatic region; (3) comparison and validation of the selected GCM; (4) sampling the daily precipitation data for the projected period (2006–2099); based on the forcing scenarios and (5) performing future precipitation trend detections. The outcomes of this framework are the estimated divisional boundaries of the climate zones, suitable GCMs for the climate impact studies and seasonal rainfall trend projections. It is necessary to mention that the results obtained using the present framework are subject to specific assumptions: (1) daily and 0.25◦ temporal and spatial scales of the data have been used for climate zoning (2) seasonal and 0.25◦, temporal and spatial scales of the data have been used for the climate projection trends (3) uncertainty of the results depends on the uncertainty involved in the climate variable outputs from the GCMs. It is very important to understand the limitations and uncertainties involved in the simulation outputs of different GCMs. The present study does not aim to quantify the uncertainties present in the GCMs' simulations but only uses the GCMs' outputs for their selection in each climate zone and trend projections. Nonetheless, the framework can be replicated, using other versions of ensembles of GMCs' with finer resolution and accuracy. The scientific data associated with the present framework can be shared on request.

This paper is structured as follows—Section 2 describes the study area, stations and data details; Section 3 provides a step-by-step presentation of the methodology; Section 4 presents the results and discussion; and Section 5 presents the conclusions.

### **2. Study Area and Data**

### *2.1. Study Area*

The basin areas of the Jhelum and Chenab Rivers are 33,330 and 67,515 km2, respectively. The elevations in the two basins vary between 146 and 6915 m. Figure 1 shows a topographic map with APHRODITE grid points of the basins of the Jhelum and Chenab river in Pakistan. The Southwest Monsoon triggers approximately 40% of the annual precipitation, whereas the remaining ~60% of precipitation is due to Westerly aggravations [24]. Figure 2 presents the average annual seasonal precipitation in various parts of the two basins. The spatially and temporally variable climatic conditions, atmospheric circulation, advected moisture and topography, with high elevations in the north and flat plains in the south, demand a flexible framework for the evaluation of the climate change effects and future predictions. Diminishing water resources due to high hydro-climatic variability and extreme climate events in Pakistan are some of the factors that inspired the research community to investigate sustainable solutions after analyzing future needs and availability.

**Figure 1.** Topography of the study area with the grid points/stations and observed gauging station used in the study.

**Figure 2.** The distribution of average total seasonal precipitation based on APHRODITE (1970–2004) for (**a**) warm-wet (**b**) cold-dry (**c**) cold-wet and (**d**) warm-dry seasons.

#### *2.2. APHRODITE Data*

To delineate the climate zones in the study area, a dense network of data stations was required. There has been an increasing trend of using easily available gridded precipitation datasets for hydrologic and climatic assessments [38–40]. Previous studies have identified the performance of the APHRODITE dataset (version V1101) as the best-gridded product over the high mountainous regions of Asia [41]. To justify the use of the gridded dataset of APHRODITE, we performed the comparative analysis of the APHRODITE dataset and European Reanalysis gridded dataset (ERA5) [42] and Global Metrological forcing dataset for land surface Modeling (GMFD) [43] with the observed dataset and monthly temporal scale, at 11 gauging stations at different altitudes, as shown in Figure 1. ERA5 is also popular for its accuracy. The results of the comparative analysis have been shown in Table 1. The Pearson correlation coefficient and Kolmogorov Simirnov (KS) test (the method is discussed in Section 3.3.2) were used to compare the monthly precipitation datasets. The results tend to show agreement with the APHRODITE dataset with higher correlation coefficients at all the stations. For the KS test, *p*-values greater than 0.05 have been shaded, depicting the null hypothesis of similar probability distribution is not rejected. The results of comparative analysis suggests that the APHRODITE datasets have higher correlation with the observed data that is, more than 0.8 in nearly all the observed gauging station as shown in Table 1, so the APHRODITE dataset [40] of 0.25◦ × 0.25◦ spatial resolution for the period of 1970–2005 has been used at a daily temporal scale, for the delineation of climate zones. Nevertheless, the regionalization framework is equally applicable to the other datasets having comparatively higher resolution and accuracy.

**Table 1.** Comparative analysis of APHRODITE and ERA5 monthly dataset. Pearson correlation coefficients and Kolmogorov Smirnov Test results (KS Test). (The shaded *p*-Values are >0.05, depicting the null hypothesis of similar distribution, is not rejected).


Figure 2 presents the total seasonal average precipitation sampled from APHRODITE. The daily gridded precipitation long term data from 1951 and onwards, is provided by the APHRODITE product (V1101). It provides the data at the continental scale, including a dense rain gauge data network of Asia, comprising the Middle East, South and Southeast Asia with valid stations of 5000 to 12,000 [40].

### *2.3. NEX-GDDP-GCMs-CMIP5 Data*

For future climate change analysis, this study used newly developed NASA Earth Exchange Global Daily Downscaled Projections (NEX-GDDP) dataset. The experiments included in Coupled Model Intercomparison Project 5 (CMIP5) were devised to address the research interrogations in AR4 of the IPCC [44]. The original resolution of most GCMs in the CMIP5 is >100 km, where such coarse resolution cannot provide fine-scale information for impact studies and decision-making at a local scale. The NEX-GDDP dataset is a collection of 21 GCMs. These datasets are bias-corrected and downscaled to a finer resolution (0.25◦ × 0.25◦) using the bias-corrected spatial disaggregation (BCSD) method [45]. The Global Meteorological Forcing Dataset (GMFD), developed by Princeton University, provides the gridded observed climate data and the NEX-GDDP after bias correction and downscaling [46,47]. Table S1 lists the 21 GCMs in the NEX-GDDP, along with information on the research centers that produce the GCMs.

NEX-GDDP data consists of daily precipitation and minimum and maximum temperatures from historical periods (1950–2006) and future projections (2006–2099). The future projections are available for two representative concentration pathways (RCPs), which are global greenhouse gas emissions scenarios. These scenarios were employed by the IPCC fifth assessment report (AR5) [44]. The NEX-GDDP dataset is publicly available. The comprehensive details to use these scenarios dataset for climate change impact evaluations and adaptations have been provided by IPCC [48]. The CMIP5 data provides multi-model datasets of climate variability, which was used to develop AR5 [44].

### **3. Methodology**

Machine learning algorithms are becoming more popular owing to their ability to identify the patterns and variances in large datasets composed of multivariate atmospheric covariates, location parameters, meteo-climatic information, tele-connection indices and attributes that influence precipitation [2,49,50] or the precipitation and temperature statistics [51,52]. In this study, we used the Python module scikit-learn [53], which integrates a broad range of machine learning algorithms for supervised as well as unsupervised learning. As aforementioned, this framework consists of five steps: delineation of homogeneous climate zones; selection of GCMs; validation of selected GCMs; selection of forcing Scenarios to be used for the predictions of the climate change; and climate change projections and trend detection. Figure 3 presents the methodology flow diagram adopted in this study. Further details of each step are discussed in the subsequent sections.

**Figure 3.** Flow chart of the methodology employed in this study.

#### *3.1. Climate Zoning*

Climate zoning/regionalization is the most important step in this framework and the basis for the subsequent steps. For regionalization, solid evaluations of different precipitation attributes require accessibility to long records of chronicled estimations at various stations inside an area [2]. For climate zoning, 35 years' (2006–2075) daily precipitation data of APHRODITE were used. These data were resampled seasonally to develop the climate zones for each of the seasons (defined in the subsequent sections). Regionalization was performed by grouping the rain gauges with homogeneous precipitation statistics in an area [54]. This step enables partitioning of the entire area into several climate zones for every season.

The regionalization of the precipitation statistics has numerous applications in various water resource management fields, such as agriculture practices, spatial and temporal rainfall patterns, hydrological analysis, extreme event forecasting [55] and watershed management. There are several methods available to delineate climate regions, for example, subjective and objective partitioning, geographical convenience and multivariate analysis [54,56,57]. The arbitrary and slightly misleading demarcation approach for a region, which is based on administrative boundaries and physical and geographical groupings, is referred to as the geographical conveniencemethod. The subjective partitioning method is based on homogeneous statistical characteristics of the rainfall and previous knowledge of the region. Objective partitioning demarcates a region by grouping the sites of similar climates. Principal component analysis (PCA), clustering techniques and correlation analysis are examples of multivariate

analysis techniques, which are widely used for climate zone delineation [51,58,59]. These analysis techniques require the development of a large matrix that defines the characteristics of the climate in the region. We adopted the method of PCA and agglomerative hierarchical clustering (AHC) to group the sites of homogeneous precipitation. These groups of stations were validated using different cluster validity indices.

### 3.1.1. Seasonal Data Resampling

The daily precipitation data in the APHRODITE time series, at 138 Grid stations in the study area from 1975–2005, were resampled as hydrological seasons: warm-wet ("July, August and September"), cold-dry (October, November and December), cold-wet (January, February and March) and warm-dry (April, May and June). The daily precipitation data of 35 years were resampled in a seasonal cycle using simple Python code.

### 3.1.2. Principal Component Analysis (PCA)

The benefits associated with the PCA are as follows—(1) identification of spatially coherent variations that can improve the signals, where the leading components represent the maximum variance in the patterns; (2) explanation of the covariance between the variables at various places; and (3) reduction of dimensionality, thereby providing resources to describe the variability in large spatially dimensional datasets with a reduced number of principal components [60]. PCA can present all the data in a smaller matrix and attempts to retain as much information about the data as possible. We used the dimensions from 138 APHRODITE grid point datasets for PCA. The aim was to project the data onto different orthogonal axes known as principal components. A linear transformation is performed using a symmetric covariance matrix, developed from the dataset, into the principal orthogonal components. The direction of principal components is represented by eigenvectors and the eigenvalues represent the magnitude of the stretch of the axis. This gives the direction (eigenvector) and magnitude (eigenvalue) of the spread of the dataset. The leading principal component, which presents the maximum spread of the dataset, is used for further analysis. Several principal components were selected for use in hierarchical clustering using the scree plot, which shows the variance explained in percentage by each of the principal components. Each principal component has component scores for every station, based on the eigenvectors and eigenvalues. These component scores depict the climate change pattern at the respective grid point/station and can be treated as metrological parameters that are stochastically independent of each other [61]. This analysis has been performed using the PCA function of scikit-learn [53].

### 3.1.3. Agglomerative Hierarchical Clustering

The purpose of this step is to obtain the groups of sites/ stations with similar climate change patterns depicted by the component scores of the leading principal components obtained in the previous step. This is achieved using different clustering algorithms. Different behaviors in the clustering algorithm can be expected based on the data features, dimensions and input variable values. Traditionally, different clustering algorithms are used in previous studies; there is extensive literature regarding these clustering techniques [51,61–63]. Recently, the agglomerative method of hierarchical clustering [61,64] has received increased attention in climate literature. In this method, smaller clusters are developed according to a bottom-up approach, followed by a sequential combination with larger clusters depending on the Euclidian distance [65] between the clusters.

In the algorithm, each evaluation (climate change pattern) was allocated to its self-cluster. Subsequently, the iterations in the algorithm were identified and joined with the closest cluster. This agglomeration continued until the formation of one cluster. A tree-like structure (dendrogram) of similarity was formed because of the hierarchical clustering, which provides meaningful information regarding the correlations/distance among different clusters. The validity of the number of clusters (CN)

was determined prior to the identification of the maximum valid Euclidian distance. The agglomerative clustering algorithm of scikit-learn [53] was utilized for grouping the sites for each season, in this study.

### 3.1.4. Optimal Clustering and Climate zone Formation

In this step, the optimal CN of stations for each season is identified. Optimal clustering validity methods must be employed for the real grouping of datasets. Comprehensive studies have been performed for the comparative analyses of the cluster validity indices. The silhouette index (*S*) [66], S Dbw index (S Dbw) [67] and Calinski–Harabasz index (*CH*) [68] have been used in this study following the recommendations on these comparative evaluations [67,69].

Using the average of the three values identified through the aforementioned validity indices, the optimal CN was identified. The numbering and identification of stations in each cluster were performed via the truncation of the dendrogram. The truncation (cut-off) bar is kept at the valid maximum Euclidian distance corresponding to the validated value of CN [20]. This was performed using a supervised learning algorithm in scikit-learn [53].

### Silhouette Score

The silhouette score method [66] is employed in the analysis to obtain the optimum CN. The silhouette score is the mean of the distance between clusters. In this study, the silhouette score yielded a maximum of 14 clusters for climate zoning for all seasons. The CN corresponding to the maximum silhouette score is the basis for the decision. The Silhouette score (*S*) is obtained as

$$S = \frac{1}{\text{CN}} \sum\_{i} \frac{1}{n\_i} \sum\_{r \in \mathcal{C}\_i} \frac{b(r) - a(r)}{\max[b(r), a(r)]} \tag{1}$$

and

$$a(r) = \frac{1}{n\_i - 1} \sum\_{s \in \mathbb{C}\_i, s \neq r} d(r, s), \\ b(r) = \min\_{j, j \neq i} \left| \frac{1}{n\_j} \sum\_{y \in \mathbb{C}\_j} d(r, s) \right| \tag{2}$$

where the number of clusters is denoted by CN; *Ci* represents the *i*th cluster; *ni* represents the number of objects in *Ci*; *ci* denotes the center of *Ci*; and *d*(r,s) is the distance between r and s [66].

### S Dbw Validity Index

S Dbw considers the inter-cluster density to calculate inter-cluster separation [67]. One of the densities of each cluster pair must be greater than the midpoint density for the cluster centers. This index is the summation of the scatter in the clusters and the density between the clusters. Previous studies have investigated the validation properties of different cluster validity indices using synthetic experimental data [70]. They conclude that S Dbw performed best concerning the aspects of noise, monotonicity, density, skewed distribution and sub-clusters. Equations (3)–(5) can be used to calculate the density between clusters, the scatter in the clusters and the S Dbw, respectively:

$$Dens\\_bw(\text{CN}) = \frac{1}{\text{CN}(\text{CN}-1)} \sum\_{i=1}^{\text{CN}} \left| \sum\_{j=1, i \neq j}^{\text{CN}} \frac{\sum\_{\mathbf{x}\_i \in \mathbf{C}\_i \cup \mathbf{C}\_j} f\left(\mathbf{x}\_i, \boldsymbol{\mu}\_{ij}\right)}{\max\left\{\sum\_{\mathbf{x} \in \mathbf{C}\_i} f\left(\mathbf{x}, \mathbf{c}\_i\right), \sum\_{\mathbf{x} \in \mathbf{C}\_j} f\left(\mathbf{x}, \mathbf{c}\_j\right)\right\}} \right| \tag{3}$$

$$\text{Cost}(\text{CN}) = \frac{1}{\text{CN}} \sum\_{i=1}^{\text{CN}} \frac{\|\sigma(\text{C}\_i)\|}{\|\sigma(D)\|},\\\text{Dis}(\text{CN}) = \frac{\max\_{i,j} d\left(\mathbf{c}\_i, \mathbf{c}\_j\right)}{\min\_{i,j} d\left(\mathbf{c}\_i, \mathbf{c}\_j\right)} \sum\_{i=1}^{\text{CN}} \left(\sum\_{j=1}^{\text{CN}} d\left(\mathbf{c}\_i, \mathbf{c}\_j\right)\right)^{-1} \tag{4}$$

$$S\_{\text{\\_}}Dbw = \text{Sect(CN)} \; + \; Dens\\_bw(\text{CN}), \tag{5}$$

where *D* is the dataset; *CN* is the number of clusters; σ(*Ci*) is the variance vector of *Ci*, σ(*D*) is the variance in the dataset, (*ci*, *cj*) is the distance between *ci* and *cj* and *x* <sup>=</sup> *xTx* 1/2.

### Calinski–Harabasz Index

This cluster validation scheme depends on the average of the cluster sum of squares and the average of the sum of squares within clusters and is known as CH [68]. It is obtained by formula as shown in Equation (6)

$$CH = \frac{\sum\_{i} n\_i d^2(c\_i, c) / (\text{CN} - 1)}{\sum\_{i} \sum\_{x \in \mathbb{C}\_i} d^2(x, c\_i) / (n - \text{CN})}.\tag{6}$$

Here, *n* represents the number of points/objects in *D*, *c* denotes the center of *D* and d (*x*, *ci*) denotes the distance between *x* and *ci*.

### 3.1.5. Climate Zone Polygon Formation

The realistic partition of climatic stations for coherent climate regions is a result of this method. The delineation of these regions depends on the similarity among the statistical climate characteristics of all the stations present in a region.

The clusters of stations for every season were plotted using the ArcGIS (Environmental Systems Research Institute, CA, USA) interface. After the identification of a group of stations in the different clusters, approximate cluster boundaries were drawn on the maps using the Arc-map tool to enable a clearer presentation.

### *3.2. GCM Selection*

The criteria used to select the GCMs are the availability of the latest generation of GCMs, good spatial resolution, past performance of GCM to replicate the historical data and the representativeness of the GCM for a wide range of climatic variable (precipitation) projections [7]. The most commonly adopted criteria are the assessment and selection of GCMs based on their capability to simulate the historical and present climate, which have been adopted by numerous studies [8–10]. An efficient and sensible selection of GCMs is required to generate reliable and diverse meteorological inputs for the impact models. Most previous studies show that a past performance assessment is among the most effective methods for the selection of GCMs, as the GCMs thus selected can be better predictors of future climatic conditions [31].

Several selection methods have been reported in the literature. Aghakhani et al. [71] extracted GCM data at four points in the vicinity of observation stations and performed a comparative analysis between the GCM and the observed data. They used the averages of the time-series data at each station for their analysis. Xuan et al. [72] selected GCMs using an average of the climate parameters for the entire study area in the Zhejiang Province of Southeast China. Najeebullah et al. [32] ranked the GCMs at every grid point after re-gridding the GCM data at the same spatial resolution, comparing the data with the gridded data product of the Asian Precipitation - Highly-Resolved Observational Data Integration Towards Evaluation (APHRODITE) for GCM performance evaluations in all of Pakistan. When comparing the best GCM with the highly ranked GCM at each grid station, they found a significant difference between the two with respect to precipitation. Maxino et al. [15] analyzed the climate models from the Fourth Assessment Report (AR4) of the United Nations Intergovernmental Panel on Climate Change (IPCC) for two regions in the Murray Darling Basin, divided based on rainfall classification. For different climate variables, they demonstrated that the models that are flawed in one region may be better for another region; model selection should be area specific. They calculated the skill scores associated with the different climate models in the study region. Lutz et al. [22] averaged the climate data over 2.5◦ × 2.5◦ grid cells in three river basins (i.e., the Indus, Brahmaputra and the Ganges) and presented their model with the selected GCMs. Latif et al. [73] also suggested certain

Coupled Modelled Intercomparison Project 5 (CMIP5) GCMs after evaluating their performance using seasonal rainfall spatial correlations in the Indo-Pak region. Ahmed et al. [7] evaluated the spatial accuracy of the CMIP5 GCMs using different spatial metrics for all of Pakistan. They performed the analysis after re-gridding the GCM data into a 2◦ × 2◦ grid size. Srinivasa et al. [31] ranked the GCMs for the maximum and minimum temperatures across India. This framework comprised the evaluation of GCM suitability at every grid point and then ranking the GCMs using compromised programming.

In the present study, after defining the clusters of stations with homogeneous precipitation, one representative station was selected in every climate zone, using quota sampling [61]. The selection of the representative station was done based on the average climate signal climate signals (Component Scores) in a respective climate zone. The comparative analysis of GCMs data and APHRODITE data at the representative station of each climate zone was accomplished using the coefficient of determination (R2).

A reduced correlation existed within the daily outputs of the GCMs and the in situ data. However, better correlations were obtained when seasonal data in the GCMs were compared with the seasonal observation data. Therefore, using the seasonal data, the correlations between the GCM data and observed gridded data were evaluated at each representative station in the respective climate zone.

### *3.3. Validation of Selected GCMs*

This is the important step of comparative analysis of the climatic data generated through the present selection method of the GCMs and the contemporary method of sampling through simple mean based multi-model Ensemble (MME) data. The KS test [74], a nonparametric test, was used to determine the validity of the selected GCM by detecting the changes in the distribution of the two datasets (dataset generated through present selection method and simple mean based MME data) with the APHRODITE. For every grid station, the observational precipitation data (APHRODITE), selected GCM simulated data of precipitation and multi-model ensemble(MME) mean daily precipitation data of all the 21 GCMs of NASA Earth Exchange Global Daily Downscaled Projections (NEX-GDDP) CMIP5 for the same baseline period from 1970–2005 were used in this test.

#### 3.3.1. Seasonal Data Sampling

The daily precipitation data of the selected GCM, corresponding to each homogeneous climate zone, were seasonally resampled for every grid station (i.e., the grid station within each grid cell). The multi-model ensemble mean data were generated by resampling the data as the average of the GCM daily precipitation data at every grid station.

### 3.3.2. Kolmogorov–Smirnov Test

In this study, the KS test stats were derived as the highest vertical difference between the two cumulative distribution functions (CDFs) of a time series data. For CDF of APHRODITE data *øn*(*x*) and the CDF of the selected GCM or MME mean at a grid station *ø* (*x*), the KS test statistics were obtained as given in Equations (7) and (8).

$$S\_n = \sup\_{\mathbf{x}} \left| o\_n(\mathbf{x}) - o(\mathbf{x}) \right|,\tag{7}$$

where

$$\rho\_n(\mathbf{x}) = \frac{1}{n} \sum\_{i=1}^n I\_{\mathbf{x}\_i \le \mathbf{x}}.\tag{8}$$

Here, *supx* is the upper bound of the set of distances. The indicator function is represented by *Ixi*≤*<sup>x</sup>* (If *xi* ≤ *x*, *Ixi*≤*<sup>x</sup>* equal to 1 or if *xi* > *x Ixi*≤*<sup>x</sup>* equal to 0). The null hypothesis of a similar distribution is rejected if *Sn* ≥ 0.05 significance level. The KS test was applied at every grid station using scikit-learn.

### *3.4. Selection of Forcing Scenarios*

Four climate forcing scenarios are commonly used in climate research for climate simulation studies on both long term and short term scales. There are four common Representative Concentration Pathways (RCPs)—RCP 2.6 is a mitigation scenario; RCP 4.5 and RCP 6.0 are scenarios of medium stabilization and RCP 8.5 is a high baseline emissions scenario [75]. We did not include RCP 2.6 for the ensemble of climate models, as it is not considered a realistic or practical scenario necessary to promote adaptation planning. These rest of the three scenarios represent the complete extent of radiative forcing. In this study, we used the RCP 4.5 and RCP 8.5. The same framework applies to other forcing scenarios [22].

### Extraction of Selected GCMs Data for RCP 4.5 and 8.5

The daily precipitation data for the selected GCM in the specific climate zone for the projected period from 2006–2099, as well as forcing scenarios RCP 4.5 and 8.5, were sampled for all seasons at every grid station.

### *3.5. Climate Change Projections and Trend Detection*

The highly correlated GCMs, which can replicate the gridded dataset of seasonal precipitation for the period from 1970–2005, were selected in each climate zone for use in the future predictions of precipitation change trends in RCP 4.5 and RCP 8.5. The precipitation change trends were detected using the Mann–Kendall (MK) test during the projected period (2006–2099). The trend was quantified by employing Sen's slope estimator.

### 3.5.1. Mann-Kendall Test (MK Test)

The MK test [76] was used to evaluate the statistically significant annual seasonal trends in the long term precipitation data. The "no trend" in precipitation with time is assumed in the null hypothesis (Ho) of the test and vice versa for the alternative hypothesis (Ha). The Equations (9)–(12) show the test stats T of the MK test.

$$T = \sum\_{i=1}^{n-1} \sum\_{j=i+1}^{n} \text{sig}(D\_j - D\_i) \tag{9}$$

$$\text{sgn}(D\_j - D\_i) = \begin{cases} +1 & \text{if } \left(D\_j - D\_i\right) > 0 \\ 0 & \text{if } \left(D\_j - D\_i\right) = 0 \\ -1 & \text{if } \left(D\_j - D\_i\right) < 0 \end{cases} \tag{10}$$

$$\sigma(T) = \frac{1}{18} \left[ n(n-1)(2n+5) - \sum\_{p=1}^{q} t\_p \binom{t\_p - 1}{p} (2t\_p + 5) \right] \tag{11}$$

$$Z = \begin{cases} \begin{array}{c} \frac{T-1}{\sqrt{\sigma\_{-}(T)}} \end{array} & \text{if } T > 0\\ 0 & \text{if } T = 0\\ \begin{array}{c} \frac{T-1}{\sqrt{\sigma\_{-}(T)}} \end{array} & \text{if } T < 0 \end{cases} \end{cases} \tag{12}$$

where *Di* and *Dj* are the consecutive time-series observations, which are organized chronologically for the length of data (n); *tp* represents the data points number for the pth value in a tied group; and q denotes the total number of tied groups σ is the variance. An upward time series trend is represented by a positive *Z* value and vice versa for the negative trend. If |*Z*| > *Z*1−<sup>α</sup>/2, the null hypothesis (Ho) is rejected which is indicative of a statistically significant trend. The critical value of *Z*1−α/2 is assumed at the *p*-value of 0.05.

### 3.5.2. Sen's Slope Evaluation

The Python module was applied to obtain the slope of the trend, if exists, according to Sen's method [77]. According to the method, linear slope sets can be estimated as shown in Equation (13).

$$T\_i = \frac{D\_j - D\_k}{j - k} \text{ for } (1 \le i < j \le n), \tag{13}$$

where Ti is the slope, *Dj* and *Dk* are the variables at *j* and *k* time steps, respectively and n represents the data points number. Sen's slope is estimated as the median of all slopes.

### **4. Results and Discussion**

Taking the spatiotemporal average [22,25,73] of the climate dataset at all the grid points of the area at different temporal scales (daily, seasonal or annual) or using various spatial metrics on the results of the individual grid points [21,31,32] are the conventional methods used for the selection of GCMs. However, these methods do not address the spatial coherence of the climate variability patterns in the study area. This poses uncertainty in using these conventional methods. Several studies support the association between atmospheric covariates and precipitation. In GCMs, these atmospheric covariates are the input variables. To address the uncertainty arising from the input variables of the GCMs, it is imperative to assess the GCM performance in reproducing the regional climate pattern. These results show that the climate variability pattern of the study area is spatially heterogeneous in all seasons and a single or an ensemble of GCMs may not represent this spatial climatic heterogeneity. Therefore, we classified/ regionalized the large study in several climate zones founded on the climate variability patterns similarity and the selection of a GCM should be done separately in each of the homogeneous climate zones.

This section presents the detailed results along with the discussion. The results have been visualized using ArcGIS. This software is based on the integration of different machine learning techniques. Section 4.1 describes the results of the analyses performed for the development of climate zones. Section 4.2 shows the results of the GCM selection procedure and Section 4.3 is related to the validation of the method used for the GCM selection. Section 4.4 described the results of the MK test of trend detection and Sen's slope estimation. These tests determined the precipitation change trend projections for the period of 2006 to 2099.

### *4.1. Climate Zones*

The Principal Component Analysis (PCA) performed as the first step in climate zoning provided the details of the climate change pattern in the entire study area. A statistically significant climate pattern heterogeneity was observed in the region when PCA was performed on a seasonal basis.

#### 4.1.1. Principal Component Analysis

The application of the Principal Component Analysis (PCA) was performed on the daily precipitation series data from APHRODITE on a network of 138 stations distributed across the Jhelum and Chenab river basins. This first led to 20 primary and physically significant principle components (PCs), which collectively explained approximately 95% of the total variability in the precipitation throughout the study area. These were retained for the cluster analysis. Scree Plot is shown in Figure 4. The PCA algorithm produced the scree plots, allowing us to select the number of PCs for the cluster analysis. For the warm-wet season, 20 PCs present 95% of the data variance from the station network (Figure 4a). For the cold-dry season, 20 PCs present 94% of the variance in the data from the station network (Figure 4b). For the cold-wet season, 20 PCs explained 95% of the variance in

the data from the station network (Figure 4c). For the warm-dry season, 20 PCs explained 95% of the variance in the data from the station network (Figure 4d). The spatial distribution of the component scores is shown in Figure 5, for the first two PCs, explaining the highest percentage of variance in the data.

**Figure 4.** Scree and cumulative plots showing the percentage of explained variance by each principal component (PC) in the (**a**) warm-wet, (**b**) cold-dry, (**c**) cold-wet and (**d**) cold-dry seasons. The red line shows the explained variance corresponding to 20 principal components.

**Figure 5.** The spatial patterns of the first two PCs for the daily seasonal precipitation across the study area. The explained variance % is presented on each panel. The component scores are presented by the color distributions across the study area (see color bar for reference): (**a**) warm-wet (PC-1), (**b**) cold-dry (PC-1), (**c**) cold-wet (PC-1), (**d**) warm-dry (PC-1), (**e**) warm-wet (PC-2), (**f**) cold-dry (PC-2), (**g**) cold-wet (PC-2) and (**h**) warm-dry (PC-2) seasons.

These two PCs cumulatively produced variance of 44.8% in the warm-wet season, 41.2% in the cold-dry season, 45.7% in the cold-wet season and 41.5% in the warm-dry season. For the component scores, the first PC for the warm-wet season, as shown in Figure 5a (with 25% of the explained variance), has high negative scores in the north and southwest of the study region and medium-high positive signals in the central region. Stronger positive and negative signals exhibit an affinity of higher variation in the precipitation amount, while weaker scores indicate low variability. The second PC for the warm-wet season, as shown in Figure 5e (with 19.8% of the explained variance), has mostly positive component scores in the southeast, except for medium and strong negative signals in several southeast and central regions, respectively. In the cold-dry season, we observe higher negative variability in the southwest based on PC1, as shown in Figure 5b and in small areas of the southeast. Strong positive

signals were detected in the northwest area of the study region. For PC2, high positive component scores were observed in the southeast, as shown in Figure 5f (with 18.8% of the explained variance), whereas strong negative scores occurred in the southwest area. In the cold-wet season, the variability in the precipitation amount was high in the southwest region for PC1 (with 27% of the explained variance in the data), as shown in Figure 5c and high positive signals in the southeast and central areas of the region. For PC2, cold-wet season (with 18.7% of the explained variance), shown in Figure 5g, had strong positive scores in the southeast area of the region but medium to high negative scores in the rest of the study area. In the warm-dry season, high variability was observed in the precipitation in the northern and southwest regions of the area for the PC1 component score (with 22.7% of the explained variance), as shown in Figure 5d. The PC2 had high positive variability in the southwest and medium negative signals in the rest of the region, as shown in Figure 5h. The first two PCs were noted to explain a significant variability for the precipitation data in all seasons.

### 4.1.2. Agglomerative Hierarchical Clustering (AHC)

We then identified all the potentially homogeneous regions inside the study area. The extent of the homogeneity in these homogeneous regions was identified using different cluster validity indices. The first 20 PCs were employed in the AHC analysis. These PCs were obtained in the previous step using the Python scikit module. The validity of different site/station clusters was evaluated by performing the silhouette score, CH and S Dbw tests. Figure 6 summarizes the cluster validation results. The machine learning algorithms for the cluster validity were run repetitively using different input values for the maximum Euclidean distance, followed by a comparison of the resulting clusters to obtain their validity. According to the results, the optimized partitioning of the data based on the CH Test was investigated while the lowest scores were estimated corresponding to clusters 13, 15, 14 and 14 for the warm-wet, cold-dry, cold-wet and warm-dry seasons, respectively, as shown in Figure 6a. The partitioning based on the S Dbw test suggests that optimized clustering can be performed when the data is partitioned into clusters 13, 13, 14 and 13 for the warm-wet, cold-dry, cold-wet and warm-dry seasons, respectively, as shown in Figure 6b. These clusters were selected based on the lowest S Dbw scores. The partitioning of data into clusters 11, 15, 13 and 13 corresponds to the highest silhouette test scores for the warm-wet, cold-dry, cold-wet and warm-dry seasons, respectively, as shown in Figure 6c.

**Figure 6.** The cluster validity indices/scores for cold-dry (cd), cold-wet (cw), warm-wet (ww) and warm dry (wd) seasons (**a**) Calinski–Harabasz score, the number of clusters valid are marked corresponding to the minimum score in each season (**b**) S Dbw score, the number of clusters valid is marked corresponding to minimum score in each season and (**c**) Silhouette Score the number of clusters valid is marked corresponding to maximum score in each season.

Based on Figure 6, all validity measures agree with the same number of clusters. Therefore, based on the recommendations of previous studies [70], we partitioned the stations for optimized clustering as clusters 13, 13, 14 and 13 for the warm-wet, cold-dry, cold-wet and warm-dry seasons, respectively. The members/stations in different clusters were obtained via the truncation of the dendrogram at the Euclidean distances that correspond to the optimized clusters. The dendrograms were obtained from

the AHC of the selected PCs. The maximum Euclidean distances were identified as corresponding to the proposed cluster numbers, where Figure 7 shows the cut-off/truncation bars for each season. Based on this truncation, we identified the groups of stations/sites in each cluster. Figure 7 shows the dendrograms that present the cluster trees for all four seasons.

**Figure 7.** Dendrograms with cut-off bars to group the stations in clusters as a function of the Euclidean distance, the numbers in bracket indicate the number of branches (group of station) in a truncated branch of Dendogram: (**a**) warm-wet, (**b**) cold-dry, (**c**) cold-wet and (**d**) arm-dry seasons. The truncation is done through black cut-off Bar corresponding to the Euclidian distance for the required optimum number of clusters.

### 4.1.3. Climate Zones and Reference site

This study reveals useful details on the efficiency of machine learning algorithms when formulating large, homogeneous regions of precipitation for different seasons. Homogeneous precipitation sites were identified based on the validated number of cluster-values. These station clusters were then plotted on a study area map, as shown in Figure 6. All homogeneous regions are differentiated by different colors to show the general spatial patterns associated with precipitation, the approximate boundaries are drawn on the maps to allow for a clearer representation. The clusters of stations transformation into different climate zones for all seasons have been shown in Figure 8.

The Jhelum and Chenab river basins are partitioned into clusters 13, 13, 14 and 13 for the warm-wet, cold-dry, cold-wet and warm-dry seasons, respectively, based on the analysis. There were three outliers in the clusters of the warm-wet season, one outlier in the clusters of the cold-dry season, one outlier in the cold-wet and one outlier in the warm-dry season. After merging the outliers with the neighboring clusters, the final demarcations of the basins were mapped for every season. The basins are finally classified into 10, 12, 13 and 13 climate zones for warm-wet, cold-dry, cold-wet and warm-dry seasons. The reference station was identified in each climate zone for the selection of the best-correlated GCM, using the past performance analysis of the GCM at the reference station in each climate zone.

**Figure 8.** Climate zones in the Jhelum and Chenab river basins in the (**a**) warm-wet, (**b**) cold-dry, (**c**) cold-wet and (**d**) warm-dry seasons.

### *4.2. GCM Selection*

We investigated the capabilities of 21 CMIP5 GCMs to reconstruct the highly resolved observation data for precipitation and, using Pearson correlation analysis, identified the most correlated GCMs for a dependable climate projection in a climate zone. The GCMs were selected in every climate zone of the study area using the GCM and APHRODITE seasonal precipitation data at the reference station. The reference stations in every climate zone for every season have been shown in Figure 8. These selected GCMs were then used for the projections of changes in the precipitation trends at every grid point throughout the respective zone in the river basins. Figure 9 shows the spatial distribution of the highly correlated GCMs in each climate zone. For the warm-wet season, we selected the BNU-ESM, MIROC5, CESM, IPSL-CM5A-MR and GFDL-ESM2G GCMs. For the cold-dry season, we selected the CCSM4, MIROC-ESM, IPSL-CM5A-LR, MIROC5 and MPI-ESM-LR GCMs. For the cold-wet season, we selected the MRI-CGCM3, MIROC-ESM, NorESMI, CESM, IPSL-CM5A-MR and GFDL-ESM2G GCMs. For the warm-dry season, we selected the IPSL-CM5A-MR, BNU-ESM, GFDL-CM3, bcc-CSM1-1, inmcm4 and GFDL-CM3 GCMs. The description of the aforementioned GCMs can be checked from the Table S1 in supplementary documents.

### *4.3. Validation of Selected GCMs*

The selected GCMs and simple mean of MME (21 GCMs) daily precipitation data at every grid station were resampled for the seasonal cycles. The KS test was applied to every grid station and the data distribution was compared with the APHRODITE data distribution. *p*-values < 0.05 demarcate between the null and alternative hypotheses. When the *p*-value < 0.05, the similar distribution hypothesis is rejected and vice versa. Figure 10 shows the spatial distribution of the *p*-value for the KS test. The result shows that the selected GCMs present a higher degree of similarity in the daily seasonal distribution to the APHRODITE distribution as compared with the simple mean of MME distribution. The results show that 45% of the time-series distribution of the selected GCMs' datasets did not fall in the rejection zone (when α = 5%), on the other hand, when using the datasets of simple mean MME, 28% of the datasets did not fall in the rejection zone for the warm-wet season. 76% of the time-series distribution of the datasets of the selected GCMs did not fall in the rejection zone (when α = 5%) for the cold-dry season, whereas only 55% of the datasets were not in the rejection zone when using the datasets of simple mean MME. For the cold-wet season, 76% of the time-series distribution of the selected GCMs' datasets did not fall in the rejection zone (when α = 5%), whereas 0% of the datasets were not in the rejection zone when the simple mean MME was tested. For the warm-dry season, 73% of the time-series distribution of the selected GCMs' datasets did not fall in the rejection zone (when α = 5%), whereas only traces of 10% of the datasets were not in the rejection zone when the

simple mean MME was tested. The results validated the GCM selection procedure when assessed against past performance.

**Figure 9.** The spatial distribution of the highly correlated general circulation models (GCMs) with the observed data in each respective climate zone for the (**a**) warm-wet, (**b**) cold-dry, (**c**) cold-wet and (**d**) warm-dry seasons.

**Figure 10.** The *p*-values for the K-S Test comparing the APHRODITE data with selected GCMs data and conventional composite mean data of 21 GCMs for warm-wet, cold-dry, cold-wet and warm-dry season (**a**) Using Selected GCMs (**b**) Simple mean Multimodel Ensemble. The green color bands are presenting the zones where the null hypothesis of "similar distribution" is not rejected.

### *4.4. Seasonal Precipitation Trend Projection*

The significance of seasonal precipitation trends for the period of 2006–2099, as detected by the MK trend detection test, for forcing scenarios RCP 4.5 and 8.5, are presented in Figure 11. The distribution of *p*-value spatially for the MK test has been presented. Spatial patterns of the estimated Sen's slope exhibit negative and positive slopes for the trends in different seasons, as shown in Figure 12. The significance of the trends was identified via the MK test of trend detection, as shown in Figure 11. For the warm-wet season, with the RCP 4.5 forcing scenario (refer to Figure 12a), 28% of the total study region yielded very strong and 5% of the total study region yielded strong evidence for decreasing precipitation at values of 2.2–3.29 mm yr−<sup>1</sup> over the central and northwest of the area. For RCP 4.5, 67% of the region showed no or weak decreasing trends. However, for the RCP 8.5 (Figure 12e), weak evidence for increasing trends was detected in the central and southeast region at 0.5–2 mm yr<sup>−</sup>1, except for certain areas in the central to western regions, where we detected a non-significant decreasing trend, with a slope of 1–3.29 mm yr−1. Figure 11 depicts the *p*-values for the MK trend detection, which shows significant and non-significant trends.

**Figure 11.** The *p*-values of the MK test for the best-correlated models for the period (2006–2099) for warm-wet, cold-dry, cold-wet and warm-dry for (**a**) Forcing Scenario of RCP 4.5 (**b**) Forcing Scenario of RCP 8.5. The blue and green shades are depicting that the Null Hypothesis of "no trend" is rejected.

For the cold-dry season and RCP 4.5 (Figure 12b), nearly all the regions of the study area had weak, increasing trends with Sen's slopes of 0–0.5 mm yr<sup>−</sup>1, which are not significant, as depicted in Figure 11. On the other hand, for RCP 8.5 (Figure 12f), weak decreasing trends were detected in the central and southwest regions and certain areas of the southeast, with Sen's slopes of 0.2 mm yr<sup>−</sup>1. Here, 80% of the total area had weak evidence of increasing trends at 0.5–1.5 mm/season/year.

For the cold-wet season for the RCP 4.5 (Figure 12c), the weak decreasing trend, with Sen's slopes of 0.2–0.5 mm yr−1, was detected in the central and southwest parts of the area. High and weak increasing trends were detected in the northwest and southeast parts of the area, respectively. The positive slope varies between 0.5 and 2 mm yr<sup>−</sup>1.

For the cold-wet season in the RCP 8.5 (Figure 12g), a non-significant decreasing trend was detected with Sen's slopes varying between 0.5 and 0.2 mm yr−<sup>1</sup> in the central area, whereas in other regions, a weak, increasing trend was detected with slopes varying between 0.5 and 2 mm yr<sup>−</sup>1.

**Figure 12.** The seasonal precipitation trends in the study area for (**a**) warm-wet RCP 4.5, (**b**) cold-dry RCP 4.5, (**c**) cold-wet RCP 4.5, (**d**) warm-dry RCP 4.5, (**e**) warm-wet RCP 8.5, (**f**) cold-dry RCP 8.5, (**g**) cold-wet RCP 8.5 and (**h**) warm-dry RCP 8.5.

Unlike for the warm-dry season and RCP 4.5 (Figure 12d), there is strong evidence for an increase in the precipitation for the whole study area with *p*-values from the MK trend detection at less than 0.05 for nearly the entire study area, with the Sen's slope indicating an increase in the precipitation rate with values varying between 1.5 and 3.5 mm yr<sup>−</sup>1. The same is the case for RCP 8.5 in the warm-dry season based on strong evidence of an increase in the precipitation throughout the entire area, with Sen's slope values varying between 1.5 and 3.5 mm yr−<sup>1</sup> (Figure 12h). Figure 11 shows the *p*-values for the MK test for the best-correlated models (2006–2099) for RCPs 4.5 and 8.5.

### **5. Conclusions**

Previous studies have not currently arrived at a consensus on a universally accepted method and criteria used for GCM selection [14,78]. Studies continue to investigate the methods, whether dynamic or statistical, to reduce the uncertainty in predicting climate change impacts. Intending to reduce the uncertainty, we presented a novel and flexible framework for the selection of GCMs in homogeneous climatic zones, which are based on daily seasonal precipitation data statistics spanning the baseline/historical period (1970–2005) for the Jhelum and Chenab river basin study areas. The GCM selected in each of the climate zones of every season can reproduce the climate distribution patterns in the study area. This has been proved by comparatively analyzing the precipitation data sets of GCMs selected and the simple composite mean of the ensemble of the 21 CMIP5 GCMs. Using these homogeneous precipitation regions, we suggest a different approach for selecting the GCMs. The GCMs were selected based on prior performance and compared with the highly resolved and gridded APHRODITE data. The GCM selection was validated for its performance to emulate the spatial precipitation patterns during the baseline period (1970–2005) for the study region. We sampled the daily precipitation data for the projected period using the selected GCM. The trends in the precipitation variability, based on the MK trend detection statistics, show that the regions in the Jhelum and Chenab river basins have negative, positive and no trends from 2006 to 2099.

The Jhelum and Chenab river basins are scarcely gauged regions; hence, the APHRODITE gridded datasets were used in this study because this type of climate zoning requires a dense network of climate observation stations. Numerous previous studies have verified the reliability of the APHRODITE data in this region [22,40]. However, at high altitudes, the deviation between the gridded data and in situ observations is more than that of flat regions due to a reduced number of gauging stations in mountainous areas, which poses a challenge to precise interpolation [79]. Uncertainties in the forcing datasets persist due to the scarcity of gauging stations in high elevation zones. Furthermore, gauging stations in valleys cannot represent peripheral higher elevation zones due to the high vertical lapse rate of precipitation [41]. Currently, there is no consensus among the climate research community regarding the methods of bias correction for high elevation areas. The comparative analysis shows that the APHRODITE correlates well with the observed dataset of some gauging stations at high altitudes, leading to the assumption of having sufficient accuracy for the regionalization process. Nevertheless, this regionalization framework is equally applicable to other datasets having comparatively higher resolution and accuracy.

All the analyses were conducted using the Python programming language, which is known to be powerful in machine learning. Several machine learning modules from the scikit-learn library, such as Principal Component Analysis (PCA), Agglomerative Hierarchical Clustering (AHC) and clustering validity indices, correlation coefficient and the KS test, were used in the analysis. All these machine learning algorithms were compiled to develop a program for regionalization and climate change trend detection. The program can be provided to the climate research community to augment decision- and policymaking for water resource planning and management.

Potential forecasts from the GCMs, however, are fundamentally uncertain and decision-makers still find it difficult to understand or use such predictions of climate change. The major cause of this uncertainty in the assessments derives from the inherent uncertainties in GCM outputs [80–83]. Based on the results and interpretations presented in this study, we can draw the following conclusions:


region. Infrastructure development and climate forecasting are based on effective regionalization using reliable estimates of climate variables. The region should be conditionally considered as climatically homogeneous to proceed with the selection of GCMs for the climate change impact assessments.


There is extensive literature available, addressing the uncertainties in climate projections through climate models [5,17]. The uncertainty in predicting the near future day to day weather of dynamic chaotic weather systems has been discussed by meteorologist Edward N Lorenz [16], in which he emphasized the theory of chaos, that is, the importance of the sensitivity of initial state simulation and parametrization of a model for plausible weather forecast. The knowledge of the precise initial state of the weather conditions is essential for the robust whether forecast, furthermore, it is also useful to predict the climate having initial conditions of weather. However, the predictability of the weather is limited to few days due to the reason of atmospheric dynamics. Weather patterns are affected by atmospheric oscillations such as El Nino and La Nina. Climate projections are done using the heat energy changes on Earth's surface and the changes in greenhouse gas emissions in the atmosphere. The projections of these changes to eventually project the climate is significantly easier than the weather forecasting of a week. To summaries, long term changes in atmospheric composition are significantly more predictable than the short term weather forecast [84].

The sources of confidence in future projections by the models are their physical basis and their performance in simulating the past/historical climate. According to IPCC, the models have been proven as important tools to project the future climate [84]. Their performance in representing large scale climate features are extensively evaluated through the comparative analysis of its simulations with the observations and they have been used to predict the ancient climate such as the warm Holocene period, the last glacial maximum and the last ice age [84].

The paper is envisioned to provide a framework that can easily be replicated to project the climate change trends for an area by utilizing the stochastic analysis strengths of machine learning and the available observed and GCMs' data. It tends to be an important reference to those who are working on impact modeling. However, it does not aim to discuss the dynamic processes and associated uncertainties present in all GCMs. These are presented in the literature and their simulation outputs have been used in the analysis. To gain a higher confidence level, the end-users of the present framework are suggested to use large ensemble of GCMs while selecting a suitable model in each zone of the study area. The codes prepared in this study are provided upon request.

Further studies are required to analyze the seasonal climate shift and how accurately the GCM outputs can project the present climatic zones, which are delineated based on the observed gridded historical data.

**Supplementary Materials:** The following are available online at http://www.mdpi.com/2076-3417/10/19/6878/s1, Table S1: Descriptions of the GCMs used in the study [43].

**Author Contributions:** Conceptualization, A.N., M.S. and H.F.G.; methodology, A.N., M.S. and S.A.; software, S.H.; validation, S.A., S.H. and H.F.G.; formal analysis, A.N. and M.S.; investigation, A.N.; resources, H.F.G.; data curation, A.N., M.S. and S.A.J.; writing—original draft preparation, A.N.; writing—review and editing, S.H, S.A. and H.F.G.; visualization, A.N. and S.A.J.; supervision, H.F.G.; project administration, S.S.J. and S.A. All authors have read and agreed to the published version of the manuscript.

**Funding:** This research received no external funding.

**Acknowledgments:** We highly acknowledge the support of Jahangir Ali, who downloaded the extracted NEX GDDP data and provided guidance pertaining to the machine learning algorithms.

**Conflicts of Interest:** The authors declare no conflict of interest.

### **References**


© 2020 by the authors. Licensee MDPI, Basel, Switzerland. This article is an open access article distributed under the terms and conditions of the Creative Commons Attribution (CC BY) license (http://creativecommons.org/licenses/by/4.0/).

## *Article* **Detection of Precipitation and Fog Using Machine Learning on Backscatter Data from Lidar Ceilometer**

**Yong-Hyuk Kim 1, Seung-Hyun Moon <sup>1</sup> and Yourim Yoon 2,\***


Received: 21 August 2020; Accepted: 14 September 2020; Published: 16 September 2020

**Abstract:** The lidar ceilometer estimates cloud height by analyzing backscatter data. This study examines weather detectability using a lidar ceilometer by making an unprecedented attempt at detecting weather phenomena through the application of machine learning techniques to the backscatter data obtained from a lidar ceilometer. This study investigates the weather phenomena of precipitation and fog, which are expected to greatly affect backscatter data. In this experiment, the backscatter data obtained from the lidar ceilometer, CL51, installed in Boseong, South Korea, were used. For validation, the data from the automatic weather station for precipitation and visibility sensor PWD20 for fog, installed at the same location, were used. The experimental results showed potential for precipitation detection, which yielded an F1 score of 0.34. However, fog detection was found to be very difficult and yielded an F1 score of 0.10.

**Keywords:** backscatter data; lidar ceilometer; weather detection; machine learning

### **1. Introduction**

The lidar ceilometer is a remote observation device used to measure cloud height at the location in which it is installed. Many studies [1–7] have obtained planetary boundary layer height (PBLH) by analyzing backscatter data: the raw data obtained using a lidar ceilometer. However, there is considerable room for improvement due to the limited accuracy of past methodologies. In the past, backscatter data from a lidar ceilometer were used primarily for PBLH measurements, but recently they have been used for radiation fog alerts [8], optical aerosol characterization [9], aerosol dispersion simulation [10], and studies of the relationship between cloud occurrence and precipitation [11].

Machine learning techniques have been actively applied to the meteorology field in recent years. For example, they are used for forecasting very short-range heavy precipitation [12,13], quality control [14,15] and correction [15–18] of observed weather data, and predicting winter precipitation types [19].

This study attempts to conduct an unprecedented analysis of backscatter data obtained using a lidar ceilometer. Beyond the conventional use of backscatter data for the analysis of PBLH measurement, their correlations with weather phenomena are analyzed and weather detectability is examined. The weather phenomena of precipitation and fog, which are expected to affect backscatter data, were examined. Cloud occurrence is related to precipitation [11], and backscatter measurements can be used to predict fog [8]. Machine-learning techniques are applied in weather detection. To detect the aforementioned weather phenomena (precipitation and fog) from the backscatter data obtained from the lidar ceilometer, three machine learning models: random forest, support vector machine, and artificial neural network were applied.

This paper is organized as follows: Section 2 describes the three machine learning models used in this study. Section 3 introduces backscatter data obtained from the lidar ceilometer, and observational data of precipitation and fog. Section 4 presents machine learning methods for detecting the weather phenomena. Finally, conclusions are drawn in Section 5.

### **2. Preliminaries**

### *2.1. Random Forest*

Random forest, an ensemble learning method used in classification and regression analysis, is designed to output the mode of the classes or the average forecast value of each tree by training multiple decision trees. The first proper random forest was introduced by Breiman [20]. To build a forest of uncorrelated trees, random forest uses a CART (Classification and Regression Tree) procedure combined with randomized node optimization and bagging. The two key elements of random forest are the number of trees and the maximum allowed depth. As the number of random trees increases, random forest generalizes well, but the training time increases. The maximum allowed depth is the number of nodes from the root node to the terminal node. Under-fitting may occur if the maximum allowed depth is small, and over-fitting may occur if it is large. This study set the number of trees to 100 and did not limit the maximum allowed depth.

### *2.2. Support Vector Machine*

Support vector machine (SVM) [21] is a supervised learning model for pattern recognition and data analysis and is mainly used for classification and regression analysis. When a binary classification problem is given, SVM creates a non-probabilistic binary linear classification model for classifying data depending on the category it belongs to. SVM constructs a hyperplane that best separates the training data points with the maximum-margin. In addition to linear classification, SVM can efficiently perform nonlinear classification using a kernel trick that maps data to a high dimensional space.

### *2.3. Artificial Neural Networks*

Artificial neural network (ANN) [22,23] is a statistical learning algorithm inspired by the neural network of biology. ANN generally refers to a model with neurons (nodes) forming a network through the binding of synapses that have a problem-solving ability by changing the binding force of synapses through training. This study used multilayer perceptron (MLP). The basic structure of ANN is composed of the input, hidden, and output layers, and each layer is made up of multiple neurons. Training is divided into two steps: forward and backward calculations. In the forward calculation step, the linear function composed of the weights and thresholds of each layer is used for calculation, and the result is produced through the nonlinear output function. Thus, ANN can perform nonlinear classification because it is a combination of linear and nonlinear functions. In the backward calculation step, it seeks the optimal weight to minimize the error between the predicted and target value (the answer).

### **3. Weather Data**

Three types of weather data were obtained from Korea Meteorological Administration for this study [24]. This section describes the details of each dataset.

### *3.1. Backscatter Data from Lidar Ceilometer*

Backscatter data were collected by a lidar ceilometer installed in Boseong, South Korea, from 1 January 2015 to 31 May 2016. CL51 [25] is a ceilometer manufactured by Vaisala. The CL51 ceilometer can provide backscatter profile and detect clouds up to 13 km, which is twice the range of the previous model CL31. Table 1 gives the basic information on backscatter data. Missing backscatter data were not used. The measurable range of CL51 is 15 km and the vertical resolution is 10 m. Therefore, 15,000 backscatter data are recorded in each observation. However, only the bottom 450 data were used considering that the planetary boundary layer height is generally formed within 3 km. The CL51 provided input to our scheme by calculating the cloud height and volume using the *sky-condition algorithm*. The sky-condition algorithm is used to construct an image of the entire sky based on the ceilometer measurements (raw backscatter data) only from one single point. No more details of this algorithm have been released. The mechanical characteristics of ceilometer CL51 are outlined in Table 2.


**Table 1.** Information on the collected backscatter data.


**Table 2.** Specification of lidar ceilometer CL51.

The BL-View software [26] estimates the PBLH from two types of ceilometer data: levels 2 and 3. In level 2, backscatter data are stored at intervals of 16 s after post-processing of cloud-and-rain filter, moving average, application of threshold values, and removal of abnormal values. In level 3, the PBLHs calculated using the level 2 data are stored. As the raw backscatter data and the level-2 data of BL-View have different measurement intervals, the data of the nearest time slot were matched and used.

For the raw backscatter data, the denoising method [27] was applied. The noise was eliminated through linear interpolation and denoising autoencoder [28]. Considering that the backscatter signals by aerosol particles are mostly similar, the relatively larger backscatter signals were removed through linear interpolation. The backscatter data to which linear interpolation was applied was used as input data of the denoising autoencoder. The moving average of backscatter data was calculated and used as input data to denoise the backscatter data. We used the denoised backscatter data in our experiments. More details about the used backscatter data related to denoising and weather phenomenon are given in Appendix A.

### *3.2. Data from Automatic Weather Station*

This study used the data collected from 1 January 2015 to 31 May 2016 from an automatic weather station (AWS) installed in Boseong, South Korea. AWS is a device that enables automatic weather observation. Observation elements include temperature, accumulated precipitation, precipitation sensing, wind direction, wind speed, relative humidity, and sea-level air pressure. In this study, AWS data with 1 h observation interval was used; the collected information is listed in Table 3. The installation information of the AWS in Boseong is outlined in Table 4.


**Table 3.** Information on the collected AWS data.

#### **Table 4.** Information on the used AWS.


As the observation interval of AWS data is different from that of backscatter data, the data of the nearest time slot based on the backscatter data were matched and used. The used elements included precipitation sensing, accumulated precipitation, relative humidity, and sea-level air pressure.

### *3.3. Data from Visibility Sensor*

Visibility data were collected by PWD20 [29] installed in Boseong, South Korea (see Table 5). PWD20 manufactured by Vaisala is a device that is used to observe the MOR (measurement range) and current weather condition. As its observation range is 10–20,000 m, and vertical resolution is 1 m, it allows a determination of long-range visibility. PWD20 can be fixed to various types of towers because the device is short, compact, and lightweight. The mechanical properties of PWD20 are outlined in Table 6.


**Table 5.** Information on the collected visibility data.



Fog reduces visibility below 1000 m, and it occurs at a relative humidity near 100% [30]. The visibility sensor data was used to determine fog presence (low visibility). The criteria for

fog were 1000 m or lower visibility for 20 mins, 90% or higher relative humidity, and no precipitation. The AWS data were used for precipitation sensing and relative humidity. As the ceilometer backscatter data, AWS and visibility sensor data have different observation intervals, the data of the nearest time slot were matched and used based on the ceilometer backscatter data.

### **4. Weather Detection**

In this section, we use the denoised backscatter data, cloud volume, and cloud height as training data, and describe how to detect weather phenomena using three machine learning (ML) models: random forest, SVM, and ANN.

### *4.1. Data Analysis*

The observational data of AWS range from 1 January 2015 to 31 December 2016. The performance of learning algorithms may decrease if they are not provided with enough training data. For precipitation, the presence or absence of precipitation in the AWS hourly data was used. In general, the lower the visibility sensor value was, the higher the probability of fog was. Hence, the visibility sensor data were categorized into 1000 m or below, between 1000 m and 20,000 m, and 20,000 m or higher. Table 7 shows that precipitation data account for 6.38% of all data (1120 cases), and Table 8 shows that visibility sensor data that fall into 1000 m or below account for 1.16% of all data (11,082 cases).

**Table 7.** Statistics of hourly AWS data related to precipitation.



**Table 8.** Statistics of visibility data.

The precipitation data has a value of 0 or 1, indicating only presence or absence, thus forming a binomial distribution. However, the visibility sensor values range from 0 m to 20,000 m and the data distribution can be represented as a histogram in Figure 1. The MOR in the figure indicates the visibility sensor value, and the values at the bottom of the figure are the mean (μ) and standard distribution (σ) of all visibility sensor values. The line indicates a normal distribution, and the values on the *X* axis are the values obtained by adding or subtracting the standard deviation to or from the mean.

**Figure 1.** Histogram on visibility data.

The observation interval of ceilometer backscatter data is irregular at approximately 30 s. The observation interval of AWS data is irregular at 1 h, and the observation interval of visibility sensor data is irregular at 1–3 mins. The intervals of observational data were adjusted to those of the backscatter data.

### *4.2. Training Data Generation*

The training data were generated using denoised backscatter data and weather phenomenon (presence/absence) data. For precipitation, the precipitation sensing value of AWS was used. For fog, AWS and visibility data were used to indicate whether fog occurred. As shown in Table 9, the backscatter coefficients of all heights were used for AWS hourly data.


**Table 9.** Field information on train data for weather phenomenon detection.

#### *4.3. Under-Sampling*

The number of absent examples was much higher than that of present examples in the training data. Such highly imbalanced data can hinder the training process of machine learning algorithms, making the resulting prediction model rarely predict the present examples. Therefore, we used under-sampling to balance the training data as in [12,31].

The training data comprised data from the first day to the 15th day of each month, and the validation data were composed of data ranging from the 16th to the last day of each month. Random forest was used to find the optimal under-sampling ratio by varying the presence to absence ratio from 1:1 to 1:7. Note that under-sampling is applied only to the training data.

To validate the results, we calculated and compared the accuracy, precision, false alarm rate (FAR), recall (or probability of detection; POD), and F1 score, which are measures that are frequently used in machine learning and meteorology (see Table 10) [32]. Accuracy is the probability that the observed value will coincide with the predicted value among all data ((*a* + *d*)/*n*). In precipitation detection, precision is the probability that the predicted precipitation is correct (*a*/(*a* + *b*)); the FAR is the number of false alarms over the total number of alarms or predicted precipitation samples (*b*/(*a* + *b*)), and recall (or POD) is the fraction of the total amount of precipitation occurrences that were correctly predicted (*a*/(*a* + *c*)). In fog detection, precision is the probability that the predicted fog is correct (*a*/(*a* + *b*)), the FAR is the number of false alarms over the total number of alarms or predicted fog samples (*b*/(*a* + *b*)), and recall (or POD) is the fraction of the total amount of fog occurrences that were correctly predicted (*a*/(*a* + *c*)). F1 score is an index that measures the accuracy of validation, and the harmonic mean of precision and recall (2 × (*precision* × *recall*)/(*precision* + *recall*)). In imbalanced classification, accuracy can be a misleading metric. Therefore, the F1 score, which considers both precision and recall, is widely used as a major assessment criterion [33,34].


**Table 10.** Contingency table for prediction of a binary event. The numbers of occurrences in each category are denoted by *a*, *b*, *c*, and *d*.

In Tables 11 and 12, F1 score is the highest when the under-sampling ratio is 1:2. When the under-sampling ratio is 1:7, accuracy is high, but precision is very low. A high precision is good considering the accuracy in the case of precipitation or visibility sensor phenomenon. However, if precision is high, there is a tendency to only overestimate the corresponding phenomenon. Therefore, we selected the case of the highest F1 score whereby precision and recall were balanced. In other words, we under-sampled the precipitation and fog (low visibility) at the under-sampling ratio of 1:2.

**Table 11.** Results of precipitation detection according to the under-sampling ratio. The bold number is the best result.


**Table 12.** Results of low visibility detection according to the under-sampling ratio. The bold number is the best result.


We also compared our results with two versions of random prediction (see the two bottom rows of Tables 11 and 12): one method called "Rand" evenly predicts the presence or absence of weather phenomenon at random, and the other method called "W-rand" randomly predicts the presence or absence of weather phenomenon with the weights according to the probability of actual observation of weather phenomenon. We could clearly see that random forest with under-sampling is superior to random prediction with respect to F1 score.

### *4.4. Feature Selection*

A large number of input features significantly increases the computation time of machine learning algorithms and requires an enormous amount of training data to ensure sufficient training. There are 452 input features as shown in Table 9, and these need to be reduced. Figures 2 and 3 show the analyses of denoised backscatter data from 1 January 2015 to 31 May 2016 for precipitation and visibility sensor data, respectively. In Figure 2, 'True' indicates the case of precipitation, and 'False' indicates the case of non-precipitation. In Figure 3, 'True' indicates that the visibility sensor value is equal to or lower than 1000 m, and 'False' indicates that the visibility sensor value is higher than 1000 m. The line indicates the backscatter mean value according to height, and the colored part indicates the area of mean ± standard deviation. Above certain heights, it seems difficult to predict the weather phenomena using backscatter data.

**Figure 2.** Backscatter data with precipitation (True) and without precipitation (False).

**Figure 3.** Backscatter data with low visibility (True) or without low visibility (False).

Therefore, we do not have to use all heights to detect weather phenomena. Random forest was applied after categorizing the total height of 4500 m into 10–300 m, 10–600 m, ... , and 10−4500 m while maintaining the under-sampling ratio at 1:2.

Tables 13 and 14 show that using all the height values ranging from 10 m to 4500 m for precipitation produced satisfactory results. In the case of visibility sensor data, using heights that ranged from 10 m to 3300 m yielded better results than using all the heights that ranged from 10 m to 4500 m. Therefore, the visibility sensor data could yield better results with smaller inputs.


**Table 13.** Results of precipitation detection according to feature selection. The bold number is the best result.

**Table 14.** Results of low visibility detection according to feature selection. The bold number is the best result.


In the case of precipitation, the number of input features was not reduced by feature selection, and in the case of visibility sensor, the number of input features was reduced to 332, which is not small enough. To train SVM and ANN, we did not use all the heights of the backscatter data. In the case of precipitation, the height intervals of input features were changed from 10 m to 100 m. In other words, we used the backscatter data at 10 m, 110 m, 210 m, ... , and 4410 m. For visibility sensor, we used the backscatter data at 10 m, 110 m, ... , and 3210 m. Therefore, 47 input features were used to predict precipitation and 35 features to predict fog.

With our final model of random forest preprocessed by under-sampling and feature selection, we provide some observation of representative cases for precipitation and fog in Appendix B.

Tables 15 and 16 show the results of SVM and ANN. ANN1 is an MLP with one hidden layer whose number of nodes is half of that of the input layer. ANN2 is an MLP with two hidden layers, and the number of nodes at each hidden layer is half that of its input layer. For both precipitation and fog, random forest best classified the weather phenomena, yielding the highest F1 score.


**Table 15.** Results of precipitation detection according to other ML techniques.

**Table 16.** Results of low visibility detection according to other ML techniques.


### **5. Concluding Remarks**

In this study, we made the first attempt to detect weather phenomena using raw backscatter data obtained from a lidar ceilometer. For weather detection, various machine-learning techniques including under-sampling and feature selection were applied to the backscatter data. The AWS provided observational data for precipitation and the visibility data from PWD20 provided observational data for fog.

Our prediction results were not noticeably good, but if we consider the hardness of weather prediction/detection in the literature (e.g., precision and recall are about 0.5 and 0.3 for heavy rainfall prediction, respectively [13], and they are about 0.2 and 0.5 for lightning forecast, respectively [31]), our prediction results showed potential for precipitation detection (in which precision, recall, and F1 score are about 0.5, 0.2, and 0.3, respectively), but fog detection (in which precision, recall, and F1 score are all about 0.1) was found to be very difficult although it was better than random prediction.

In future work, we expect to improve the accuracy of planetary boundary layer height (PBLH) measurements by classifying backscatter data according to precipitation occurrences.

**Author Contributions:** Conceptualization, Y.-H.K. and S.-H.M.; methodology, Y.Y. and Y.-H.K.; validation, S.-H.M. and Y.Y.; formal analysis, Y.Y.; investigation, Y.Y. and S.-H.M.; resources, Y.-H.K.; data curation, S.-H.M.; writing—original draft preparation, Y.Y.; writing—review and editing, S.-H.M.; visualization, Y.-H.K.; supervision, Y.-H.K.; project administration, Y.-H.K.; funding acquisition, Y.-H.K. All authors have read and agreed to the published version of the manuscript.

**Funding:** This work was supported by the Technology Development for Supporting Weather Services, through the National Institute of Meteorological Sciences of Korea, in 2017. This research was also a part of the project titled 'Marine Oil Spill Risk Assessment and Development of Response Support System through Big Data Analysis', funded by the Ministry of Oceans and Fisheries, Korea.

**Acknowledgments:** The authors would like to thank Junghwan Lee and Yong Hee Lee for their valuable helps to greatly improve this paper.

**Conflicts of Interest:** The authors declare that there is no conflict of interests regarding the publication of this article.

#### **Appendix A Details of the Used Backscatter Coe**ffi**cients**

In this appendix section, we provide some details of the used backscatter data related to denoising and weather phenomenon, through some representative cases. Figure A1 shows an example of raw backscatter data and their denoised ones which were observed at a moment. We can see that noises are successfully removed. Figure A2 shows an extended time-height plot of raw backscatter data and their denoised ones which had been observed for one day (specifically on 18 March 2015 when daily precipitation was 61 mm). Figure A1 can be understood as a cross section at a point of Figure A2. For Figure A2, we chose a day during which both of precipitation and non-precipitation occurred while clouds are presented in both. In the right side of the figure, a gray box means a period that it rains continuously. We could find clear difference between each box boundary point and its adjacent one, but it does not seem easy to distinguish the two phenomena only by the values themselves.

**Figure A1.** Example data of backscatter coefficients at a moment ((**left**): raw backscatter data and (**right**): denoised backscatter data).

**Figure A2.** Example time-height plotting of backscatter coefficients on 18 March 2015, when both of precipitation and non-precipitation are mixed: a gray box means a precipitation period ((**left**): raw backscatter data and (**right**): denoised backscatter data).

Figure A3 shows an example of the denoised backscatter data of CL51 and observation range data of PWD20 which had been observed for one day (specifically on 31 March 2015). For the figure, we chose a day during which both of low visibility and not-low visibility occurred while clouds are presented in both. In the left side of the figure, a gray box means a period that visibility is continuously low (i.e., less than 1000 m). Similar to the precipitation case of Figure A2, it is hard to distinguish the two phenomena only by the values themselves. Moreover, we can see that in the middle period it is not easy even to find some difference between the box boundary point and its adjacent one.

**Figure A3.** Example time-height plotting of backscatter coefficients on 31 March 2015, when both of low visibility and not-low visibility are mixed: a gray box means a low visibility period ((**left**): denoised backscatter data and (**right**): visibility data).

### **Appendix B Case Observation**

Figure A4 shows an example of detecting precipitation through backscatter data using random forest. In the case of the left side, the observed weather phenomenon at 15:09:36 (hh:mm:ss) on 21 January 2015 is precipitation and the predicted weather phenomenon is also precipitation. In the case of the right side, the observed weather phenomenon at 23:16:48 on 21 January 2015 is non-precipitation, and the predicted weather phenomenon is precipitation. The blue line indicates the mean value of backscatter data according to the observation value, and the colored part is a section where the standard deviation was added or subtracted from the mean. The distributions of backscatter data are generally similar regardless of precipitation. Since the input of our machine learning model is only backscatter data, it seems natural for the model to output the same prediction result for both cases with similar distribution. Hence, this supports the hardness of predicting precipitation.

**Figure A4.** Example of backscatter data predicted as precipitation by machine learning ((**left**): actual precipitation and (**right**): actual non-precipitation).

Figure A5 shows an example of detecting non-precipitation from the backscatter data using random forest. In the case of the left side, the observed weather phenomenon at 23:55:12 on is non-precipitation and the predicted weather phenomenon is also non-prediction. In the case of the right side, the observed weather phenomenon at 15:55:48 on 21 January 2015 is precipitation and the predicted weather phenomenon is non-precipitation. Likewise, the distributions of backscatter data are clearly similar regardless of precipitation. As mentioned above, naturally the machine learning model seems to output the same prediction result for both cases. This also supports that it is difficult to predict non-precipitation.

**Figure A5.** Example of backscatter data predicted as non-precipitation through machine learning ((**left**): actual non-precipitation and (**right**): actual precipitation).

Figure A6 shows an example of detection through the backscatter data using random forest when the visibility sensor value is equal to or less than 1000 m. In the case of the left side, the observed and predicted visibility sensor values at 01:01:46 are both equal to or lower than 1000 m. In the case of the right side, the observed visibility sensor value at 00:46:46 on 22 January 2015 is greater than 1000 m and lower than 20,000 m, and the predicted value is equal to or lower than 1000 m. It is not easy to find clear difference between both cases.

**Figure A6.** Example of backscatter data predicted as low visibility by machine learning ((**left**): actual low visibility and (**right**): actual not-low visibility).

Figure A7 shows an example of detection from backscatter data using random forest when the visibility sensor value is greater than 1000 m and lower than 20,000 m. In the case of the left side, the observed and predicted visibility sensor values at 00:35:23 on 22 January 2015 are both greater than 1000 m and lower than 20,000 m. In the case of the right side, the observed visibility sensor value at 00:00:00 on 22 January 2015 is 1000 m or lower and the predicted value is greater than 1000 m and lower than 20,000 m. Analogously to the above, we cannot see distinct difference between both cases. It hints that it is very hard to differentiate fog phenomenon by using only backscatter data.

**Figure A7.** Example of backscatter data predicted as not-low visibility by machine learning ((**left**): actual not-low visibility and (**right**): actual low visibility).

### **References**


© 2020 by the authors. Licensee MDPI, Basel, Switzerland. This article is an open access article distributed under the terms and conditions of the Creative Commons Attribution (CC BY) license (http://creativecommons.org/licenses/by/4.0/).

### *Article*

## **A Machine Learning-Assisted Numerical Predictor for Compressive Strength of Geopolymer Concrete Based on Experimental Data and Sensitivity Analysis**

**An Thao Huynh 1, Quang Dang Nguyen 2, Qui Lieu Xuan 3,4, Bryan Magee 1, TaeChoong Chung 5, Kiet Tuan Tran <sup>6</sup> and Khoa Tan Nguyen 2,7,\***


Received: 8 October 2020; Accepted: 28 October 2020; Published: 31 October 2020

**Abstract:** Geopolymer concrete offers a favourable alternative to conventional Portland concrete due to its reduced embodied carbon dioxide (CO2) content. Engineering properties of geopolymer concrete, such as compressive strength, are commonly characterised based on experimental practices requiring large volumes of raw materials, time for sample preparation, and costly equipment. To help address this inefficiency, this study proposes machine learning-assisted numerical methods to predict compressive strength of fly ash-based geopolymer (FAGP) concrete. Methods assessed included artificial neural network (ANN), deep neural network (DNN), and deep residual network (ResNet), based on experimentally collected data. Performance of the proposed approaches were evaluated using various statistical measures including R-squared (R2), root mean square error (RMSE), and mean absolute percentage error (MAPE). Sensitivity analysis was carried out to identify effects of the following six input variables on the compressive strength of FAGP concrete: sodium hydroxide/sodium silicate ratio, fly ash/aggregate ratio, alkali activator/fly ash ratio, concentration of sodium hydroxide, curing time, and temperature. Fly ash/aggregate ratio was found to significantly affect compressive strength of FAGP concrete. Results obtained indicate that the proposed approaches offer reliable methods for FAGP design and optimisation. Of note was ResNet, which demonstrated the highest R<sup>2</sup> and lowest RMSE and MAPE values.

**Keywords:** geopolymer concrete; artificial neural network; machine learning; deep neural network; ResNet; compressive strength; fly ash

### **1. Introduction**

Emission of carbon dioxide caused by various sectors, including construction, industrial processes, transport, residential, and agriculture, has emerged as a severe problem that dramatically affects global climate change. Calcining limestone in Portland cement production represents 8% of global

anthropogenic CO2 emission [1]. Global production of cement increased rapidly from 1.5 billion tonnes in 1998 [2] to 4.1 billion tonnes in 2018 [3], which has significantly impacted emissions linked to the construction sector. This justifies the need for more sustainable alternatives sourced from industrial by-products/wastes with minimal embodied carbon, offering a balance of technical, environmental, and economic benefits.

In this context, geopolymer concrete using industrial by-products (e.g., fly ash and ground-granulated blast-furnace slag) has been reported to reduce up to 80% of CO2 emission relative to conventional concrete [4]. Geopolymer products are synthesised through the reaction of alkali liquid with silica and alumina contained in aluminosilicate precursors. Depending upon local resources, solid aluminosilicate precursors can be sourced from industrial by-products such as fly ash, metakaolin, red mud, and waste glass [5–8]. According to previous studies [9–13], fly ash-based geopolymer (FAGP) concrete showed the ability to achieve high compressive strengths up to 68 MPa. To assess the compressive strength of concrete, universal compression testing machines are typically used to apply compression load on cylindrical or cube specimens at a prescribed rate (e.g., 20–50 psi/s or 0.14–0.35 MPa/s based on American Society for Testing and Materials (ASTM) C39/C39M-18 [14]). In addition to direct testing, non-destructive testing methods, such as ultrasonic pulse velocity and rebound hammer, are also used to predict the compressive strength of concrete products [15–17]. However, these experimental methods rely heavily on costly equipment and time-consuming preparation of specimens.

As such, artificial intelligence approaches including artificial neural network, adaptive neuro fuzzy inference, and deep learning have been employed to predict the mechanical properties of FAGP concrete by several researchers [18–21]. Inspired by the biological neural system, the artificial neural network (ANN) algorithm with three neuron layers has been widely applied in different research fields such as civil engineering, biochemistry, pharmaceutics, and biology owing to its ability to learn complex relationships among values in its training patterns. Dao et al. [19] investigated the compressive strength of FAGP concrete consisting of steel slag aggregates using ANN and neuro fuzzy inference approaches. Mean absolute error, R-squared, and root mean square error (RMSE) were employed to evaluate the performance of the proposed approaches. Three input parameters including sodium hydroxide (NaOH) concentration, alkali activator/fly ash ratio, and sodium hydroxide-to-sodium silicate ratio were used to predict the compressive strength of FAGP concrete. Results obtained from the ANN were in substantial agreement with experiment data. Sensitivity analysis for ANN and adaptive neuro fuzzy inference were adopted in a study by the same authors [18] to evaluate the impact of each input factor including the mass of fly ash, sodium silicate (Na2SiO3), NaOH, and water on the accuracy of the proposed models. The two approaches effectively predicted the compressive strength of the geopolymer concrete using only three input parameters in the mixture proportion. Curing condition factors were neglected in these studies even though they undoubtedly play essential roles in the compressive strength of geopolymer concrete [22–25]. In addition to mixture proportion factors, curing time and temperature values were added in the training dataset for compressive strength prediction of FAGP concrete using ANN in a study by Ling et al. [21]. Results from ANN modelling methods showed good agreement with those obtained from experiments. The authors concluded that compressive strength of geopolymer concrete was profoundly influenced by mixture proportion and curing conditions. Performance of deep neural network (DNN) and deep residual network (ResNet) approaches in predicting the compressive strength of FAGC was investigated in a study by Nguyen et al. [20]. With high rate of recognition accuracy within a complex network, the ResNet model showed better performance than the DNN models, with two main forward and backward passes; therefore, it has been used in several advanced engineering problems [26,27]. ResNet and DNN approaches were also employed to predict compressive strength of conventional and foamed concrete in the studies by Jang et al. [28] and Nguyen et al. [29], respectively. Against this background, current solutions to predict compressive strength of FAGP concrete have not been dealt with in depth within existing literature. Although various machine learning approaches including ANN and

DNN have been separately introduced in several studies [19,21] as numerical predictors for FAGP strength, a thorough search of relevant published literature yielded a mere presence of the ResNet approach in FAGP property prediction. The lack of studies on impacts of input parameters (e.g., mix proportion ratios, NaOH concentration, and curing conditions) on geopolymer strength indicates possible improvements for upcoming research. More comprehensive research needs to be carried out to investigate the effectiveness of various machine learning methods in predicting compressive strength of FAGP concrete, considering a wider variety of input parameters and sensitivity analysis.

As such, this study aims to offer advancements to the existing literature by employing ANN, DNN, and ResNet approaches integrated with sensitivity analysis to predict the compressive strength of FAGP concrete. These models were trained through 263 pairs of input/target values obtained from experiments. Performance of FAGP strength prediction of the three proposed approaches was investigated in two phases. In the first phase, the models were trained and validated using randomly shuffled datasets. Additional training and assessment under K-fold cross validation schemes were then carried out to confirm the results obtained from the first phase. Impacts of six input parameters (including NaOH/Na2SiO3 ratio, fly ash/aggregate ratio, alkali liquid/fly ash ratio, NaOH concentration, curing time, and temperature) on prediction models were investigated using sensitivity analysis. Outcomes from sensitivity analysis are expected to identify the critical input parameters in FAGP strength prediction and control them carefully during geopolymer production. Three measures including R-squared (R2), root mean square error (RMSE), and mean absolute percentage error (MAPE) were employed to evaluate the accuracy of the proposed machine learning techniques.

### **2. Machine Learning Approaches**

### *2.1. Artificial Neural Network (ANN)*

Inspired by the biological neuron system, ANN is based on a suite of mutually connected units, known as perceptrons, which replicate the functions of neurons in the human brain. ANN is one of the main models used in machine learning where its structure is formed by three layers of neurons including input, hidden, and output layers. Independent variables enter the system through the input layer and are processed in the hidden layer, while predicted values are generated in the output layer. Figure 1 presents the basic concept of ANN.

**Figure 1.** The construction of the artificial neural network (ANN) [20].

### *2.2. Deep Neural Network (DNN)*

DNN consists of more layers and neurons than ANN, leading to its ability to learn functions with a high degree of complexity. DNN possesses a powerful representational ability of input data and can reduce over-fitting issues in regression performance [30]. With powerful representational ability, DNN is able to achieve high accuracy in various tasks [31]. A typical DNN network structure is presented in Figure 2, including two main forward and backward phases.

**Figure 2.** Deep neural network (DNN) with two hidden layers [20].

### *2.3. Deep Residual Network (ResNet)*

ResNet was developed to overcome a limitation in training deep networks where training errors can increase as the number of layers increases [20]. Owing to modified architectures, ResNet models have been empirically confirmed to enhance learnability of neural networks with less error on defined tasks using a limited number of layers [32]. ResNet consists of residual blocks with shortcut connections as shown in Figure 3, where the formulation H(x) is the desired mapping output of a specific layer and x is the input data. Given the presence of shortcut connections, gradient-based optimisation algorithms work effectively under ResNet-based architectures and improve the learnability of weight layers representing the function F(x) [33].

**Figure 3.** A block in a deep residual network [20].

### **3. Experimental Programme**

### *3.1. Materials and Mixing Process*

Constituent materials of the FAGP concretes considered were fly ash, coarse and fine aggregates, alkali activator, and water. Low-calcium fly ash (class F) with a specific gravity of 2500 kg/m3 was used as the main aluminosilicate precursor. The chemical composition of the fly ash used is presented in Table 1, which conforms to requirements from ASTM 618 [ASTM]. The FAGP concrete mix designs and mixing processes were based on a previous study by Nguyen et al. [20]. Geopolymer mix

designs were formulated based on various binder and aggregate contents, concentration of sodium hydroxide, and curing conditions. The ratio of fly ash mass to total aggregate mass (fly ash/aggregate) varied from 0.13–0.37. Specific gravities of the coarse and fine aggregates were 2700 kg/m3 and 2650 kg/m3, respectively.


**Table 1.** Chemical compositions of fly ash class F.

Sodium silicate solution consisting of 36% Na2O and 38% SiO2 by mass was mixed with sodium hydroxide with a wide range of concentrations including 4M, 8M, 11M, 12M, 15M, and 18M to prepare alkali liquid (AL). The ratios of NaOH/Na2SiO3 and AL/fly ash ranged from 0.4–2.5 and 0.3–0.7, respectively.

Fly ash and aggregates were mixed together on a slow setting for about three minutes. Alkali solution was then added and mixed for a further four minutes before casting. Fresh FAGP concrete was cast in standard cylinder moulds (100 mm diameter, 200 mm high), de-moulded after 24 h, and then cured in an oven at temperatures 40, 60, 80, 90, 100, and 120 ◦C for 2, 4, 6, 8, 10, and 12 h. The processing and testing procedure is represented in Figure 4.

**Figure 4.** Schematic illustration of experimental works.

### *3.2. Data Preparation for Machine Learning Approaches*

According to previous studies [10,22,34], FAGP concrete properties depend on constituent material proportioning, concentration of sodium hydroxide (CM), and curing conditions. In this study, a total of 263 pairs of input/target values fabricated from different geopolymer mix proportions, NaOH concentration, and curing conditions were designed to generate the data for running the machine learning-based models. Inside these models, the six input variables considered to estimate the compressive strength of FAGP concrete were: NaOH/Na2SiO3, fly ash/aggregate and AL/fly ash, CM, curing time, and curing temperature. For compressive strength measurement, FAGP concrete cylinders were subjected to axial compression with a loading rate of up to 0.35 MPa/s according to ASTM C39/C 39M-18 [14] after seven days. At least three specimens were tested for each mix design of FAGP concrete to obtain the mean value of the compressive strength. The test data from experimental works are given in Table 2.


**Table 2.** Statistical parameters of fly ash-based geopolymer (FAGP) concrete used in the training dataset.

### **4. Research Methodology**

In this study, 263 datasets (each comprising six inputs and one output) were used to train and validate ANN, DNN, and ResNet models. In terms of inputs, each dataset comprised a unique combination of the six mix design values considered, as summarised in Table 2. The output for each dataset was the corresponding average compressive 7-day strength result obtained from experimental testing. The range of strength values recorded for the 263 combinations considered was 5.55–67.86 MPa.

A data division scheme was applied to reduce possibilities of error and improve the reliability of predicted results. Random selection of about 90% of the values (235 datasets) in the training dataset were chosen from the original data collection to train the network, while the remaining values (28 datasets) remained untrained as a validation database to confirm the accuracy of the trained network. The structures of three machine learning approaches including ANN, DNN, and ResNet are presented in the schematic flowchart in Figure 5.

FAGP concrete compressive strength was predicted by employing ANN, DNN, and ResNet architectures comprising weight, normalisation, and activation layers in regression tasks. For comparative purposes, DNN and ResNet models consisted of the same number of nodes with 128 nodes in Weight Layer 1 and 256 nodes in Weight Layer 2. A layer with 256 nodes, known as Weight Layer 3, was included to enable additional operation at the end of ResNet implementation. The ANN model with one weight layer comprised 384 nodes. One of the stochastic gradient descent methods, known as Adam [35], was used as the optimisation method to update neural networks coefficients since it integrated advanced features from different optimisation algorithms, including AdaGrad and RMSProp. The layer normalisation method introduced by Ba et al. [36] was employed to ensure inputs to layers fell within specific ranges since it exhibited efficient training time in neural network architecture compared to traditional batch normalisation. Training models without normalisation were also carried out to validate the effectiveness of the model integrated with layer normalisation. During the training process, dropping out units with keep probability of 0.2 in the architectures were included in the final models to prevent overfitting problems. Table 3 presents details of the setting of six architectures (known as architectures 1–6) implemented in this study.

**Figure 5.** Schematic flowchart presenting the three machine learning approaches used to estimate FAGP concrete compressive strength.

**Table 3.** Details of the setting of six investigated architectures from artificial neural network (ANN), deep neural network (DNN) and deep residual network (ResNet).


Three statistical measures including R2, RMSE, and MAPE were applied to evaluate the accuracy of the proposed machine learning approaches under the K-fold cross validation scheme. These parameters provide insights into differences between original and estimated values. Higher R<sup>2</sup> value and/or lower MAPE and RMSE values indicate better prediction performance of machine learning approaches [19]. The three statistical measures were calculated using the following equations:

$$\mathcal{R}^2 = \frac{\left(n\sum\_{i} y\_i y\_i' - \sum\_{i} y\_i' \sum\_{i} y\_i\right)^2}{\left(n\sum\_{i} y\_i'^2 - \left(\sum\_{i} y\_i'\right)^2 \left(n\sum\_{i} y\_i^2 - \left(\sum\_{i} y\_i\right)^2\right)'\right)}\tag{1}$$

$$MAPE = \frac{1}{n} \sum \left| \frac{y\_j - y\_j'}{y\_j} \right| \times 100,\tag{2}$$

$$RMSE = \sqrt{\frac{1}{n} \sum\_{j=1}^{n} \left( y\_j - y\_j' \right)^2} \,\tag{3}$$

where *yj* and *y <sup>j</sup>* are the compressive strength obtained from experiments and predictions respectively; *n* is the number of datasets.

The K-fold cross validation method divides data into *K* equal folds and then does *K* independent training iterations on the prediction model with (*K* − 1) folds while leaving the remaining fold for validation purposes. In this experiment, the common value *K* = 10 was used. The performance of the prediction model was judged by averaging the metric measurement (R2, MAPE, and RMSE) measured in *K* training and evaluating the iterations as follows:

$$M\_{K-fold} = \frac{1}{K} \sum\_{k=1}^{K} m\_{k\prime} \tag{4}$$

where *MK-fold* denotes a general metric measurement when K-fold cross validation is applied, and *mk* is the metric measurement in the fold *k* of the procedure.

An important note is that the same training and validation sets in each fold were used to train and validate each model. A hypothesis test (e.g., paired *t*-test) with a significance level α = 0.05 was then applied to the accurate measurements of each model on validation sets in 10 divided folds to confirm the statistical significance of the results. The null hypothesis was that these measurements are all in the same population (or belong to the same model), suggesting there is no difference between the performance of two evaluated models. From the *t*-test, a *p*-value less than the chosen significance level (α = 0.05) can statistically confirm the advance of a prediction model over the others (rejecting the null hypothesis), while the *p*-value greater than this significance level may suggest that the event, or the numerical conclusion, happens by chance (not rejecting the null hypothesis).

### **5. Experimental Programme**

### *5.1. Estimative Performance of ANN, DNN, and ResNet Approaches*

In the first phase, six predictive models based on three proposed machine learning approaches (ANN, DNN, and ResNet) were trained and validated using randomly shuffled datasets obtained from experimental works. Input variables in the dataset consisted of six parameters including mixture proportions (i.e., NaOH/Na2SiO3, fly ash/aggregate, AL/fly ash), NaOH concentration, and curing conditions. Compressive strength of FAGP specimens was regarded as output variable. The results from the first phase were aimed to provide a short list of models to further test with K-fold cross validation and the *t*-test method as described in Section 4.

R2, RMSE, and MAPE values for the ANN (architecture 1 and 2), DNN (architecture 3 to 4), and ResNet (architecture 5 to 6) models are summarised in Table 4, with the bold numbers representing the best predictive model of each approach. As shown, architectures 1, 3, and 6 were found to be the best ANN, DNN, and ResNet models, respectively. From the six architectures presented in Table 4, ResNet-based architecture 6 was the best model for determining FAGP concrete compressive strength with the highest R2 of 0.937 and lowest RMSE and MAPE values (1.987 and 6.6, respectively). Apart from the ResNet models, ANN-based architecture 1 (R<sup>2</sup> = 0.889; RMSE = 4.711; MAPE = 14.06) and DNN-based architecture 3 (R<sup>2</sup> = 0.898; RMSE = 2.521; MAPE = 9.496) showed better predictive performance than the other models (architecture 2, 4, and 5). Based on these observations, ANN-based architecture 1, DNN-based architecture 3, and ResNet-based architecture 6 were selected for further investigation.


**Table 4.** Performance comparison of six architectures for FAGP compressive strength prediction in terms of R-squared (R2), root mean square error (RMSE) and mean absolute percentage error (MAPE).

In the second phase, further investigation into performance of the proposed approaches was carried out using additional training and assessment under a 10-fold scheme with three architectures: 1, 3, and 6. Results of various statistic measures (R2, MAPE, and RMSE) for each fold and the average (Avg.) values with standard deviations are presented in Table 5. The same training and validation sets of each fold were applied for the three models 1, 3, and 6. As shown in this table, ResNet-based architecture 6 obtained the best strength prediction performance in terms of R<sup>2</sup> (0.934 <sup>±</sup> 0.021), RMSE (2.750 ± 0.573), and MAPE (8.552 ± 1.333. Also, a further paired *t*-test with α = 0.05 was applied to prove the statistical significance of this observation. As presented in Table 6, *p*-values from the comparisons of ResNet model and ANN/DNN models were lower than the chosen significance level (α = 0.05), providing statistical confirmation that the ResNet model out-performed the ANN/DNN model in terms of FAGP strength prediction.

Relationships between the experimental and predicted strength values from architectures 1, 3, and 6 are illustrated in Figure 6. As shown, compressive strength values predicted by all machine learning models were close to the actual values obtained from compression experiments, indicating that the proposed approaches were successfully trained to predict FAGP compressive strength. The ResNet model outperformed the other models with the strongest relationship existing between actual and predicted values. This observation was confirmed in Figure 7, which presents the correlation coefficient (R) of the three approaches in terms of validation data. Minimal variation existed between actual and predicted values existed for the ANN, DNN, and ResNet models, albeit with the highest variance being associated with the former (architecture 1).

The relationship between iterations of the three best performed architectures (1, 3, and 6) and validation RMSE is shown in Figure 8. The highest convergence speed was observed in ResNet-based architecture 6 model, which required only 2000 iterations to reach a validation RMSE of 4.8 MPa. For the same RMSE, higher iteration numbers of 6000 and 148,000 were required for DNN and ANN models, respectively. After convergence, sufficiently low values of RMSE were observed in the ResNet and DNN models, indicating better performances over the ANN model. For instance, at the same iteration value of 152,000, the ANN model converged at an RMSE of 4.7 MPa while ResNet and DNN models achieved lower values of RMSE (2.1 MPa and 2.7 MPa, respectively).


**Table 5.** Accuracy measurements on validation sets for the proposed machine learning approaches (architectures 1, 3, and 6) under K-fold cross validation scheme.



**Figure 6.** Relationship between compressive strength values obtained from experiments (actual value) and machine learning approaches (predicted value): (**a**) training ANN; (**b**) validation ANN; (**c**) training DNN; (**d**) validation DNN; (**e**) training ResNet; (**f**) validation ResNet.

**Figure 7.** Correlation coefficients R of three proposed approaches: (**a**) training ANN; (**b**) validation ANN; (**c**) training DNN; (**d**) validation DNN; (**e**) training ResNet; (**f**) validation ResNet.

**Figure 8.** Relationship between validation RMSE and iterations.

Figure 9 presents the distribution of error rates at 5% increments for predicted results obtained from architectures 1, 3, and 6. It is noted that the majority of datasets (61%) from the ResNet model exhibited error levels less than 5%. Corresponding frequencies of errors less than 5% for the ANN and DNN models were significantly lower (approximately 21 and 35%, respectively). In terms of errors less than 20%, frequencies for the ANN, DNN, and ResNet models were 79, 89, and 89%, respectively. In terms of ranking, therefore, the ResNet model provided the best estimative performance, followed by the DNN and ANN models.

**Figure 9.** Error rate distribution of three proposed approaches: (**a**) validation ANN; (**b**) validation DNN; (**c**) validation ResNet.

### *5.2. Sensitivity Analysis*

Sensitivity analysis is commonly used to evaluate how input parameters affect output variation derived by machine learning models [37]. As the best performing model, ResNet-based architecture 6 was exclusively selected for this analysis, which involved calculating FAGP concrete compressive strength by changing one input variable at a time while maintaining the other five as constants based on their mean values. For example, to assess the importance of the NaOH/Na2SiO3 ratio, this value was varied from 0.4–2.5, while fly ash/aggregate, AL/Fly ash, NaOH concentration, curing time, and temperature values were kept constant at mean values of 0.23, 0.5, 14 (M), 8 h, and 85.6 ◦C, respectively. Data derived from this sensitivity analysis were returned to the training process to estimate compressive strength. For each parameter, a corresponding sensitivity analysis factor was given by the expression:

$$I\_i = f\_{\max}(\mathbf{x}\_i) - f\_{\min}(\mathbf{x}\_i), \tag{5}$$

$$SA\_i = \frac{I\_i}{\sum\_i I\_i} \times 100,\tag{6}$$

where *fmax*(*xi*) and *fmin*(*xi*) are the maximum and minimum estimated compressive strengths relating to the input variable *xi*, with all other input parameters kept constant at their mean values.

Figure 10 shows the results of this sensitivity analysis, from which a pronounced influence (35.5%) of fly ash/aggregate ratio on estimated compressive strength can be seen. A similar effect was observed in the study by Joseph and Mathew [11], and can be explained by the fact that the internal void structure

formed by fly ash and aggregates has direct effects on FAGP compressive strength. Additionally, shown in this figure are high sensitivity scores of 16.22%, 16.18%, and 14.93% for curing time, NaOH concentration, and curing temperature, respectively. This confirms the observations from previous studies by [22,23] where curing conditions play significant roles in compressive strength of FAGP concrete. As such, various factors such as mix proportions, sodium hydroxide concentration, and curing regimes should be thoroughly considered in the prediction of FAGP mechanical properties using machine learning approaches. In particular, based on these findings, it is recommended that the fly ash/aggregate ratio is carefully determined and controlled in geopolymer manufacturing processes owing to its pronounced effect on FAGP strength.

**Figure 10.** Sensitivity analysis parameters for the estimated compressive strength of FAGP concrete.

### **6. Conclusions**

This study employed three different machine learning approaches including ANN, DNN, and ResNet to predict compressive strength of fly ash-based geopolymer concrete. Six parameters of mix design and curing conditions (including NaOH/Na2SiO3, fly ash/aggregate, AL/Fly ash, concentration of sodium hydroxide, curing time, and temperature) and corresponding 7-day compressive strength results were used to generate 263 unique input/output pairs for model training purposes.

While the results indicated that all three machine learning approaches could predict FAGP concrete compressive strength with some degree of accuracy, the ResNet model was the most promising method with the highest R2 (0.937) and the lowest RMSE (1.987) and MAPE (6.6) values. This observation was confirmed by additional training and assessment under the K-fold cross validation scheme and paired *t*-test with <sup>α</sup> = 0.05, where the highest R2 (0.934 <sup>±</sup> 0.021) and the lowest RMSE (2.750 <sup>±</sup> 0.573) and MAPE (8.552 ± 1.333) were observed in the ResNet-based model. Sensitivity analysis performed for the ResNet model confirmed that the ratio of fly ash/aggregate was the most dominant factor when predicting compressive strength, with a sensitivity analysis score of 35%. This was followed in order of importance by curing temperature (16.22%), NaOH concentration (16.18%), curing time (14.93%), NaOH/Na2SiO3 ratio (12.90%), and AL/fly ash ratio (4.22%). This analysis indicates the importance of considering a wide range of input parameters in the prediction of FAGP concrete compressive strength and controlling them carefully during the manufacturing process.

This study provides a detailed understanding of performance of different machine learning approaches in strength prediction for FAGP concrete. The findings highlight potential uses of the proposed machine learning approaches such as ResNet and DNN as effective tools to, not only precisely predict mechanical properties of FAGP, but also to develop mix designs for geopolymer concrete. In the next phase of work, consideration will be given to how ResNet and DNN models can be applied in FAGP manufacturing industries to predict optimised mix designs and curing regimes based on target compressive strength. This predictive ability will also be linked to mix design evaluations in terms of potential cost and environmental benefits prior to construction stages. In addition, the proposed machine learning approaches adopted a general training scheme for neural networks with standard input and output features, indicating promising potential to be applied to other regression problems in upcoming research works.

**Author Contributions:** Conceptualization, Q.D.N. and K.T.N.; Data curation, Q.D.N. and Q.L.X.; Funding acquisition, T.C.; Investigation, A.T.H., Q.D.N. and K.T.N.; Methodology, Q.D.N. and Q.L.X.; Project administration, K.T.N.; Resources, K.T.N.; Software, Q.D.N. and Q.L.X.; Supervision, B.M. and T.C.; Validation, Q.L.X. and T.C.; Visualization, K.T.N.; Writing—original draft, A.T.H. and K.T.N.; Writing—review & editing, A.T.H., Q.L.X., K.T.T. and B.M. All authors have read and agreed to the published version of the manuscript.

**Funding:** This research was funded by the Basic Science Research Program through the National Research Foundation of Korea (NRF-2020R1F1A1050014).

**Conflicts of Interest:** The authors declare no conflict of interest.

### **References**


**Publisher's Note:** MDPI stays neutral with regard to jurisdictional claims in published maps and institutional affiliations.

© 2020 by the authors. Licensee MDPI, Basel, Switzerland. This article is an open access article distributed under the terms and conditions of the Creative Commons Attribution (CC BY) license (http://creativecommons.org/licenses/by/4.0/).

## *Article* **Structural Vibration Tests: Use of Artificial Neural Networks for Live Prediction of Structural Stress**

### **Laura Wilmes 1, Raymond Olympio 2, Kristin M. de Payrebrune 1,\* and Markus Schatz <sup>2</sup>**


Received: 31 October 2020; Accepted: 26 November 2020; Published: 29 November 2020

**Abstract:** One of the ongoing tasks in space structure testing is the vibration test, in which a given structure is mounted onto a shaker and excited by a certain input load on a given frequency range, in order to reproduce the rigor of launch. These vibration tests need to be conducted in order to ensure that the devised structure meets the expected loads of its future application. However, the structure must not be overtested to avoid any risk of damage. For this, the system's response to the testing loads, i.e., stresses and forces in the structure, must be monitored and predicted live during the test. In order to solve the issues associated with existing methods of live monitoring of the structure's response, this paper investigated the use of artificial neural networks (ANNs) to predict the system's responses during the test. Hence, a framework was developed with different use cases to compare various kinds of artificial neural networks and eventually identify the most promising one. Thus, the conducted research accounts for a novel method for live prediction of stresses, allowing failure to be evaluated for different types of material via yield criteria.

**Keywords:** mass operator; machine learning; structural stress; artificial neural network; live prediction; vibration test

### **1. Introduction**

In the space industry, the launch evidently dominates structural requirements. Therefore, in order to demonstrate that a structure will survive the launch, it is analyzed using the finite element method (FEM) and tested in vibration test facilities [1]. During a vibration test, accelerations are usually monitored in order to assess the loads that the structure is experiencing. Ideally, load cells are also installed at the interface of the structure to directly monitor the interface loads and compare them against the design loads. This, however, is not always possible because the use of load cells or strain gauges has many technical, operational, and financial drawbacks [2]. Consequently, the input of the vibration test, i.e., the excitation load of the structure under test, is often adjusted based on the measured accelerations rather than on loads or stresses [3].

One specific example concerns the case where loads need to be monitored at the interface of a subsystem that is part of a larger complex system such as the James Webb Space Telescope (JWST) (Figure 1). The JWST is composed of several subsystems, each of which was tested separately before integration on the JWST. Figure 1 illustrates this problem where, particularly on the bottom, are depicted all the different mechanical test campaigns in which the Near-Infrared Spectrograph (NIRSpec) has been involved. One can observe the NIRSpec optical assembly stand-alone test (OA), followed by the integrated science and instrument module test (ISIM) and the optical telescope assembly test (OTE + ISIM = OTIS). The last mechanical test prior to launch has been recently conducted, the

observatory test with JWST in folded configuration. Beside these mechanical tests, many more test campaigns have been conducted. This path highlights the path followed only by NIRSpec.

In the later testing phase, it is not possible to monitor the loads at the interface of the NIRSpec using load cells because there is no space for accommodating them [4]. Therefore, alternative approaches must be used.

**Figure 1.** The James Webb Space Telescope (JWST) (top left) in deployed configuration and Near-Infrared Spectrograph (NIRSpec) (top right). On bottom, JWST is disassembled towards NIRSpec [5].

One approach is to use the coil current from the shaker, since the applied load can be correlated with the shaker current. However, this approach can only be used to estimate the load in the excitation direction [6]. Strain gauges could be used to recover strains at interfaces and thus loads. However, they require careful calibration to provide a robust indirect measurement of the interface loads. A force measurement device provides six global interface forces or moments and local load cell forces during vibration testing, allowing the measurement of the local forces in three orthogonal directions [6]. However, such devices are not available in every test facility center. Moreover, they are costly, take space that is not taken into account in the design, and often change the system's response, so they must be accounted for in all test prediction analyses [2,5].

The mass operator is a mathematical tool used to derive loads from measured accelerations [3]. It uses measured accelerations in order to calculate the interface loads or stresses representative of the

physical state of a structure. A simple example of a mass operator approach is the sum of weighted accelerations (SWA), which is nothing else but the application of Newton's second law F = *m*a, where a is a vector of measured accelerations, *m* is an equivalent mass matrix, and F is the vector of loads at chosen interfaces. Generally, the mass operator would be created before a vibration test based on finite element analyses' results. The actual computation of the mass matrix can be performed using one of several techniques. With these data, it is then possible during the vibration test to calculate interface loads based on the real-life accelerations, measured by the sensors with no additional hardware [2,3,5]. The authors of [3] provide an extensive review and comparison of mass operators, among them the fitted SWA, the frequency-dependent SWA, and the artificial neural network (ANN).

The fitted SWA is the most straightforward method to calculate mass coefficients. It consists of defining the mass coefficients as design variables of a minimization problem or a curve fitting problem [2,3] where the error *E* between the response calculated with the finite element method and the response provided by the mass operator is minimized as follows [3]:

$$E = \min\{\frac{1}{2} \left| \mathbf{F}\_{FEM}(\omega) - \mathbf{F}\_{MOP}(\omega)^2 \right|\}. \tag{1}$$

However, this method works well only over small frequency ranges with few modes. To solve this issue, the authors of [3] considered a frequency-dependent SWA where the frequency range is split into subranges and a fitted SWA is created independently for each subrange. This method is however not well suited to closely spaced modes. In order to generalize the definition of the mass operator, the authors of [3] presented the use of an artificial neural network (ANN) in two different approaches. The ANN can be used to calculate mass coefficients based on input frequencies and accelerations; this is then a generalization of the frequency-dependent SWA. The ANN can also be used to directly provide the force from accelerations and frequency inputs; this is the most general definition of an operator that can convert measured accelerations into quantities of interest such as forces. Both approaches showed great potential for load estimations [3], but the latter approach has shown many drawbacks especially regarding the ability to generalize the mass operator as an ANN, if the tested structure differs from the analyzed one due to uncertainties such as boundary conditions or material properties. Furthermore, a mass operator as an ANN has not been investigated for the estimation of internal structural stresses.

In the last few years, ANNs have shown many successful applications in various domains, from monitoring structural health [7] to predicting tool life [8]. In [9], convolutional ANNs are used to predict vibrations. In a civil structure, vibration-based structural damage can meanwhile be detected using methods based on machine learning [10]. This paper aims to contribute to this expanding field in structural mechanical engineering by expanding the work done in [3] on the use of ANNs. First and foremost, research work was performed on a large-scale structure within an industrial environment. Second, in addition to standard responses such as acceleration response, stresses were successfully predicted. Moreover, several types of neural networks were investigated that could be used to directly convert measured accelerations into structural stresses and hence enable the live prediction of stress during the vibration test. First, the general methods and considered ANNs are presented. Then, a use case is considered in order to test the different ANNs and get a better understanding and confidence about their ability to predict interface loads or stresses in a robust way. Finally, the paper concludes with a discussion on the findings and potential operational use of the proposed approaches.

#### **2. Materials and Methods**

In practice, mass operators are built using accelerations and stresses or loads. In this case, the accelerations, loads, and stresses were computed using the finite element method [11,12]. Once the mass operators were built and verified, they were deployed during the test to compute stresses and loads based on measured accelerations. In this paper, only ANNs are considered for creating mass operators and MATLAB 2018b (Mathworks, Natick, MA, USA) was used to create and train the proposed ANN.

A prediction of the structure's response is indeed provided by the finite element analysis (FEA) data. The FE model in this specific case needed to comprise two main aspects. One aspect was the accurate modeling of NIRSpec's ceramic bench, as this is the instrument being designed by AIRBUS and is one of the most sensitive parts. The other aspect was the compliance of the surrounding structure, i.e., the structural elements onto which NIRSpec was mounted. The latter aspect was addressed by conducting a coupled load analysis (CLA), where NIRSpec was considered via a standard FE model and the remaining ones, for instance, the instrument module, the optical telescope, and space craft elements, were represented through stiffness representative super-elements. From this CLA, only the forces and moments acting on NIRSpec were derived. In order to have the full picture, phase information was considered as well in order to depict the dynamical compliance of the overall structure. Next, the interface load input was condensed by only considering frequency steps in the vicinity of peaks in direct response as well as in cross-response. This condensation reduced the input size from roughly 42,000 frequency support points down to 700 (1.7%). This, evidently, reduced the computational efforts on our detailed FE model in terms of stress calculation and post-processing considerably, thereby allowing detailed investigation at mechanically interesting frequency ranges to address the first aspect of our FE approach, namely the detailed stress prediction on our ceramic bench.

However, to use FE models for deriving predictions, one has to assume damping. This highlights the major contributor to potential discrepancies, together with overall system nonlinearities stemming from interface mechanics, secondary structures like harnesses, implemented damper elements, and the like.

This infers that the real-life physical state of the structure, namely the interface forces and stresses, needs to be predicted live during the vibration test based on the actual response of the system. Only then will it be possible to adequately adapt the testing level to protect the structure. Unfortunately, only a limited set of data about the state of the structure is available, such as the measured accelerations at discrete locations on the structure [2]. From these accelerations, the stresses or forces working in the tested structure need to be derived using a dedicated method, such as mass operators or ANNs, which is the subject of this investigation. Any method must meet the following requirements:


Artificial neural networks (ANNs) mimic the human brain in its mechanisms to transfer data from one neuron to another (see Figure 2). They consist of a connection of different layers where each layer has a defined number of neurons. A neuron is similar to a computing block defined by an activation function, a set of weights and biases, an input, and an output. For more complex problems, a number of hidden layers can be inserted. Data are propagated through the ANN and the output of each layer represents the input of the next layer. The input to an ANN usually comprises the features and the targets. The feature data are used to predict the target data. In the case of mass operators, the features are the accelerations while the targets are the stresses. Such an ANN architecture can be described as a feedforward neural network [13]. If *p* is considered to be the input to a neuron and *b* the neuron's bias, then the output of that neuron is *a* = *f*(*w* · *p* + *b*), where *f* represents the neuron activation function and *w* is a weighting factor. While *f* is chosen with regard to the problem to be solved, *w* and *b* are both parameters that will be calculated based on a learning rule during the training [13]. During training, the network's neurons are first initialized, i.e., a random set of weights and biases is attributed to each neuron and an activation function needs to be assigned to each neuron in the layer. Then, the training data are forward-propagated through the network; each neuron applies its random weights and biases and its activation function to the input and produces an output, which is further propagated until the data reach the output layer. Afterward, an error function *E* is evaluated, usually the mean squared error (MSE) between the calculated outputs *Yi* and the target values *Ti*:

$$E = MSE = \frac{1}{n} \sum\_{i=1}^{n} \left( T\_i - Y\_i \right)^2\\ \text{with } n = \text{number of data points.} \tag{2}$$

**Figure 2.** Example of the artificial neural network (ANN) structure with 3 neurons in the input layer, 5 hidden neurons in 1 hidden layer, and 1 output layer and the connectivity between neurons.

Finally, the error is back-propagated through the network in order to identify the neurons that are responsible for the error. The latter are then adapted to minimize the error, specifically their weights and biases are altered, while the connections of the neurons producing a low error are reinforced in this process [14].

The recurrent ANN is capable of exhibiting a dynamic behavior where the output of one layer can also be used as the input for a preceding layer. This makes it then possible for the neural network to create a temporary memory and process sequences of inputs [15]. This is particularly relevant as vibration tests are performed using frequency sweep where, for example, the frequency increases with time.

In this study, four different neural network models are compared to each other:


the external feature sequence *ut* , *ut*−1, *ut*−2, *ut*−3, ..., the targets *yt* of the network are also used as features, while a delayed version of them *yt*−1, *yt*−2, *yt*−3, ... is fed back into a feedforward network, according to [16] by *yt* = *f*(*yt*−1, *yt*−2, *yt*−3, ... , *ut* , *ut*−1, *ut*−2, *ut*−3, ...). While the benefit of such an ANN is its memory of the past values, the disadvantage of the NARX model is that each time step *t* of the sequence is treated as an independent layer. This can lead to an extremely deep ANN, resulting in an increase in computational time.

• A recurrent ANN with a bidirectional long short-term memory layer (biLSTM): a recurrent ANN with a biLSTM layer to depict the sequence nature of the input data, taking into account both the last as well as the following time step for every prediction. The biLSTM layer is built up by a cell state and three different gates, namely the input, the output, and the forget gate. From this structure, an ANN with an LSTM layer is able to work with a memory. The prefix bi comes from the fact that it is able to use data from prior as well as following time steps. The input gate determines how much of a new value is used as input into the cell, while the forget gate determines how much of the cell state is to be forgotten, and the output gate determines how much of the cell state is used to compute the cell state of the next cell. These elements are combined through several functions as well as matrix operations. More information regarding the mechanisms of biLSTM layers can be found in [17].

**Figure 3.** Architecture of the considered ANN: (**a**) frequency-dependent ANN and (**b**) nonlinear autoregressive exogenous (NARX) model.

### *2.1. Data Generation*

The data used to develop the proposed method represent the harmonic response of the system over the frequency range over which the structure will be tested, typically 5 Hz to 100 Hz. In this study, in order to train and evaluate the networks to compare the different ANNs, data had to be generated for the three different scenarios. The training, testing, and validation data were generated by conducting a finite element harmonic analysis to compute the accelerations and stresses or forces at given nodes and elements, respectively, over a determined frequency range (5–200 Hz, step of 2 Hz).

This data set was complemented by another set of data that was generated by conducting a finite element harmonic analysis over the same frequency range, with the same structure but different material properties. The Young's modulus of the JWST's optical bench was decreased by 5% in order to shift the natural frequencies of the structure, and to account for material property uncertainty. The remaining material properties were left unchanged. These artificial data helped the trained models to generalize and make better predictions when the material of the test structure was not identical to the material data considered in the finite element analysis.

The data were then divided into a training set, a testing set, and a validation set to enable the assessment of the training progress and process. In order to improve training, the data at natural frequencies of the structure were included in the training data set, while the remaining frequencies were randomly distributed between the training and the validation data set. Thus, it was ensured that the model learned the connections at the natural frequencies that were the most critical, since the structure experiences the stresses with highest amplitudes. In general, a small random number of frequencies can also be used as a test set to evaluate the model's accuracy. However, in this investigation, the models were assessed on independently generated test data with a changed Young's modulus. In this way, uncertainties, as experienced in reality, were taken into account.

### *2.2. Data Processing*

To improve training and reduce the complexity of the problem to be solved, while increasing accuracy and speeding up the training process, the data of the various observations should be normalized. Every observation was scaled to be in a range from minus one to one. To make usable predictions during the test, the scaling parameters should be stored to denormalize the predictions to real-life figures [13].

### *2.3. Academic Use Case*

For the first scenario, the theoretical case consisted of a very simple structure. It served as a benchmark to determine whether the method would be successful. The structure used for this scenario can be seen in Figure 4. The accelerations of 68 of the 90 nodes of the structure were used to predict the base force of the structure in element 100 (highlighted in Figure 4). The use of this excessive and unrealistic number of sensors (which, in reality, is never the case) enabled the assessment of the overall feasibility of the method. In the case where the method failed to predict the structure's base force, it could be deemed impractical. Furthermore, for this first scenario, the base force and not the stress was to be predicted using the accelerations because its relation to the measurable acceleration is more straightforward.

**Figure 4.** The academic use case with six highlighted sensors and element 100.

The second scenario basically represented a variation of the first scenario, where only six sensors were used to predict the base force as highlighted in Figure 4. This reduced number of sensors reflects reality, where the number of available measuring points is highly restricted. Thus, it provides the possibility to estimate the method's performance in a more realistic case with a limited number of sensors.

### *2.4. Industrial Use Case*

Last but not least, the NIRSpec use case represented an application of the method on a real and complex structure with a reduced number of sensors while predicting the element stress. Consequently, in the case where the models are able to make accurate predictions for those three scenarios, the method can be concluded as useful.

The considered use case scenario concerns an actual structure corresponding to the NIRSpec instrument's optical bench. The optical bench is equipped with ten sensors to predict the stress in one element (see Figure 5). This case makes it possible to evaluate the potential of the method for a real and complex structure with more complex eigenmodes and a limited number of sensors. In this use case, the stress is to be predicted because it represents a good indicator for the structure's physical state and enables the evaluation of the model's performance to predict other metrics than the force, as in [2].

**Figure 5.** NIRSpec module's optical bench with the ten most stressed elements highlighted, with a perspective from below in order to depict all ribs stiffening the bench.

In order to determine the stresses at the highlighted elements in Figure 5, a FEA was conducted with MSC NASTRAN version 2018.1.0. The structure was discretized by 96,073 nodes and 104,182 elements spanning from one-dimensional elements (i.e., rods and beams) over shell (triangular and quadrangular) to solid elements (tetrahedral, hexahedral, and pentangular). As boundary conditions, the FEA was subjected to forces and moments for each kinematic mount derived from the CLA, where phase information was provided as well. This approach is referred to as the multi-excitation method (MEM). All dynamic analyses were based on modal decomposition, and they are therefore modal frequency response analyses ranging from 5 Hz to 200 Hz. For each of these frequency steps, the von Mises stress was evaluated at the ten selected elements and used to train the ANN. It should be noted that this equivalent stress was used for this paper only. AIRBUS has developed a dedicated equivalent stress suited to predicting ceramic failure.

#### **3. Results**

Table 1 summarizes the number of neurons for the different models. The ideal number of neurons was determined in a trial and error way, aiming for the best performance of the MSE while keeping the number of neurons small. The number of delays of the NARX model was determined in the same way. As objective function, the mean squared error (MSE) was used for all the models. While the input differs for each model (see Table 1), the element stress or the base force were used as feature data for all ANN. Furthermore, except for the biLSTM, the Nguyen–Widrow layer initialization function [18] was used to generate the initial weights and biases of the neurons for all ANNs. For the biLSTM, the input weights were initialized with the Glorot/Xavier initializer [19], using an orthogonal initialization for the recurrent weights, while the forget gate bias was initialized with ones and the remaining biases with zeros. All models also share the same activation function, namely the hyperbolic tangent sigmoid, except for the biLSTM, which uses the sigmoid function for the gate, the hyperbolic tangent function for the cell state and hidden state, and the linear activation function for the regression layer. The used training algorithm is also indicated in Table 1.



The NARX model was designed in open-loop form, where the input targets were used as feedback features. The model used as many inputs as sensors and had one hidden layer, a defined number of delays, and one output layer per stress (see Table 1). The network with a biLSTM layer consisted of a sequence input layer with as many neurons as inputs, followed by a biLSTM layer. Then, there was one fully connected layer and, lastly, the regression output layer with its linear activation function and as many neurons as outputs.

After the setup of the architecture of the different models, they were trained on the setting as listed in Table 1. Figure 6a,b shows an example of the learning curves for the pretrained ANN and the NARX model for the industrial use case, respectively.

**Figure 6.** Training plots of (**a**) the pretrained ANN and (**b**) the NARX model for the industrial use case.

The blue curves represent the MSE over the training epochs for the training data, the green curves represent the error for the validation error, and the red line represents the error for the test data. The green circle marks the optimal validation performance. The training curves of every model decrease (as clearly shown in Figure 6 for the pretrained ANN and the NARX model), indicating that the models are able to learn the underlying data. The remaining gap between the validation curves and the training curves can be ascribed to the generalization of the data. As the final validation error is not too large, training can be concluded to be successful. After the training, the models were deployed on the test data. The results of their predictions can be seen in the following section.

The evaluation of the training on the theoretical cases shows that the NARX model is extremely sensitive to the resolution of the frequency range. Dividing the frequency range into 600 rather than

100 steps proves to increase the quality of training tremendously. This does not make a difference for the feedforward and the biLSTM networks, as it only increases computation time.

### **4. Discussion**

In this section, the results of the three different use cases are discussed. Therefore, the different models' predictions of the test data are compared to the FEA and evaluated with a regression analysis.

### *4.1. Theoretical Case with 68 Sensors*

The trained models were deployed to make predictions using the testing feature data. These data were generated by conducting the second FEA and reducing the Young's modulus of the academic structure's material by 5%. As can be seen in Figure 7b, the NARX model makes inaccurate predictions of the first three frequency steps. These steps were used as delays for training. The remaining frequency steps are predicted accurately. The ANN with biLSTM layer was the most delicate to train, and it makes more or less accurate predictions. It wrongly predicts the heights of some peaks, for instance, the peaks at frequency steps 40 and 90, as can be seen in Figure 7b. The frequency-dependent ANN predicts the heights of the peaks correctly (see Figure 7a), whereas the form of the peak at frequency step 50 is poorly predicted. The pretrained ANN (Figure 7d) makes slightly inaccurate predictions about the height and the form of the peak at frequency step 50 as well as the peak at step 90, corresponding to the peaks that are shifted the most in the testing data set compared to the training data set.

**Figure 7.** Actual and predicted element force over frequency steps for the academic case with 68 sensors: (**a**) frequency-dependent ANN, (**b**) NARX with more frequency steps, (**c**) ANN with bidirectional long short-term memory layer (biLSTM) layer, and (**d**) pretrained ANN.

This evaluation can be illustrated by a regression analysis. Therefore, the predicted values, called output in Figure 8, are plotted against the calculated base force by FEA, referred to as targets, and a regression line is computed. Figure 8 shows the resulting regression plots, where the black dots represent the data points and the blue line represents the regression line.

**Figure 8.** Regression plots of the predictions for the academic case with 68 sensors: (**a**) frequency-dependent ANN, (**b**) NARX, (**c**) ANN with biLSTM layer, and (**d**) pretrained ANN.

An overview of the respective regression coefficient of the models, i.e., the slope of the regression line in Figure 8 and the root mean square error (RMSE) between the target and the predicted base force, allows a quantitative comparison of the models. Table 2 summarizes these metrics for the prediction on the test data. The second values (R = 0.9644 and RMSE = 0.0305 N) for the NARX model are the regression coefficient and the RMSE, respectively, for the data without the first three values representing the delays. It can be determined that the pretrained model has the lowest performance, whereas the frequency-dependent ANN makes the most accurate predictions.


**Table 2.** Regression coefficients and RMSE for the different models trained on the academic case with 68 sensors.

The evaluation of the first theoretical case leads to the conclusion that the method has proven to be successful, even though these first predictions were made with an unrealistically large number of sensors.

### *4.2. Theoretical Case with 6 Sensors*

The trained models were deployed to predict the test data with the shifted frequencies, resulting in the predictions seen in Figure 9. While, in this case, the biLSTM was not able to make adequate predictions, as can be seen in Figure 9c, the other models predicted the element force mostly accurately. It can be noted that the NARX model's predictions of the first frequency steps used as delays are not accurate, while the remaining curve is correctly predicted, as in Figure 9b. The pretrained ANN as well as the frequency-dependent ANN make slightly wrong predictions about the height and the form of some of the peaks (see Figure 9a,d).

**Figure 9.** Finite element analysis (FEA) and predicted element force over frequency steps for the academic case with 6 sensors: (**a**) frequency-dependent ANN, (**b**) NARX with increased number of frequency steps, (**c**) ANN with biLSTM layer, and (**d**) pretrained ANN.

A regression analysis enforces the above observations, as can be seen in Figure 10. The regression coefficients and the RMSE for the predictions on the testing data are summarized in Table 3. While the ANN with biLSTM layer performs worst, resulting from its delicate training, the NARX model makes the most adequate predictions, even despite the delays included in the above calculation. The frequency-dependent ANN and the pretrained ANN perform similarly.

**Figure 10.** Regression plots of the predictions for the academic case with 6 sensors: (**a**) frequency-dependent ANN, (**b**) NARX, (**c**) ANN with biLSTM layer, and (**d**) pretrained ANN.

**Table 3.** Regression coefficients and RMSE for the different models.


This second theoretical case proves that most of the ANNs are also successful in the case, where the number of sensors is restricted, as is often the case in reality.

### *4.3. NIRSpec Use Case*

After the successful training of the ANN models, they were then applied to the NIRSpec use case. The test data with varied stiffness were generated by reducing the Young's modulus of the material of the optical bench plate by 5% of the initial value in the FE model. Figure 11 shows the prediction of the four considered models against the calculated stress with FEA with respect to the frequency steps. The frequency steps divide the considered frequency range (5–200 Hz) into equal steps, the steps and the corresponding normalized stress values both resulting from the FEA.

**Figure 11.** Stress calculated by FEA and predicted stress by ANN over the frequency steps for the NIRSpec use case: (**a**) frequency-dependent ANN, (**b**) NARX, (**c**) ANN with biLSTM layer, and (**d**) pretrained ANN.

As can be seen from Figure 11b, the NARX model predicts the element stress curve without major deviation. It only seems to struggle slightly with the first frequency step, which can be ascribed to the use of that first step as delay in the model's architecture. The frequency-dependent ANN in Figure 11a struggles to predict the peak stress values, for instance, around frequency step 400 and also with the shapes of a few peaks, mainly at the last frequency steps from steps 500 to 600. The pretrained ANN, as seen in Figure 11d, also seems to struggle with the shapes of a few peaks, especially at the end of the

frequency range. The fact that both these models struggle at the end of the frequency range can be ascribed to the fact that this represents the mode that was shifted the most by changing the material properties. The NARX model, however, does not face any difficulties with this. Figure 11c shows that the biLSTM also makes more or less accurate predictions, while it was the most delicate to train. However, it also struggles with some peak stress values and completely omits the mode at frequency step 500. Figure 12 shows the resulting regression plots.

**Figure 12.** Regression plots of predictions for the NIRSpec use case with 35 sensors: (**a**) frequency-dependent ANN, (**b**) NARX, (**c**) ANN with biLSTM layer, and (**d**) pretrained ANN. The output values and the target values are normalized.

Table 4 makes clear that both recurrent ANNs (NARX and biLSTM) perform the best, the NARX model having the best regression coefficient of 0.9936 and the smallest error of 0.0454 MPa. In contrast, the pretrained ANN makes the largest error of 0.1219 MPa and has the poorest regression coefficient of 0.9724, which is still a high accuracy of prediction.


**Table 4.** Regression coefficients and RMSE for the different ANN models on the NIRSpec data.

#### **5. Conclusions**

In this work, four different artificial neural network models were tested for their ability to predict stresses related to the excitation frequency for the launch scenario of the Near-Infrared Spectrograph. In addition, they were tested on a theoretical case with differing numbers of sensors. With correctly trained ANNs, the monitoring of real shaker tests and thus the avoidance of overstressing the test specimens are possible.

The conducted investigation allowed the comparison of all ANN models with respect to the requirements formulated in Section 2. From Tables 2–4 in Section 4, it can be clearly deduced that the NARX model is the most promising one. Figure 7, Figure 9, and Figure 11 illustrate this conclusion. Thus, a trained NARX model could be used during vibration tests and decrease the time of prediction of the given structural parameters, which is crucial for adapting and notching the input load of the shaker in time.

As could also be seen, the recurrent ANN generally performs better than the feedforward ANN, handling the input as concurrent data. The ANN with biLSTM layer is able to make accurate predictions, even though its training is not conducted thoroughly due to the lack of data for a deep ANN. However, if such an ANN is trained with more data and more varied data, it possibly makes the most accurate predictions. In future studies, the potential of this network can be further investigated. For instance, the training data set for this model could be increased by including training data from FEA with several varied Young's moduli or varied damping parameters or by varying other material parameters that have an impact on the natural frequency.

While the NARX model performs the best, its performance is highly dependent on the number of available frequency steps. For example, if the frequency range to be predicted (from 5 to 200 Hz) is poorly resolved and only divided into 100 instead of 600 frequency steps, this situation has a negative effect on the quality of the NARX model's predictions. The other networks are not as sensible to the division of the frequency range. In particular, the ranges of eigenmodes should have a higher resolution by having additional frequency steps. Each time step *t* of the sequence is treated as a single layer, which can lead to an extremely deep ANN. On the one hand, this results in increasing the computational time, but on the other hand it increases the performance of the network. The performance of the NARX model can thus be maximized by training it with as many frequency steps as possible. However, in practice, this can be a hurdle, as the required higher resolution may not be available during the test. Table 5 outlines the qualitative evaluation for the different models in terms of the requirements introduced in the introduction.



All in all, it can be stated that the conducted research was able to outline a methodology capable of live predicting equivalent stresses of a structure under vibration testing, thereby allowing failure to be evaluated for different types of material via yield criteria.

**Author Contributions:** L.W. carried out the presented research within her thesis, while R.O., M.S., and K.M.d.P. supported this study as supervisors. All authors have read and agreed to the published version of the manuscript.

**Funding:** This research received no external funding.

**Acknowledgments:** The authors would like to thank the Structural Analysis Department of AIRBUS Defense & Space GmbH in Immenstaad, Germany, for having supported the research by providing the necessary budget.

**Conflicts of Interest:** The authors declare no conflict of interest.

### **References**


**Publisher's Note:** MDPI stays neutral with regard to jurisdictional claims in published maps and institutional affiliations.

© 2020 by the authors. Licensee MDPI, Basel, Switzerland. This article is an open access article distributed under the terms and conditions of the Creative Commons Attribution (CC BY) license (http://creativecommons.org/licenses/by/4.0/).

## *Article* **PreNNsem: A Heterogeneous Ensemble Learning Framework for Vulnerability Detection in Software**

### **Lu Wang 1,2,3,\*, Xin Li 1,3, Ruiheng Wang 1,3, Yang Xin 1,2,3,\*, Mingcheng Gao 1,3 and Yulin Chen <sup>2</sup>**


Received: 21 October 2020; Accepted: 6 November 2020; Published: 10 November 2020

**Abstract:** Automated vulnerability detection is one of the critical issues in the realm of software security. Existing solutions to this problem are mostly based on features that are defined by human experts and directly lead to missed potential vulnerability. Deep learning is an effective method for automating the extraction of vulnerability characteristics. Our paper proposes intelligent and automated vulnerability detection while using deep representation learning and heterogeneous ensemble learning. Firstly, we transform sample data from source code by removing segments that are unrelated to the vulnerability in order to reduce code analysis and improve detection efficiency in our experiments. Secondly, we represent the sample data as real vectors by pre-training on the corpus and maintaining its semantic information. Thirdly, the vectors are fed to a deep learning model to obtain the features of vulnerability. Lastly, we train a heterogeneous ensemble classifier. We analyze the effectiveness and resource consumption of different network models, pre-training methods, classifiers, and vulnerabilities separately in order to evaluate the detection method. We also compare our approach with some well-known vulnerability detection commercial tools and academic methods. The experimental results show that our proposed method provides improvements in false positive rate, false negative rate, precision, recall, and F1 score.

**Keywords:** cyber security; vulnerability detection; word embedding; deep learning

### **1. Introduction**

Software vulnerabilities are one of the root causes of cybersecurity issues. Despite the improving software quality in academia and industry, new vulnerabilities have been exposed, causing huge losses. A large number of vulnerabilities were proven by Common Vulnerabilities and Exposures [1].

Vulnerability detection is an effective method for discovering software bugs. Overall, vulnerability detection methods can be categorized as static and dynamic methods. High coverage and low false positives are the advantages of static methods and dynamic methods, respectively. Many studies of source-code-based static analysis during the software development stage considered open-source tools [2–4], commercial tools [5–7], and academic research tools [8–10] to reduce dynamic runtime costs. Most of these tools are based on pattern matching. The pattern-based methods require experts to manually define vulnerability features for machine learning or rule matching. In summary, there are two significant drawbacks with the existing solutions: (1) relying on human experts and lacking automation; (2) the high false positive rate and low recall. Both are described below.

The existing solutions rely on human experts to define vulnerability features. It is difficult to guarantee the correctness and comprehensiveness of features because of complexity, even for experts. This is a highly subjective task, because the knowledge and experience of experts influence the results. It follows that there cannot be a unified standard for manually extracting features. Therefore, we must reduce or eliminate reliance on intense labor from human experts.

The existing solutions produce a high false positive rate and low recall. Most new tools detect all possible vulnerability patterns when matching the rules, regardless of context, structure, or semantics. As such, the detection results have low recall and a high false positive rate. Because of the fixed nature of rule detection, errors occur when detecting the same vulnerability across projects. Although machine learning has been applied to solve the above problems [11,12], the results are still unsatisfactory. These problems suggest that we must achieve a low false positive rate while maintaining a high recall rate.

For the two problems that are mentioned above, the featured engineer should be the core of the solution. Firstly, automated feature extraction will overcome the need for human labor. Secondly, precise vulnerability features will improve the precision of the result. As an automated feature tool, deep learning [13,14] was proposed for vulnerability detection. Applicable deep learning models can automatically and precisely learn various low- and high-level features. However, there are many deep-learning models, and one problem is selecting a model for achieving automation and a lower false positive rate.

In this paper, the proposed framework, which involves pre-training for vector representation, neural networks for automated feature extraction, and ensemble learning for classification (PreNNsem), focuses on improving the feature engineering of vulnerability detection. To validate PreNNsem, we applied different models of pre-training, neural networks, and ensemble learning. Word2vec continuous bag-of-words (CBOW), multiple structural convolutional neural networks (CNNs), and stacking classifiers were found to be the best combination by comparing classification results. In summary, we make four contributions:


The remainder of this paper is structured, as follows: Section 2 reviews related work. Section 3 presents the PreNNsem framework. Section 4 describes our experimental evaluation of PreNNsem and comparison results and Section 5 discusses problems and concludes the paper.

### **2. Related Work**

### *2.1. Prior Studies Related to Vulnerability Detection*

From the degree of automation, previous vulnerability detection methods can be divided into three categories: (i) Manual methods: Many static vulnerability tools, such as Flawfinder [2], RATS [16], and Checkmarx [17], are based on vulnerability patterns, which are defined by human experts. Because pattern matching depends on the rule base, the false positives and/or false negatives are often high. (ii) Semi-automatic methods: Features are manually defined (code-churn, complexity, coverage, dependency, and organizational [18]; code complexity, information flow, functions, and invocations [19]; missing checks [20,21]; and, abstract syntax tree (AST) [22,23]) for traditional machine learning, such as k-nearest neighbor and random forest. MingJian Tang et al. [24] used artificial statistical characteristics to analyze vulnerability trends and dependencies with the Cupra model in multivariate time series. (iii) More automatic methods: Human experts do not need to define features. POSTER [25] presented a method for automatically learning high-level representations of functions. VulDeePecker [13] is a system showing the feasibility of using deep learning to detect vulnerabilities. Venkatraman S et al. [26] proposed a hybrid model by employing similarity mining and deep learning architectures for image analysis. Vasan D et al. [27] analyzed malware images while using a CNN in order to extract features and support vector machine (SVM) for multi-classification.

PreNNsem is an automated approach and an end-to-end vulnerability detection framework. When compared with the manual and the semi-automatic methods, our method abandons subjectivity. Therefore, the features obtained by our method are more persuasive and comprehensive. POSTER extracts features from the level of the function with a coarser granularity. PreNNsem extracts richer features directly from the word level. Compared with VulDeepecker, we have expanded the corpus in the word embedding layer to increase the precision of semantic expression. We used heterogeneous classifiers to improve the stability and accuracy of classification.

### *2.2. Prior Similar Studies*

Pattern-based approach. Z. Li et al. [13] generated vectors from code gadgets using Word2vec, like us. They used Recurrent Neural Network (RNN)-based deep learning and SoftMax for learning classification. Liu S et al. [28] also used RNN for learning high-level representations of abstract syntax trees (ASTs). Duan X et al. [29] extracted semantic features while using code property graph (CPG), obtaining feature matrices by encoding the CPG. Finally, they used attention neural networks for learning classification. Lin G et al. [30] proposed a deep-learning-based framework with the capability of leveraging multiple heterogeneous vulnerability-relevant data sources for effectively learning latent vulnerable programming patterns.

Similarity-based approach. Vinayakumar R et al. [31] used a Siamese network to identify the similarity and deep learning architectures to classify the domain name. Zhao G et al. [32] encoded code control flow and data flow into a semantic matrix. They designed a new deep learning model that measures code functional similarity that is based on this representation. Xiao, Yang et al. [33] used a novel program slicing to extract vulnerability and patch signatures from the vulnerability function and its patched function at the syntactic and semantic levels. Subsequently, a target function was identified as potentially vulnerable if it matched the vulnerability signature but did not match the patch signature. Nair, Aravind et al. [34] examined the effectiveness of graph neural networks for estimating program similarity by analyzing the associated control flow graphs. In [35], they built a graph representation of programs called flow-augmented abstract syntax tree (FA-AST) and applied two different types of graph neural networks (GNNs) on FA-AST to measure the similarity of code pairs.

When compared with any pattern-based approach, the similarity-based approach is sufficient for detecting the same vulnerability in target programs. However, it cannot detect vulnerabilities in some code clones, including deletion, insertion, and rearrangement of statements. PreNNsem is categorized as a pattern-based approach to vulnerability detection. The existing pattern-based approaches have two problems: first, the extracted information's granularity is rough; second, the data set used to learn vulnerability patterns is insufficient. In contrast to the studies reviewed above, PreNNsem has two advantages: first, it directly extracts features from code granularity in order to avoid the loss of information during feature abstraction. Second, by expanding the corpus, it can learn from common programming patterns and improve generalization capabilities.

### **3. Design of PreNNsem**

### *3.1. Hypothesis*

High-level programming languages, like C and JAVA, are designed for humans, and are closed to human expression. They have many similarities with natural language. For example, programming languages are probabilistic in definitions and context-dependent in grammar. Hence, we can borrow concepts from natural language processing (NLP) for vulnerability detection. We consider the concepts of code language and natural language as follows:

• Concepts: A slice of code—sentences, keywords, statements, characters, numbers—words.

For natural language processing, we encoded each word as a vector and each sentence as a sequence of vectors. Therefore, distributed representations are based on an assumption; words that occur in the same context tend to have similar meanings [36].

For vulnerability detection, we separated the code segment by tokenization and represented as a sequence of vectors. It has the same form as NLP. Therefore, we made assumptions for vulnerability detection.

**Hypothesis 1.** *In a programming language, a token's context is its preceding and succeeding tokens. Tokens that occur in the same context tend to have similar semantics.*

**Hypothesis 2.** *The same types of vulnerabilities have similar semantic characteristics. These characteristics can be learned from the context of vulnerabilities.*

*3.2. Overview of PreNNsem*

We aimed to automatically detect vulnerability with feature engineering, while using PreNNsem to achieve the goal. Figure 1 shows the process of our proposed framework, in which we take the sliced code as the input, and the output is whether the vulnerability is detected. In this paper, an extended corpus and sample data are transferred from C/C++ source code using security slice [13] and are represented as a sequence of numbers called "vectorize". Subsequently, PreNNsem needs three steps that are related to each other. In this process, the intermediate data serve as the input to the next layer. In the first step, pre-training uses a vectorized extended corpus to generate distributed representation, and the output is a vector of tokens. The embedding layer takes the output as the initialization parameter. In the second step, the sample data pass through the embedding layer. The neural network and the SoftMax layer obtain high-level features. In the third step, supervised learning takes the feature as the input to determine whether the sample is vulnerable.

**Figure 1.** Overview of the proposed PreNNsem (pre-training for vector representation, neural networks for automated feature extraction, and ensemble learning for classification) framework.

### *3.3. Source Code Pre-Processing*

According to code lexical analysis, we remove some semantically irrelevant symbols (e.g., }{) in order to improve efficiency. We divided segment code into words by spaces and symbols (e.g., +−\*/=). Deep learning models take vectors (arrays of numbers) as the input. When working with text, we had to develop a strategy to convert strings to numbers before feeding it to the model. Firstly, we indexed each word as a unique number. For example, we assigned 1 to "i", 2 to "for", 4 to "=", 3 to "100", and so on. Subsequently, we encoded the sentence "for i = 100" as a dense vector like [2, 1, 4, 3]. However, different sentences have different lengths. To unify data length for model input, we defined the max fixed-length as 400 according to sample data. There are two cases: if the sentence length is less than 400, zero will be padded; otherwise, the excess will be removed. Note that because pre-training requires a similar representation (Section 3.4) as embedding, extending the corpus only indexes the words in this step.

#### *3.4. Word Embedding Pre-Training*

According to Hypothesis 1, the same vulnerability pattern has similar semantics and structure in source code, and code representation is significant for pattern analysis. Word embeddings [37,38] are a type of word representation that allow words with similar meaning to have a similar representation. As such, a similar representation has the same vulnerability pattern. Vulnerability code and non-vulnerability code can be distinguished.

In this section, word embedding is divided into random, static, and non-static [39] embedding, according to the initialization method. Random embedding means all words are randomly initialized and then modified during training. Static embedding means word vectors are pre-trained from distribution representation and kept static and unchanged during training. Non-static embedding means pre-trained vectors from Word2vec are fine-tuned for each task and trained with a deep learning model. We used continuous bag-of-words (CBOW) to obtain densely distributed representation.

How does CBOW work? As shown in Figure 2, CBOW is a three-layer network. Firstly, we convert each word into a one-hot encoding form as the CBOW input. *xCk* represents the vectors of surrounding words given a current word *xt*, where *C* is the number of surrounding words and *k* is the number of vocabulary words. Every *x* is a matrix with a dimension of *k* × 1. Secondly, we initialize a weight matrix *Wk*×*<sup>d</sup>* between the input layer and hidden layer. In *Wk*×*d*, *d* is a word vector size. In the hidden layer, each *x* left multiples with *W* and then adds up to the average as the output *hd* of the hidden layer. *hd* is a matrix with a dimension of *d* × 1

$$h\_d = \frac{\mathcal{W}^T \cdot \mathbf{x}\_1 + \mathcal{W}^T \cdot \mathbf{x}\_2 + \dots + \mathcal{W}^T \cdot \mathbf{x}\_C}{\mathcal{C}} \tag{1}$$

Next, we initialize a weight matrix *Ud*×*<sup>k</sup>* between the hidden layer and the output layer. In the output layer, *h* left multiples with *U* and then adds the *So f tMax* activation function. *y* and *x* have the same dimensions, but each element of *y* represents each word's corresponding probability distribution.

$$y = Sof tMax(\mathcal{U}^T \cdot h) \tag{2}$$

The CBOW model is a method of learning. Finally, *y* is not the last result we want; the intermediate product *W* is the last word vector. In our proposed method, we define surrounding words windows *C* = 10 and word vector size *d* = 200. According to Figure 2, in CBOW, we want to predict the word of the target location. We use the location's surrounding words as input and then obtain the probability distribution of vocabulary words. Finally, we select the word with the highest probability as the final result. In this process, the weight matrix *W* is constantly adjusted as the final word vector matrix.

**Figure 2.** Overview of continuous bag-of-words (CBOW) models.

#### *3.5. Representation Learning*

According to Hypothesis 2, common semantic characteristics can be learned from the context of vulnerabilities. Traditionally, the characteristics of manual definition are crucial to machine learning classification. They transform training data and then augment them with additional features to increase the efficacy of machine learning algorithms. However, with deep learning, we can start with raw data, as features will be automatically created by the neural network when it learns.

In this section, we choose CNN and Long Short-Term Memory (LSTM) as the base deep learning model. Figure 3 shows the selected deep learning model used in this study. Firstly, in order to better learn the structure and semantics of the data, we used transfer learning to build the embedding layer for neural networks. Secondly, we sequentially combined three concatenated CNNs and one CNN as a network model. Thirdly, we added a one-dimensional max-pooling layer and a dropout layer for dimension reduction.

**Figure 3.** Features of convolutional neural networks (CNNs).

What are the features learned by CNN? As shown in Figure 3, we represent a code segment of length *n* as:

$$S = [v\_{0\prime}v\_{1\prime} \dots v\_{n}] \tag{3}$$

where *vi* is the *i*th word vector in the segment. A filter *wh* is used to extract new features combined by the following *h* words. *h* is the size of the filter *wh*; *c<sup>h</sup> <sup>i</sup>* represents the feature generated by combining the *i*th word and the *h* words following it.

$$c\_i^h = f(w\_h v\_{i:i+h-1} + b),\tag{4}$$

where *f* is a non-linear function, *b* is a bias, and:

$$
\upsilon\_{i:i+h-1} = \upsilon\_i \oplus \upsilon\_{i+1} \oplus \dots \oplus \upsilon\_{i+h-1} \tag{5}
$$

where ⊕ is the concatenation operator. According to the filter size, there are four different types of filters, including size three, size four, size five, and size six filters. We considered a filter in order to generate a new feature. The larger the filter size, the richer the context of consideration. In our experiment, we applied multiple filters to multiple features. CNN is characterized by parallelism, and each filter is not related to each other, which improves the execution efficiency.

According to Figure 4, LSTM processes one code segment at a time, and the loop allows for information to be passed from one step of the network to the next. This chain-like nature reveals that the recurrent neural networks are intimately related to sequences. They are the genetic architecture of the neural network to use for such data.

In the application of extracting sequence features, RNN can obtain more comprehensive inter-sequence information than CNN. In theory, CNN can only consider consecutive words' characteristics, and RNN can consider the entire sentence. However, in the experiment, the more information stored, the longer the processing time. Even if LSTM has chosen to forget some of the information, there is still the problem of prolonged time consumption for long sequences.

**Figure 4.** Features of Long Short-Term Memory (LSTM).

### *3.6. Heterogeneous Ensemble Learning*

Recent experimental studies [40] showed that the classifier ensemble may improve the classification performance if we combine multiple diverse classifiers that disagree with each other. Neural network models are nonlinear and have a high variance, which can cause problems when preparing a final model for making predictions. A solution to the high variance of neural networks is to train multiple models and combine their predictions. Ensemble is a standard approach in applied machine learning to ensure that the most stable and best possible prediction is made. We replaced the simple SoftMax classifier with the stacking learning classifier to improve vulnerability classification.

According to [41], heterogeneous ensemble methods have emerged as robust, more reliable, and accurate, intelligent techniques for solving pattern recognition problems. They use different basic classifiers in order to generate several different hypotheses in the feature space and combine them to achieve the most accurate result possible.

How does the stacking framework work? Figure 5 shows the conception of the stacking ensemble. Stacking is used to combine multiple classifiers generated using different learning algorithms *L*1, ··· , *LN* on a training dataset *S* and a testing dataset *S* , which consist of samples *si* = (*xi*, *yi*) (*xi*: feature vectors, *yi*: classifications). Define *C* as a classifier. Thus,

$$\begin{cases} \ C\_i = L\_i(S), & i \in 1, \cdots, N \\ \ C\_{meta} = L\_{meta}(S') \end{cases} \tag{6}$$

where *Ci* is the base classifiers and *Cmeta* is a meta classifier. In the first stage, we choose two base algorithm, *L*<sup>1</sup> = *LogisticRegression* and *L*<sup>2</sup> = *MultinomialNaiveBayesian*. We divide training data into *K* = 10 parts, one of which is the validation subset *Sd*, *d* ∈ 1, ··· , *K*. We trained *Ci* on *S* and evaluated while using 10-fold cross-validation. For the model trained in each step *d*, we complete predictions on the test set *Y* . ∀*i* ∈ 1, ··· , *N*, and ∀*d* ∈ 1, ··· , *K*.

$$\begin{cases} \mathbf{C}\_i^d = L\_i(\mathbb{S} - \mathbb{S}\_d) \\ \mathbf{Y}\_i^d = \mathbf{C}\_i^d(\mathbb{S}\_d) \\ P\_i^d = \mathbf{C}\_i^d(\mathbb{S}') \end{cases} \tag{7}$$

**Figure 5.** Stacking classifier framework.

Subsequently, each *Y<sup>d</sup> <sup>i</sup>* is stacked into a feature *Ai*. Take the average of all *Pi* to obtain feature *Bi*.

$$\begin{cases} A\_i = Y\_i^1 \uplus Y\_i^2 \uplus Y\_i^3 \uplus \dots \uplus Y\_i^d\\\ B\_i = average(P\_i^1, P\_i^2, P\_i^3, \dots, P\_i^d) \end{cases} \tag{8}$$

In the second stage, we concatenate *Ai* to form a new training data *A* and concatenate *Bi* to form new testing data *B*.

$$\begin{cases} \quad A = A\_i \oplus A\_i \oplus A\_i \oplus \cdots \oplus A\_i \\\quad B = B\_i \oplus B\_i \oplus B\_i \oplus \cdots \oplus B\_i \end{cases} \tag{9}$$

Finally, the meta-classifier is trained on *A* and predict the result of *B*.

$$\begin{cases} \mathsf{C}\_{meta} = L\_{meta}(A) \\ \textit{Result} = \mathsf{C}\_{meta}(B) \end{cases} \tag{10}$$

### *3.7. Construct Framework*

Now, we build a vulnerability detection framework and propose an implementation. Our proposed framework (PreNNsem) consists of distributed representation, deep learning, and machine learning. We chose an implemented solution, Word2vec CBOW, for distributed representation, multiple structural CNNs for deep learning, and heterogeneous ensemble classifier (stacking) for machine learning.

We tokenize the extended corpus in order to obtain word vectors for similar code representations. Sample data are indexed and sequenced as input to the deep learning model. Word vectors are used as a parameter of the embedding layer. The processed sample data are embedded with neural networks as the input to generate an automatic feature extraction model. Subsequently, features are trained by machine learning and predict whether the samples are vulnerable or not.

### **4. Experiments and Results**

### *4.1. Evaluation Metrics*

Let true positive (TP) denote the number of vulnerable samples detected correctly, false positive (FP) denote the number of normal samples detected incorrectly, false negative (FN) denote the number of vulnerable samples undetected, and true negative (TN) denotes the number of clean samples classified correctly. Running time and memory were considered for testing resource consumption.

We used five metrics to measure vulnerability detection results. The *FP* rate (*FPR*) metric measures the ratio of falsely classified normal samples to all normal samples.

$$FalsePositiveRate(FPR) = FP/(FP+TN) \tag{11}$$

False negative rate (*FNR*) measures the ratio of vulnerable samples classified falsely to all vulnerable samples.

$$FalseNegativeRate(FNR) = FN/(FN + TP) \tag{12}$$

Precision measures the correctness of the detected vulnerabilities.

$$Precision(P) = TP/(TP + FP) \tag{13}$$

Recall represents the ability of a classifier to discover vulnerabilities from all vulnerable samples.

$$Recall(R) = TP/(TP + FN) \tag{14}$$

The *F*1 measure considers both precision and recall.

$$F1 - Measure(F1) = 2 \ast P \ast R / (P + R) \tag{15}$$

The low *FPR* and *FNR*, and high *P*, *R*, and *F*1 metrics indicated the excellent performance in the experimental results. Low resource consumption is also vital.

### *4.2. Experimental Setup*

In terms of collection programs, the Software Assurance Reference Dataset (SARD) [42] serves as the standard dataset to test vulnerability detection tools with software security errors, and the National Vulnerability Database (NVD) [43] contains vulnerabilities in production software. In the SARD, each program case contains one or multiple common weakness enumeration Identifiers (CWE IDs). In the NVD, each vulnerability has a unique common vulnerabilities and exposures identifier (CVE ID) and a CWE ID to identify the vulnerability type. Therefore, we finally collected the programs with CWE IDs that contained vulnerabilities.

We chose two types of vulnerabilities as detection object: buffer overflow (CWE-119) and resource management error (CWE-399). We also collected some other C/C++ programs on NVD as an extended corpus for pre-training. Table 1 summarizes statistics on training data and pre-training data. The datasets were preliminarily processed by [13]. We collected data from the 10,440 programs related to buffer error vulnerabilities and 7285 programs related to resource management error vulnerabilities from the NVD; we also collected 420,627 programs as an extended corpus to improve code representation. The extended dataset focuses on 1591 open-source C/C++ programs from the NVD and 14,000 programs from the SARD. It includes 56,395 vulnerable samples and 364,232 samples that are not vulnerable.


**Table 1.** Statistics on training data and pre-training data.

Regarding training programs vs. target programs, we randomly chose 80% of the programs that were collected as training programs and 20% as target programs. This ratio is applied when dealing with one or both types of vulnerabilities. We also used 10-fold cross-validation over the training set to select the model and used the test set to test the obtained model.

For the deep learning model, we implemented the deep neural network in Python with Keras [44]. We ran experiments on a Google Colaboratory [45] with Nvidia K80, T4, P4, or P100 graphics processing unit (GPU). Genism [46] Word2vec was used to train the word embedding layer. Scikit-learn [47] provides KNeighborsClassifier, RandomForestClassifier, MultinomialNB, and LogisticRegression algorithm as classifiers. Every experiment monitored valid F1 as a condition of early stopping in 10 epochs. Table 2 shows the parameters in the representation learning phase.


**Table 2.** Tuned parameters for representation learning.

#### *4.3. Comparison of Different Embedding Methods*

We compared CBOW and Skip-gram to verify the effect of the embedding method. Different types of tokens were selected to test the methods. Then, their embedded results were lowered to a two-dimensional diagram, as shown in Figure 6. CBOW performed better. After embedding, semantically similar words are closer to each other in the diagram, which means that word embedding extracts token semantic information in the context code structure. CBOW is more accurate than the information extracted by Skip-gram.

**Figure 6.** CBOW and Skip-gram tokenization results.

#### *4.4. Comparison of Different Neural Networks*

We trained six neural network models on the CWE-119 dataset to evaluate the different neural network models for representation learning. Note that we only indexed and sequenced the dataset instead of vectorizing, so the training dataset is two-dimensional in this section. The models contained: (1) three sequential CNNs with128 filters each; (2) two long short-term memory (LSTM) layers, with a 128-dimensional output; (3) bidirectional long short-term memory (BiLSTM) with a 128-dimensional output; (4) combined CNN (128) and BiLSTM (64 × 2); (5) combined CNN, BiLSTM, and Attention; and, (6) sequentially combine three concatenated convolutional layers and one convolutional layer. To avoid the disappearance of gradients during RNN structural training, in networks that use LSTM, we use sigmoid as the last dense layer activation function, which is different from our previous papers [15]. Table 3 shows the comparison results.

**Table 3.** Comparison of different representation learning models. CNN, convolutional neural networks. convolutional; BiLSTM, bidirectional long short-term memory; FPR, false positive rate; FNR, false negative rate; P, precision; R, recall; F1, f1-score.


Within the margin of error, sequential CNNs and concatenated CNNs achieved the best FPR result. Sequential LSTM has balanced performance and achieved the best results in FNR, recall, and F1. It also has excellent precision. Of the CNNs, concatenated CNNs perform better. Therefore, we tested the embedding layer on sequential LSTM and concatenated CNNs.

### *4.5. Combination of Different Embedding Methods and Different Neural Networks*

According to [39], we divided the pre-training into random, static, and non-static initialization, and then defined the vectors' dimension as 200. Random initialization means that all words are randomly initialized and then modified during training. Static initialization means that all words are pre-trained from Word2vec to generate vectors and non-trainable in work. Non-static initialization means that pre-trained vectors are fine-tuned for each work. We used the CWE-119 dataset and tested

different pre-training methods on the sequential LSTM and concatenated CNN models. For training the Word2vec embedding layer, we used CWE-119 as the corpus and SySeVR [48] data as the extended corpus. In this section, we count the memory and training time of the models to compare their resource consumption. Table 4 shows the comparison results.

**Table 4.** Comparison of different embedding methods on different corpora. Memory, memory consumption; Time, time consumption.


As shown in Table 4, CNN excelled in terms of FNR and recall, and LSTM excelled for FPR and precision. However, the time consumption of LSTM was 18 times that of CNN. For both, we obtained the following conclusions. According to the corpus, the extended corpus has better metrics because the more words we trained, the more appropriate the obtained vector. According to the false rate (FPR + FNR), P, R, and F1, we found that trainable embedding is better than static embedding because the fine-tuning can be adjusted to each work. The memory of training is almost the same, because the input sample data and the embedding size were the same. Less time was required for static embedding because the increase in trainable parameters leads to increased training time.

In conclusion, when considering the results and efficiency, we chose non-static CNN with extending the corpus as our final deep learning model.

### *4.6. Comparison of Different Classification Algorithms*

Through Section 4.4, we observed that concatenated CNNs are the appropriate deep learning model to extract features. In Table 5, we directly use traditional machine learning after the word-embedding layer. In order to improve the classification results, we chose a different ensemble learning model [49] to substitute the simple activation sigmoid after CNNs. We chose boosting and bagging as our homogeneous ensemble model, including gradient boosting decision tree (GBDT) and random forest (RF). We used stacking for generating ensembles of heterogeneous classifiers, logistic regression (LR) and MultinomialNB (NB) as the base classifiers, and RF as the final classifier. For comparison with ensemble classifiers, we also chose traditional classifiers, including KNeighbors (KN), NB, and LR. Finally, Table 5 shows the comparison results.

**Table 5.** Comparison of different classification algorithms. ML, machine learning; KN, KNeighbors; LR, logistic regression; NB, MultinomialNB; GBDT, gradient boosting decision tree; RF, random forest.


Table 5 shows that the first two lines did not use representation learning to extract features, and the classification effect was poor. Machine learning with CNNs performed better than traditional machine learning. We concluded that word embedding can only extract the granular features of words. CNNs can obtain the features of code structure, not only word semantics. Therefore, multiple granularity features help to improve the performance of the classifier.

The results of the last three lines (CNN + Ensemble) were generally better than those of lines three to five. Although CNN + NB produced the best recall results (93.9%), its precision was worse, at 85.4%, resulting in an F1 score of only 89.5%, which represents comprehensive performance. Low precision leads to spending more effort and time on the wrong detection results. Therefore, ensemble learning can further improve vulnerability detection. Of the three ensemble learnings, the stacking that was used in this article yielded the best results because it combines multiple diverse algorithms to generate several different hypotheses in the feature space and achieves the most accurate result possible. Though time consumption is higher compared to traditional machine learning methods, we emphasize the detection results for vulnerability detection tasks. Therefore, the increased time consumption is within an acceptable range.

Above all, we selected the most appropriate implementation of PreNNsem through our experiments; it consists of non-static pre-training with an extended corpus, concatenated CNNs representation learning, and stacking classifier.

### *4.7. Ability to Detect Different Vulnerabilities*

As shown in Table 6, the proposed method was applied to the six datasets. We tested our model on the buffer overflow CWE-119 dataset and resource management error CWE-339 dataset in order to evaluate our method's detection ability for different types of vulnerabilities. To validate our approach's generalization capabilities, we selected three different types of vulnerability datasets: Array Usage, API Function Call, and Arithmetic Expression. Each type of dataset contains multiple CWE vulnerabilities. Array Usage (87 CWE IDs) accommodates the vulnerabilities related to arrays (e.g., improper use of array element access, array address arithmetic, and address transfer as a function parameter). API Function Call (106 CWE IDs) accommodates the vulnerabilities related to library/API function calls. Arithmetic Expression (45 CWE IDs) contains the vulnerabilities that are related to improper arithmetic expressions (e.g., integer overflow). Finally, we combined the three to form a hybrid vulnerability dataset, Hybrid Vulnerabilities.


**Table 6.** Comparison of different classification algorithms.

According to the results, we found that the method for detecting specific vulnerabilities performs well. Resource management error has the best result, F1 Score, at 98.6%. Our approach also performs well in detecting the same type of vulnerability. API Function Call has the lowest F1 score, but the result was still no less than 91.5%. The method performed better on hybrid vulnerability datasets than on the same vulnerability datasets, because having more data can improve the model's indicators. In summary, our approach performs well on a variety of data sets.

### *4.8. Comparative Analysis*

We compared our best experimental results with those of state-of-the-art methods in order to verify the performance of the proposed method. We chose open-source static analysis tool Flawfinder [2], commercial static analysis tool Checkmarx [17], vulnerable code clone detection tool VUDDY [50], and academic deep learning methods VulDeePecker [13], DeepSim [32], and VulSniper [29]. Our three reasons for selecting these were: (1) these tools represent the state-of-the-art static analyses for vulnerability detection; (2) they directly operate on the source code; and, (3) they were available to us. Flawfinder and Checkmarx represent manual methods based on static analysis. VUDDY is suitable for detecting vulnerabilities incurred by code cloning. VulDeePecker, DeepSim, and Vulsnipper use deep learning to analyze source code. All of the results in Table 7 are based on the CWE-119 dataset. The results of Checkmarx and VulDeePecker were obtained from [13]. The results of DeepSim and VulSniper were obtained from [29].


**Table 7.** Comparison of experimental results obtained using the proposed method and those using state-of-the-art methods.

Our method outperformed the state-of-the-art methods. Because these traditional tools depend on the rule base, they incurred high FR (FPR and FNR) and lower precision, recall, and F1. VulDeePecker was found to be better than the other tools, with a precision of 91.7%. However, VulDeePecker's recall rate was low, only 82.0%, because it does not expand the corpus during the word embedding phase. DeepSim and VulSnipper extract features from the intermediate code, which loses some of the information. Accordingly, both precision and recall do not work well. Our method automatically extracts vulnerability features directly from the slice source code and does not rely on the rule library. In the word embedding phase, we expand the corpus to obtain richer semantics. Therefore, we improved vulnerability detection capabilities. When compared to VulDeePecker, we improved FPR by 1.4%, FNR by 10%, pPrecision by 3.7%, recall by 9.9%, and F1 by 7%.

### **5. Conclusions**

In this paper, according to existing detection methods, we analyzed vulnerability detection's core problem, which is the lack of proper feature extraction. Firstly, we researched vulnerability detection methods related to deep learning. We then presented the PreNNsem framework to detect vulnerabilities by analyzing source code. We drew some insights that were based on the collected dataset, including explanations for word embedding, deep learning model, and classifier comparisons in vulnerability detection. We used six different vulnerability datasets to prove our method's generalization ability. Finally, we compared the results that were obtained with our method with those of the state-of-art tools and academic methods to validate the improvement in vulnerability detection.

In terms of practicality, our summary is as follows: (i) our method performs well on various mixed vulnerability data sets. Our method can detect various vulnerabilities. (ii) Because we analyze the source code from the perspective of analyzing text, other high-level language source code vulnerabilities can also use our framework. (iii) Each part of PreNNsem also supports other methods, which proves the scalability of the framework.

However, our method has several limitations: (i) our method only focuses on the source program, and our framework can not be applied in executable programs. (ii) Our approach relies on VulDeePecker's [13] code snipping, which will be proposed and integrated into our future framework. (iii) Although we chose several deep learning models, we need to evaluate other models. (iv) The sample length is padded if it is shorter than the fixed length and cut off if it is longer; future works need to investigate how to handle vectors' varying lengths.

**Author Contributions:** Conceptualization, L.W. and X.L.; methodology, L.W. and X.L.; software, L.W.; validation, L.W.; formal analysis, L.W.; investigation, L.W. and R.W.; resources, L.W.; data curation, L.W.; writing—original draft preparation, L.W.; writing—review and editing, L.W.; visualization, L.W.; supervision, L.W.; project administration, L.W.; funding acquisition Y.X., Y.C., and M.G. All authors have read and agreed to the published version of the manuscript.

**Funding:** This work was supported by the Major Scientific and Technological Special Project of Guizhou Province (20183001), the Foundation of Guizhou Provincial Key Laboratory of Public Big Data (No. 2018BDKFJJ021), the Foundation of Guizhou Provincial Key Laboratory of Public Big Data (No. 2017BDKFJJ015), and the National statistical scientific research project of China (2018LY61, 2019LY82).

**Conflicts of Interest:** The authors declare no conflict of interest.

### **References**


**Publisher's Note:** MDPI stays neutral with regard to jurisdictional claims in published maps and institutional affiliations.

© 2020 by the authors. Licensee MDPI, Basel, Switzerland. This article is an open access article distributed under the terms and conditions of the Creative Commons Attribution (CC BY) license (http://creativecommons.org/licenses/by/4.0/).

## *Article* **Using BiLSTM Networks for Context-Aware Deep Sensitivity Labelling on Conversational Data**

### **Antreas Pogiatzis \* and Georgios Samakovitis \***

School of Computing and Mathematical Sciences, University of Greenwich, Old Royal Naval College, Park Row, Greenwich, London SE10 9LS, UK

**\*** Correspondence: a.pogiatzis@greenwich.ac.uk (A.P.); g.samakovitis@greenwich.ac.uk (G.S.)

Received: 31 October 2020; Accepted: 08 December 2020; Published: 14 December 2020

**Abstract:** Information privacy is a critical design feature for any exchange system, with privacy-preserving applications requiring, most of the time, the identification and labelling of sensitive information. However, privacy and the concept of "sensitive information" are extremely elusive terms, as they are heavily dependent upon the context they are conveyed in. To accommodate such specificity, we first introduce a taxonomy of four context classes to categorise relationships of terms with their textual surroundings by meaning, interaction, precedence, and preference. We then propose a predictive context-aware model based on a Bidirectional Long Short Term Memory network with Conditional Random Fields (BiLSTM + CRF) to identify and label sensitive information in conversational data (multi-class sensitivity labelling). We train our model on a synthetic annotated dataset of real-world conversational data categorised in 13 sensitivity classes that we derive from the P3P standard. We parameterise and run a series of experiments featuring word and character embeddings and introduce a set of auxiliary features to improve model performance. Our results demonstrate that the BiLSTM + CRF model architecture with BERT embeddings and WordShape features is the most effective (F1 score 96.73%). Evaluation of the model is conducted under both temporal and semantic contexts, achieving a 76.33% F1 score on unseen data and outperforms Google's Data Loss Prevention (DLP) system on sensitivity labelling tasks.

**Keywords:** BiLSTM; BERT; NLP; context-aware

### **1. Introduction**

Diminishing information privacy in communication ecosystems often requires consumers to directly manage their own preferences. Yet, this easily becomes a tedious task considering the amount of information exchanged on a daily basis. Research indicates that, most often, consumers lack the information to take appropriate privacy-aware decisions [1], and even where sufficient information is available, long-term privacy is traded-off for short-term benefits [2] (indicatively, a Facebook-focused empirical user study pointed out that 20% of participants and 50% of interviewees reported direct regret after posting sensitive information publicly) [3]. Automatically identifying sensitive information is critical for privacy-preserving technologies. Sensitive data most often appear as unstructured text, making it vulnerable to privacy threats, due to its parseable and searchable format. This work solely concentrates on text data.

The sensitivity of a piece of information is directly shaped by the context it is provided in. For the purposes of this analysis, we offer a taxonomy of four distinct context classes, for use as sensitivity categories. We derive these respectively by the meaning, interaction, precedence, and preference associated with any piece of information:

**Semantic Context:** formed on the basis of the semantic meaning of a term. As, for instance, in the case of homonyms or homographs, the semantic meaning of a sequence affects its sensitivity.

**Agent Context:** depending upon the agents participating in the transmission of information. Here, the relationship between participating actors determines sensitivity. For instance, patient–doctor sharing of medical history is non-sensitive for this set of actors, but sensitive otherwise.

**Temporal Context:** defined by information precedence that affects the significance of a term. Here, previously introduced definitions of a term qualify its sensitivity. If, for example, a string sequence is introduced as a password, it carries higher sensitivity than if it was introduced as a username.

**Self Context:** defined by the user's personal privacy preferences. Notably, what is considered private varies across users due to cultural influences, personal experiences, professional statute, etc. For example, ethnic origin may be considered sensitive information for one individual but not for another.

The above list of context classes is not exhaustive, nor are these mutually exclusive. One or more contexts may simultaneously influence sensitivity differently. Information may, however, be sensitive, regardless of context (e.g., credit-card numbers, email, passwords, insurance numbers). In this work, we use sensitive information as any coherent sequence of textual data from which a third party can elicit information that falls under the categories proposed in the P3P standard (Table 1).


**Table 1.** Sensitivity classes used for multi-class classification. (Source: P3P Standard [4]) and the count of sensitive tokens in our dataset per class.

This paper proposes a context-aware predictive deep learning model that can annotate sensitive tokens in a sequence of text data, more formally defined as a "token sensitivity labelling" task. We develop our models for *semantic* and *temporal* contexts, as this provides an adequate proof-of-concept, and as incorporating all four contexts requires extraneous methodologies, which are planned for future work. We reduce the problem of sensitivity annotation to a multi-class classification problem and follow deep learning techniques that were proven effective in similar labelling tasks, such as Named Entity Recognition (NER) and Part-of-Speech (POS) tagging [5–7]. We develop a context-aware classifier based on the BiLSTM + CRF architecture with word embeddings and WordShape as features. We address dataset limitations by developing data generation algorithms to combine synthetic and real data, and experimentally identify the best performing model architecture and feature combination. The models are first evaluated on a dataset generated with the assistance of synthetic sensitive data, reaching an F1 score of 96.73%. We evaluate our model in two settings: (i) one addressing sensitive information annotation under the influence of temporal context and; (ii) one comparing against Google's DLP system with semantic context variations. In the former our model reaches an F1 score of 76.33%, and in the latter, results highlight the resiliency of our system on semantic noise by outperforming Google's DLP in all sensitive information type annotation. We summarise the key contribution of this work as:


Notably, our work aims to investigate the performance of our BiLSTM + CRF architecture for context-aware token sensitivity labelling, as opposed to discovering the optimum model for the task; this is clearly reflected in the experimental design where variants of that model are evaluated in temporal and semantic contexts. In the absence of similar research for more complex architectures, comparisons with other similar BiLSTM-based architectures are performed to initially investigate the behaviours of simpler models.) The paper is organised into eight sections: Section 2 first provides the background and related work, followed by a discussion of our methodological approach and model background (Section 3); we then devote Section 4 to our dataset creation strategy, and follow up with our experimental design (Section 5). Implementation and results are outlined in Section 6, and our model is then evaluated (Section 7). A discussion on our contributions, limitations and future work is ultimately offered in Section 8.

### **2. Related Work**

The majority of the literature on sensitivity labelling is associated with Data Loss Prevention (DLP) systems [8–12], notably focusing on classifying sensitivity at the document level. Other research uses sensitivity classification for confidential information redaction on declassified documents [13–15], where classification is often performed at finer granularity that reaches the token level. More general applications include quantifying information leakage in Open Social Networks (OSNs) for privacypreserving technologies [16,17].

Earlier work on text sensitivity annotation focused on heuristics. *Sweeney* [18] introduced a template matching approach with boolean hashtables to capture Personally Identifiable Information (PII) in medical records with 99–100% accuracy; however, this approach is challenging when dealing with unstructured data and requires manual work to be expanded to other domains. Gomez-Hidalgo et al. [11] used NER to pinpoint sensitive tokens in a corpus. Although their assumptions about the sensitive nature of Named Entities are reasonable, static NER is context-free and only captures a limited part of sensitive content. Sanchez et al. [19] presented an information-theoretic approach by introducing the concept of information content (IC), defined later in Section 3.4. They annotated as sensitive any noun-phrase with an IC value higher than a threshold *β*. Again, this is a context-free approach and using only IC as a sensitivity measure is problematic in some cases, as it is directly related to the size and content of the corpus used.

A different approach for text sensitivity annotation uses statistical machine learning models. The work of Hart et al. offered a DLP system, which can classify sensitive enterprise documents using a Support Vector Machine (SVM) classifier trained on a WikiLeaks-based corpus [10]. Later, MacDonald et al. built a novel SVM sensitivity classifier by mixing concepts from both NLP and machine learning for government document declassification, using POS n-gram tags as a sensitivity load indicator [15]. Alzhrani et al. [9] proposed another DLP system with more fine-grained granularity. Their work effectively combined unsupervised and supervised methods to create a similarity-based classifier operating on a paragraph level and trained on an ad-hoc annotated WikiLeaks corpus. Building on their previous work, MacDonald et al. further enhanced their SVM classifier by introducing pre-trained Word2Vec and GLoVe word embeddings [20,21] and found that word embeddings can significantly contribute to a more accurate model. Our work is the closest to their approach.

Research using deep learning for textual sensitivity annotation is relatively sparse, with only a few authors moving to that direction. Ong et al. built a context-aware DLP system, which follows a hierarchical structure to achieve fine-grained granularity [8], and used LSTM neural networks to achieve binary sensitivity classification at the token level. Despite the novelty of their hierarchical approach, we argue that the dataset size used for the experiments may fall short of the requirements for deep learning applications. Jiang et al. [17] used LSTM networks for identifying personal health experiences from tweets. Similarly, previous work underlines the high utility of word embeddings and LSTM/BiLSTM networks for textual data mining and classification through social media posts and other sources [22–24]. Although these are framed in other problem domains, their results highlight the advantages of deep learning methodologies against conventional machine learning models for similar tasks.

### **3. Background and Approach**

We assess the performance of specific BiLSTM variants in classifying and labelling information sensitivity in a particular context. Token relationships in this problem definition are sequential; therefore, we focus on sequential models such as BiLSTM and CRF, with supervised training. We chose a bidirectional LSTM, over simple LSTM, for better modelling of temporal nuances in a corpus, as BiLSTMs perform forward and backwards passes on sequential data and model data dependencies in both directions. Finally, we introduce auxiliary features, namely, POS tags, Information Content (IC) and WordShape (WS).

### *3.1. LSTM Networks*

First, we define the BiLSTM Recurrent Neural Network (RNN) more formally. A recurrent neural network is a special type of normal artificial neural network (ANN) which is capable of modelling sequential data by having recurrent connections [25]. In essence, it maintains a hidden state, which can be considered as a "memory" of previous inputs. This is driven by the fact that each neuron represents an approximation function of all previous data.

Figure 1 illustrates the architecture of a simple RNN. The input units {..., *xt*−1, *xt*, *xt*+1,...} where *x* = (*x*1, *x*2, *x*3,..., *xN*), are connected to the hidden units *ht* = (*h*1, *h*2,..., *hM*) in the hidden layer, via connections defined by weight matrix *WIH*. Every hidden unit is connected to the next one with recurrent connections given by *WHH*. Each hidden unit is therefore formulated by:

$$h\_t = f\_H(\mathbf{o}\_t) \tag{1}$$

where:

$$
\rho\_t = \mathcal{W}\_{IH} + \mathcal{W}\_{HH} h\_{t-1} + b\_h \tag{2}
$$

*Fh* is a non-linear function such as tanh, ReLU or sigmoid, etc., and *bH* is the bias vector. The hidden layer is also connected with the output layer with weights *WHO*. Lastly the outputs *yt* = (*y*1, *y*2,..., *yP*) are defined by:

$$y\_t = f\_O(\mathcal{W}\_{HO}h\_t + b\_o) \tag{3}$$

**Figure 1.** Simple Recurrent Neural Network (RNN) architecture. For the sake of simplicity biases are ignored. Source: [26].

In the same manner as the hidden layer, *fO* is the activation function and b is the bias vector.

Although, this model maintains a memory of previous states, in practice it suffers from the vanishing gradient problem, thus becoming impractical for long-term dependencies [27]. A special type of RNN called Long Short Term Memory (LSTM) was published in 1997, which overcomes this issue [28]. LSTM cells follow a more sophisticated mechanism with the introduction of a complex cell that utilises "forget" gates to selectively choose what to forget. An illustration of the LSTM is given in Figure 2.

**Figure 2.** Long Short Term Memory (LSTM) unit. For the sake of simplicity biases are ignored.

The state of an LSTM memory unit adopts the following mathematical formulation:

$$\begin{aligned} \dot{\mathbf{r}}\_{t} &= \sigma(\mathbf{W}\_{xi} + \mathbf{W}\_{\text{hi}}\mathbf{h}\_{t-1} + \mathbf{W}\_{ci}\mathbf{c}\_{t-1} + \mathbf{b}\_{i}) \\ \dot{f}\_{t} &= \sigma(\mathbf{W}\_{xf}\mathbf{x}\_{t} + \mathbf{W}\_{\text{lf}}\mathbf{h}\_{t-1} + \mathbf{W}\_{cf}\mathbf{c}\_{t-1} + \mathbf{b}\_{f}) \\ \mathbf{c}\_{t} &= f\_{t} \otimes \mathbf{c}\_{t-1} + \dot{\mathbf{r}}\_{t} \otimes \tanh(\mathbf{W}\_{\text{xc}}\mathbf{x}\_{t} + \mathbf{W}\_{\text{hc}}\mathbf{h}\_{t-1} + \mathbf{b}\_{c}) \\ \boldsymbol{o}\_{t} &= \sigma(\mathbf{W}\_{\text{xc}}\mathbf{x}\_{t} + \mathbf{W}\_{\text{ho}}\mathbf{h}\_{t-1} + \mathbf{W}\_{\text{co}}\mathbf{c}\_{t} + \mathbf{b}\_{o}) \\ \mathbf{h}\_{t} &= \boldsymbol{o}\_{t} \otimes \tanh(\mathbf{c}\_{t}) \end{aligned}$$

To clarify, the subscripts correspond to the initials of what each matrix represents (i.e., *Wh f* is the hidden forget weight matrix). Additionally, *f* , *i*, *o* and *c* correspond to the forget, input, output and cell gate vectors. With these in mind, an LSTM network would resemble the initial RNN architecture provided above but with LSTM cells instead.

However, due to its architecture an LSTM network can only perform forward passes on sequential data, which ultimately means that the data dependencies are only modelled uni-directionally. An intuitive way to overcome this limitation is to use an exact replica of the LSTM network but in reverse. Thus, combining these two together, a Bidirectional LSTM (BiLSTM) is created which can be used to model dependencies bidirectionally.

### *3.2. Embeddings*

Embeddings provide an efficient mechanism for encoding the semantic and temporal context information. They were proven very effective in practical neural network applications for encoding complex data structures to information-rich continuous vectors in latent space. Examples include mappings from word to vectors [20,21,29], document to vectors [30], graphs to vectors [31], etc. Especially word vectors have become the norm in neural networks for Natural Language Processing (NLP). We address two main types of embeddings: context-free and contextual. Context-free embeddings are static and do not capture any context about the word. Conversely, a contextual embedding encapsulates information about the surrounding context of the word, hence assigning embedding vectors *vi* for each of the *i* contexts in the embedding space.

### *3.3. Conditional Random Fields*

Conditional Random Fields (CRFs) is a discriminative model mostly used for labelling and separating sequential data [7]. The underlying concept of CRFs attempts to model a conditional probability distribution over a label sequence given an already observed sequence. The use of Conditional Random Fields (CRFs) in this work is inspired by state-of-the-art results of Neural CRFs [32] particularly combined with BiLSTMS in fundamental sequence labelling NLP tasks such as NER and POS tagging [6,33]. Neural CRFs implement neural networks to extract high-level features for use as inputs to CRFs for labelling. CRFs' architecture models the conditional distribution P(x|y) over a label sequence, given an already observed sequence, rather than the joint distribution P(x,y), thus outperforming traditional machine learning models, such as Hidden/Maximum Entropy Markov Models (HMMs, MEMMs) [7], in sequential labelling.

### *3.4. Auxiliary Features (POS, IC, WS)*

Although relying solely on word embeddings for prediction can deliver acceptable performance, we propose three additional features to enrich the learning significance of the semantic and syntactic abstraction of word embeddings: POS tags, Information Content (IC) and WordShape (WS).

MacDonald et al. [15] showed that specific POS n-gram sequences can be correlated with sensitive information. This provided an incentive to use POS tags as auxiliary features in this work. We extract the POS tags of the sequences using SpaCy's v2.1 (https://github.com/explosion/spaCy) POS tagger, preferred for its speed, industrial strength and convenience.

Information Content (IC) offers a quantitative metric for general purpose sensitivity [19]. IC estimates the information carried by a specific token *t* in a given context, relying on the information-theoretic assumption that rare terms typically convey more information than general terms (e.g., "surgeon" vs. "doctor"). Thus, we expect that incorporating IC in our model can better classify sensitive tokens of particular classes. An example is labelling passwords: due to their random nature, the embeding of a password will most often result in the Out-Of-Vocabulary (OOV) embedding. Hence, apart from the surrounding context, there is nothing differentiating it from other OOV tokens. With this in mind, introducing IC features can be advantageous. As this, however, directly depends on the size and content of the corpus used to calculate the information content, a massive general corpus is required for general purpose estimations. To extract the IC of tokens we used the Bing search engine API (https://azure.microsoft.com/en-gb/services/cognitive-services/bing-web-search-api/). Notably, Google could be more accurate, since it maintains the largest and most updated page index to date. Yet, it was not possible to use Google's search API for this project due to its API restrictions on repetitive use. As suggested by Sanchez et al. [19] only nouns are queued for the IC extraction as the rest of the part-of-speech types have a dynamic meaning that effectively makes search engine queries unreliable for IC calculation. These tokens are assigned an IC value of 0 by default.

Lastly, a morphological word feature, *WordShape*, was introduced in our experiments. We use the term *WordShape* as the textual representation of a word's morphology, which is implemented through transforming words into character sequence templates. This can contribute to learning the sensitivity correlations with structured data such as credit-card numbers, national insurance numbers, and phone numbers. To generate the WordShape features, SpaCy's v2.1 parser was used (https: //spacy.io/usage/linguistic-features).

#### **4. Dataset Creation**

Public annotated datasets for token sensitivity labelling are rare. Even where such data can be collected, the subjects' privacy is at risk through deanonymisation [34]. We therefore generate synthetic data for training purposes. Training deep learning models on synthetic data often comes with generalisability challenges due to overfitting, although recently, several scholars successfully trained such models on synthetic data in real-world settings [35]. In this paper, we developed a methodology (Section 4.2) to generate a large-enough synthetic annotated dataset, combining real-world conversational data with random sensitive information. At the highest level, two data generation approaches are used: (1) One featuring sensitivity classes that are redacted by default and (2) one with sensitivity classes that are not redacted by default. For the former, consider the data generation process for the "Online contact information" sensitivity class. For example, we choose

"email" for the topic. We populate a list of conversational patterns relevant to that topic, ('my email . . . ', 'You can contact me at', . . . ) and use that list to search in Reddit threads using Google's BigQuery. Then, we generate a synthetic concrete value for the sensitive part (i.e., the email address), combine it with the conversational pattern used for the query and inject it into the results as part of the conversation (comments chain). Because Google BigQuery redacts sensitive terms, we cannot ascertain their original position. Hence, the modified sentence arrays are injected at a random index. In the latter case (sensitivity classes not redacted by default) we follow a slightly different approach. Again, we use the related conversational pattern to search in Reddit threads, but this time the sensitive part is already included in the comments after the pattern. Therefore, we generate concrete values beforehand and include them in the filtering process. For instance, the "Religion" topic has a pattern "I believe in" and concrete values "Christianity", "Buddhism", etc., which are used as part of the query. Then we annotate the position of the sensitive concrete value and expand it until the next verb or noun is found in the sentence. Lastly, we annotate the conversational pattern "I believe in" along with the next noun. This is not a fail-proof methodology but it generated an acceptable format of the dataset that is tested in our experiments. Our synthetic datasets are derived from real conversational data from Reddit. Where used, manual annotations are part of the training process, in the same way as, for instance, human annotation in object recognition. With this approach, we increase the size of the dataset with much lower effort than that for gathering more real-world data, and we also mitigate sensitivity class imbalance [36–38]. We then test our proposed dataset on real unseen data.

### *4.1. Sensitivity Classes*

To increase the semantic significance of our sensitivity classification we used 13 distinct sensitivity classes, based on the categories specified in W3C's Platform for Privacy Preferences (P3P) [4] presented in Table 1. Note that the P3P specification originally defined 17 categories but we intentionally omitted 4 as they were very open-scoped; These are *Navigation and Click-stream Data*, *Interactive Data*, *Content* and *Other*. This made it easier to automate the aggregation of data for unambiguously defined sensitivity classes for dataset creation. It also allowed to examine model performance for each sensitivity class individually, and potentially extract more relevant insights. The choice to use the W3C P3P classes was made to (i) leverage existing legal and social expertise that informed the development of the platform, and (ii) support the openness and extensibility of our methods by allowing third-party user applications to be built on top of our work. Since P3P works with other standardised languages, such as APPEL (a P3P Preference Exchange Language), a third party automated process can use this language to trigger an action on leakage of annotated sensitive data.

### *4.2. Data Generation Process*

For our synthetic data generation, we extracted real-world text data from a main discussion theme and then injected sensitive data at random positions. For dataset creation purposes sensitivity classes are further divided into relevant *topics*, thus achieving higher granularity: for instance, the *Financial Information* sensitivity class would include topics like *Payment History*, *Credit Card Numbers*, *Account Balance*, etc. The process was repeated per topic in each sensitivity class. Because the source used for the real-world data redacts sensitive information by design, the injected text was automatically annotated as sensitive and the rest of the corpus as non-sensitive. To cover still for sensitivity classes that leak secondary private information (e.g., *Preference*), we developed an alternative algorithm for annotating such, not so obvious, sensitive data. A very high level flow chart of both algorithms is shown in Figure 3.

Note that the algorithms require three distinct datasets: (i) the sensitive text patterns of the topic (also used to query the data in the first place); (ii) the concrete values and; (iii) the conversational data associated with the topic. To aggregate real conversational data from the web, we utilised the publicly available Reddit comments dataset as given by Google's BigQuery (https://bigquery.cloud.google.com/ dataset/fh-bigquery:reddit\_comments). The rich querying capabilities of BigQuery allowed convenient filtering for specific topics by relevant keywords.

**Figure 3.** Synthetic data generation (i) when sensitive data is redacted by default (**left**) and (ii) when it is not (**right**).

Figure 3 (left), outlines the algorithm for augmenting our dataset with artificially created sentences that include at least one sensitive term. A sensitive term is encoded as a unique token, and then replaced with a concrete value. Concrete values are the actual mock values that make the sentence sensitive, for example, passwords, email, address These were randomly generated using Mockaroo (https://mockaroo.com/), a generation engine for realistic data.Then, the Reddit comments are split into sentences and the newly created sensitive sentence is injected at a random index within the sentence array. That sentence is annotated as sensitive and the remaining sequence as non-sensitive. The algorithm in Figure 3 (right) annotates sensitive information already present in Reddit comments. We selectively built a collection of phrases that correspond to a topic in a sensitivity class and then used these to query the Reddit comments. The phrases are then matched per comment, and all tokens from the matching phrase until the next verb or noun, are annotated as sensitive.

After data generation, we built an annotated dataset of multiple sensitivity classes. Table 1 shows the number of annotated tokens per sensitivity class in the dataset. The classes are notably imbalanced because an upper bound constraint on how many records can be generated per topic was introduced. While this constraint was imposed to avoid overfitting the model to one specific topic, the observed imbalance is realistic as it is often seen in real-world datasets [39]. Overall, 12% of the total tokens in the dataset are annotated as sensitive, with the remaining 88% labelled as non-sensitive.

#### **5. Experimental Design**

Our experiments attempt to answer three questions: (i) Which of our BiLSTM + CRF model architectures and word embeddings combination is better suited for the task; (ii) whether character embeddings contribute to increased accuracy on the model; and (iii) whether IC, POS tags and WordShape features increase model performance. For that purpose, 11 distinct BiLSTM + CRF model variations were derived for experimentation, as shown in Table 2. In addition, five simpler models were selected for bench marking against the main BiLSTM-CRF variants.

We chose the hyperparameters of the models based on empirical results [40] and preliminary experiments. The input sequences were trimmed to 205 timesteps (maximum sequence length in the dataset) and 128 units of BiLSTM cells were used, followed by a dropout layer of 40%. For experiments with character embeddings, a character input sequence of 32 length was used in a 1D convolutional layer with a kernel of size 3, followed by a dropout layer of 50% and a Global Max Pooling layer. Training was performed in batches of 128 for 100 epochs but an early stopping callback was employed to interrupt training if validation loss was not improved for 5 consecutive epochs.

The test subset consists of 10% of the dataset and was entirely left out for use as a completely unbiased evaluation dataset. The remaining 90% was further split to 80% training (used to fit the classifier weights), and 20% validation data for tuning the hyperparameters. For preprocessing, the text was converted to lowercase, all contractions were reversed and punctuation removed. Unlike conventional preprocessing pipelines, where a stopword removal stage is involved, we decided to keep the stopwords, as they are part of our automated annotation process when creating the dataset.

**EXPERIMENT A:** *Model design and word embeddings:* Interestingly, there is a wide range of BiLSTM applications in NLP often featuring state-of-the-art results [41–43]. Similarly, the integration of BiLSTM with a CRF layer is also rapidly gaining research attention and has delivered promising results in NLP tasks [6,33,44]. For the above reason, Experiment A focuses on reviewing the performance effect of word embeddings on BiLSTM + CRF models (see EXP. A5–A7 in Table 2). Four alternative variants (EXP. A1–A4) were also used, to offer a basis for comparison.

Even though existing literature has demonstrated the advantages of using contextualised word embeddings in numerous occasions [45–47], we perform this experiment to support this hypothesis for this problem setting as well. Of the many available word embedding extraction techniques [48] we shortlisted 3 methods for the evaluation, as a sufficient minimum to cover for all pre-training and contextualisation possibilities. The first uses initially random vectors to derive embeddings in the training process, and is here referred to as Randomised Word Embedding (RWE). The vocabulary of the RWE embeddings was built on the training split of our dataset. RWE is later used as a baseline for comparing with pre-trained word embeddings. The remaining two choices were pre-trained word embeddings, namely GLoVe and BERT [5,21] to cover context-free and contextualised embeddings, respectively. GLoVe is a popular context-free word embeddings model and BERT comes from Google's BERT, a language model that achieved state-of-the-art performance in many standard NLP tasks. In the case of BERT, the output of the last encoder layer was used as embeddings.

**EXPERIMENT B:** *Character embeddings:* Character-level embeddings were successfully combined with word embeddings to improve performance before [49,50]. As the integration of character embeddings allows for learning language-agnostic morphological features, we attempt to quantify the resulting performance improvement, if any.

It has been shown that LSTM and Convolutional Neural Network (CNN) character embeddings exhibit similar performance improvements when combined with BiLSTM models, with a slight advantage for CNN [51]. We implement character embeddings through an additional extension model based on a one-dimensional convolution layer. The output of the extension model is concatenated with the input of the best-performing embedding types.

**EXPERIMENT C:** *Auxiliary features experiment:* Section 3.4 provides a detailed account of the three auxiliary features (POS, IC, WS) used to enhance model performance, and also articulated sources and selection strategies. In the experimental setting, we introduce 7 feature variations on top of the best performing BiLSTM + CRF model architecture. The aim is to practically evaluate the contribution of these features (and their combination) on performance. The variations are illustrated in Table 2 (bottom).


**Table 2.** Model variations and performance metrics of the experiments. (Top Section): Conditional Random Fields (CRF) layer and embeddings combination results. (Middle Section): Character embeddings model extension experiment results. (Bottom Section): Auxiliary features variations experiment results.

### **6. Results**

Micro-averaged Precision, Recall and F1 metrics have been used for performance evaluation as they are widely used in similar sequence labelling tasks [52–54] and perform better on imbalanced datasets [55]. Table 2 summarises the results for the entire set of experiments carried out, separated in three sections, with the corresponding models and their performance.

**EXPERIMENT A:** *Model design and word embeddings:* As a baseline, a CRF model with casing and word morphology features was used. Table 2 shows that all of our proposed models outperform the baseline CRF. Of these, predictably, RWE performs the poorest. GLoVe embeddings deliver a substantial improvement, with BERT embeddings giving the best results across all three metrics, most likely due to its contextualised nature. On aggregate, results indicate that adding a CRF layer improves the performance of all models slightly. In summary, although the increase of the F1 metric is very marginal, there is a consistent improvement throughout all variants in experiment A when introducing a CRF layer. We observe that BiLSTMBERT + CRF is the best performing CRF-enriched model in regards to model architecture and embeddings.

**EXPERIMENT B:** *Character embeddings:* Despite our expectation for the contrary, results revealed that character embeddings cause performance deterioration. Reasons for this may be: (a) that the supplementary trainable parameters increased the model's complexity and learning the underlying correlations between the data points became more challenging, and (b) that the CNN and pooling architecture is perhaps by design unsuitable for this problem setting. The remaining experiments were conducted without CNN character embeddings.

**EXPERIMENT C:** *Auxiliary features experiment:* For the third part of the experiment we used POS tags, IC and WordShape as auxiliary features for the classification. Overall, it is observed that the combination of two or more auxiliary features offers a slight performance advantage against single feature models. Yet, the performance metrics when using those features are not dissimilar from the initial BiLSTM + CRF model with BERT embeddings, except when using WordShape features exclusively. Thus, based on the Recall and F1 metrics, we identify WordShape as a better suited auxiliary feature that can be used with BiLSTM + CRF model. Accordingly, we chose to incorporate WordShape features for further evaluation experiments.

### **7. Evaluation**

Based on the experimental results presented in Section 6 we evaluate BiLSTMBERT+WS + CRF against temporal and semantic context sensitivity labelling.

For temporal context sensitivity labelling, the final model is evaluated on a dataset that was manually built and annotated. Manual annotation is time-consuming and thus the dataset is small compared to the synthetic one. It consists of 60 text sequences, specifically written in a way that token sensitivity is directly dependent on temporal context.

Due to manual annotation, there were cases where stopwords were annotated as sensitive tokens but not picked up by the model, or vice versa, causing a drop in the evaluation metrics. A confusion matrix (Figure 4) on class-level (rather than token level) granularity offers a better evaluation that is invariant to annotation discrepancies. In effect this demonstrates whether the model managed to identify the sensitivity class of the text. Figure 4 illustrates that the majority of sensitivity types are classified correctly. Most of the incorrect classifications are confused with the *NON\_SENSITIVE* class. Note that the hardest classes to classify are the *DEMOG\_SOCECON\_INFO* and *POLITICAL\_INFO*. Additionally, it is important to highlight that the 76.33% and 73.07% F1 score in token-level and class-level experiments, respectively (Figure 4), support our hypothesis that a BiLSTM model can be used for temporal context sensitivity annotation.

**Figure 4.** Confusion matrix for temporal context evaluation.

For our Semantic Context Evaluation, we performed a comparative evaluation against Google's DLP system (https://cloud.google.com/dlp/), which provides industrial-strength sensitive data annotation for over 80 sensitive data types. Google DLP is chosen for its industrial strength and was seen as a suitable benchmark for semantic evaluation since it also uses an automated methodology. To perform an as accurate as possible comparison, we test the performance of the two systems (our model and DLP) solely on sensitive data types, which are common between the two. Google DLP is provided as a platform with standard functionality, allowing solely for user intervention in (i) selecting InfoTypes (equivalent to *topics* of our Sensitivity Classes) and; (ii) creating templates to support a more structured data detection approach. To that end, another dataset was created manually, containing text sequences that include sensitive tokens affected by semantic context (as an example of a sample in this dataset, consider the sentences "I am a male" and "I am a man", they both reveal the same gender information but the wording is different). In the absence of sizeable overlap between the two systems on sensitivity topics, we select only those commonly appearing in both. These are: *Email, Credit Cards, Age, IP, SSN, Ethnic Group and Gender*.

Results of our comparison are provided in Figure 5. For all sensitive data types, our system outperforms Google's DLP service. Particularly it was observed that, when noise that can affect the syntactic but not the semantic meaning of sensitive data was added, Google's DLP fails to annotate the sensitive tokens. On the contrary, our system exhibits resilience against such noise with a relative accuracy advantage over the Google DLP of 42.65%.

### **8. Discussion**

Sensitive information labelling is a prominent problem when designing privacy-aware decision systems. Automated sensitivity labelling is especially relevant when considering users as custodians of their own personal data. With this in mind, we developed our model to enhance sensitivity annotation by first offering a taxonomy of four context classes (semantic, agent, temporal and self), and then using these to implement context-aware labelling.

Our model does not come without limitations, and future work should: involve the full set of context classes; extend the auxiliary features experiments to evaluate tweaked models such as plain BiLSTM (without CRF); incorporate additional word, sentence and character embeddings; and perform testing and validation on more extensive datasets. Yet, the work presented in this paper can essentially serve as a framework for building similar models within alternative well-defined problem domains.

The impact of our work is manifold, although we acknowledge potential risks that typically come with advancements in understanding and extracting sensitive information. While our models contribute to more accurate context-aware sensitivity labelling, our choice to adopt P3P sensitivity classes also supports openness and extensibility to third party applications, offering a platform for others to further develop suitable methods. It furthermore demonstrates promising results from using deep learning techniques in text sensitivity annotation, an area that is sparsely addressed in the literature. We believe that further expanding our approach will offer a more concrete future direction for privacy-preserving information exchange.

**Author Contributions:** This paper was accomplished based on the collaborative work of the authors. A.P. performed the experiments and analysed the data. Experiment interpretation and paper authorship were jointly performed by A.P. and G.S. All authors have read and agreed to the published version of the manuscript.

**Funding:** This research received no external funding.

**Conflicts of Interest:** The authors declare no conflict of interest.

### **Abbreviations**

The following abbreviations are used in this manuscript:


### **References**


**Publisher's Note:** MDPI stays neutral with regard to jurisdictional claims in published maps and institutional affiliations.

© 2020 by the authors. Licensee MDPI, Basel, Switzerland. This article is an open access article distributed under the terms and conditions of the Creative Commons Attribution (CC BY) license (http://creativecommons.org/licenses/by/4.0/).

## *Article* **Machine Learning-Based Code Auto-Completion Implementation for Firmware Developers**

### **Junghyun Kim 1, Kyuman Lee 2,\* and Sanghyun Choi <sup>3</sup>**


Received: 29 October 2020; Accepted: 19 November 2020; Published: 28 November 2020

**Abstract:** With the advent of artificial intelligence, the research paradigm in natural language processing has been transitioned from statistical methods to machine learning-based approaches. One application is to develop a deep learning-based language model that helps software engineers write code faster. Although there have already been many attempts to develop code auto-completion functionality from different research groups, a need to establish an in-house code has been identified for the following reasons: (1) a security-sensitive company (e.g., Samsung Electronics) may not want to utilize commercial tools given that there is a risk of leaked source codes and (2) commercial tools may not be applicable to the specific domain (e.g., SSD firmware development) especially if one needs to predict unique code patterns and style. This research proposes a hybrid approach that harnesses the synergy between machine learning techniques and advanced design methods aiming to develop a code auto-completion framework that helps firmware developers write code in a more efficient manner. The sensitivity analysis results show that the deterministic design results in reducing prediction accuracy as it generates output in some unexpected ways, while the probabilistic design provides a list of reasonable next code elements in which one could select it manually to increase prediction accuracy.

**Keywords:** machine learning; code auto-completion; GPT-2 model; advanced design methods

### **1. Introduction**

### *1.1. Research Motivation*

Firmware software developers at a company are typically responsible for developing a software program that operates a product. A company always seeks to provide a streamlined work process for firmware software developers to increase productivity as the process leads to saving money for the company. One potential barrier for increasing productivity is to spend considerable time writing code that is particularly due to a repetitive task. Another potential problem is that firmware software developers may be generating similar codes simultaneously as they are separately involved in developing different hardware products. Figure 1 notionally illustrates the issue where we noticed that two firmware software developers separately worked on each code that was eventually similar to each other. This situation could prevent them from working efficiently if they would need to handle a large volume of source codes, resulting in decreasing productivity. Thus, a need to develop a framework that helps firmware software developers work efficiently has been identified in this research.

**Figure 1.** Notional sketch of necessity for developing a code auto-completion framework.

### *1.2. Background*

With the advent of machine learning (ML) techniques, the research paradigm in various engineering areas has been recently transitioned from theory-based to data-driven approaches [1]. There have already been many studies asserting that ML techniques outperform traditional statistical methods. A language model is no exception to this paradigm shift. Many research groups have been committed to developing a language model using ML techniques. For example, Google developed the bidirectional encoder representations from transformers (BERT) [2] that mainly use transformer encoder blocks. OpenAI released the generative pre-trained transformer (GPT) models [3,4] such as the GPT-2 models.

The GPT models are transformer-based language models trained on a massive text dataset from the websites. Depending on the size of neural network weight parameters, the GPT-2 models are classified into small (i.e., 117 M), medium (i.e., 345 M), and large (i.e., 762 M) pre-trained models, as shown in Figure 2. Here, 117 M means that there are 117 million parameters of the neural network model. The obvious upside of the GPT models is that the pre-trained models can be easily tailored to various domain-specific language modeling tasks, given that the models are fine-tuned with domain-specific training datasets. For this reason, the GPT models have been widely used for a variety of domain-specific tasks such as speech recognition and language translation.

**Figure 2.** GPT-2 models (reproduced from [5]).

In fact, the GPT models have stunned the world by demonstrating the impressive capability that may exceed the current language models. One example is the Allen AI GPT-2 Explorer [6] as shown in Figure 3 , where it uses the GPT-2 345M model to predict the most likely next word alongside their probability score. In this example, it appears that the model generates a list of candidates for the new few words (e.g., Electronics Co., Ltd.) once some initial text (e.g., I am currently working at Samsung) is provided.


**Figure 3.** AllenNLP language modeling demonstration example.

Given the aforementioned observations, it can be hypothesized if the pre-trained GPT-2 models are fine-tuned with SSD firmware source codes, the fine-tuned model predicts the most likely next code element. To that end, this research aims to develop a framework that deploys the GPT-2 models to help SSD firmware developers write code in a more efficient manner. The remainder of this paper consists of the following: Literature Review, Proposed Methodology, Results and Discussion, and Conclusion.

#### **2. Literature Review**

In relation to the research objective, there have already been many attempts to develop similar capabilities. This section is aimed at reviewing the advances and limitations of the previous efforts about code auto-completion functionality, which helps identify research gaps that need to be bridged.

#### *2.1. Related Work*

Code auto-completion functionality has been considered as one of the most essential functions for software engineers. The Integrated Development Environment (IDE) has provided a set of effective features that include code auto-completion capability [7]. The code auto-completion feature in the IDEs basically suggests next probable code elements; however, there are some potential issues that have been identified [8]: (1) the feature requires an exhaustive set of rules, (2) predictions do not consider the category of code, (3) predictions do not consider context such as class definition, and (4) recommendations are often lexicographical and alphabetical, which may not be very useful.

Many research groups have adopted statistical methods to resolve the potential issues of the IDEs. For example, Sebastian Proksch et al. [9] replaced an existing code auto-completion engine by an approach using Bayesian networks named pattern-based Bayesian network (PBN). Raychev et al. [10] proposed the state-of-the-art probabilistic model for code auto-completion functionality, which is mainly equipped with the n-gram model that computes the probability of the next code elements given previous n elements. The statistical approach, however, examines only a limited number of elements in the source codes when completing the code; thus, the effectiveness of this approach may not scale well to large programs [11].

With the advent of deep learning, many research groups have been committed to developing deep learning-based code auto-completion functionality. The most common technique is to use a recurrent neural network (RNN). In fact, Karampatsis et al. [12] showed that the RNN-based language models would be much better than the traditional statistical methods. Moreover, Martin White et al. [13] illustrated how to use the RNN-based language model to facilitate the code auto-completion task. It seemed that RNN-based language models gained the most popularity at the time; however, it was identified that the models were limited by the so-called hidden state bottleneck: all the information about the current sequence is compressed into a fixed-size vector. This limitation made it hard for the RNN-based models to handle long-range dependencies [14].

A transformer-based language model has been introduced to overcome a major drawback of an RNN-based language model by relying on the attention mechanism. For example, Alexey Svyatkovskiy introduced IntelliCode Compose [15], which is capable of predicting sequences of code tokens of arbitrary types. It leveraged the state-of-the-art generative transformer model trained on 1.2 billion lines of source codes. In addition to the IntelliCode Compose, a variety of transformer-based language models have recently achieved excellent work [2–4,16] for various natural language processing

(NLP) tasks such as language modeling. There are numerous practical applications that deploy a transformer-based language model for code auto-completion functionality. TabNine [17] published a blog post mentioning the use of GPT-2 model in their code auto-completion feature. However, they never revealed technical details about the modeling process. Kite-Pro [18], which also employs the transformer-based language model, reports on average 18 percent more efficiency by using the code auto-completion feature. Table 1 summarizes four different approaches with the advantages and limitations of the previous efforts about code-completion functionality.



### *2.2. Research Gap*

Although many research groups have deployed a transformer-based language model for code auto-completion functionality, it is important to note that they have trained the model, with open source codes mostly coming from GitHub. For example, the Deep TabNine [17] is trained on around 2 million files from GitHub. This indicates that the software may not be applicable if one needs to predict very unique code patterns and style. Therefore, a need to establish a domain-specific language model has been identified based on the following reasons: (1) firmware codes implement very specific sets of features for the hardware and (2) firmware codes typically comply with unique coding styles and patterns optimized for embedded environments.

In fact, some companies (e.g., TabNine) advertise that they offer a GPU-based cloud service that enables users to create a custom model by fine-tuning the model with their input data. However, a security-sensitive company such as Samsung Electronics may not want to utilize the commercial tools given that there is a risk of leaked source codes. In addition to the security issue, a company has to pay for the license fee because the tools are not free to use the service. Thus, a need to develop in-house codes for code auto-completion functionality is identified.

As we seek to develop a domain-specific (i.e., solid state drive (SSD) firmware development) language model by using the GPT-2 model, it naturally leads us to consider how to determine diversity parameters (i.e., *Top*\_*k*, *Top*\_*p*, and Boltzmann temperature) of the model. Given that there has not been any analysis done on the optimal diversity parameter values especially on the SSD firmware development domain, the following research question can be constructed: "How can we determine the GPT-2 diversity parameter values properly for the SSD firmware development domain"?

To answer the question, in this paper, we propose a hybrid approach that harnesses the synergy between ML techniques and advanced design methods (e.g., design of experiment, surrogate modeling, and Monte Carlo simulation) [19] to enhance the level of understanding of the relationship between

the GPT-2 model diversity parameters and code auto-completion functionality in the SSD firmware development domain. Figure 4 notionally illustrates the process of the hybrid approach used for this research.



**Figure 4.** Notional sketch of the process of the hybrid approach used for this research.

### **3. Proposed Methodology**

#### *3.1. Overview of the Methodology*

This research aims to not only develop a framework that deploys GPT-2 117M model to help firmware developers write code in a more efficient manner, but also unravel the hidden relationships between the GPT-2 model diversity parameters and code auto-completion capability. Figure 5 depicts an overview of the proposed methodology.

**Figure 5.** Overview of the proposed methodology.

The framework is a Python-based program that consists of several modules with its primary data sources. There are three different modules: (1) the first module is designed to automate data pre-processing, such as removing all unnecessary C++ code comments, (2) the second module performs fine-tuning for the GPT-2 model with optimized hyper-parameters (i.e., batch size and learning rate), and (3) the third module employs advanced design methods for diversity parameter sensitivity analysis.

### *3.2. Text Pre-Processing*

To customize the original GPT-2 117M model to the SSD firmware development domain, it is imperative to prepare input data properly for the fine-tuning process, because the process may require understanding input data in its own way. Since the source codes include unnecessary information (e.g., C++ code comments) that may deteriorate training data quality, we develop a Python code that automatically removes all unnecessary code through the pattern analysis. Once the Python code completes the removal process, it also removes white spaces as well as empty lines. It then combines all the source codes into one single text file with the delimiter in order to allow the model to learn the formatting of the training data.

### *3.3. GPT-2 Model Fine-Tuning*

The GPT-2 model is a transformer-based language model trained with a massive 40GB text data that mainly includes web pages [18]. Users can fine-tune the GPT-2 model with new input data. During the fine-tuning process, users can either increase or decrease two hyper-parameters, namely batch size and learning rate, to optimize the model's predictive capability. We employ the grid search method as shown in Figure 6 and test all candidate cases on the NVIDIA DGX-1 machine (i.e., Volta 32GB version) to isolate the hyper-parameters. The effective model is finally determined by minimizing the log-loss value. The choice of hyper-parameters is tabulated in Table 2. Figure 7 shows the plot of the loss curve describing that the loss value is actually converged with the chosen hyper-parameters.

**Figure 6.** Notional sketch of the grid search method for isolating hyper-parameters.

**Figure 7.** Plot of the loss history of the tailored GPT-2 model with isolated parameters.

### *3.4. Design of Experiment*

The design of experiment (DoE) is a procedure that selects samples in the design space to maximize the amount of information with a limited set of experiments. To generate a non-linear surrogate model that represents the design space of the GPT-2 model's diversity parameters, we employ two representative DoE methods: (1) the Latin hypercube sampling (LHS) method is used to capture inner points of the design space and (2) the full factorial design with three factors is utilized to capture corner points of the design space. Figure 8 shows how samples are distributed in the design space of the GPT-2 model's diversity parameters.

**Figure 8.** Design space of the GPT-2 model's diversity parameters.

### *3.5. Surrogate Modeling*

The multi-layer perceptron (MLP), which is one of the most representative non-linear regression methods, is deployed as a surrogate model with respect to the GPT-2 model's diversity parameters. The MLP model used for this research entails the following fully-connected layers: (1) an input layer to receive diversity parameter values, (2) an output layer that makes a prediction in terms of the score function illustrated in Table 3, and (3) two hidden layers that are the true computational engine for the regression. Figure 9 shows a diagram of the MLP model structure used for this research.

**Figure 9.** Diagram of the MLP model structure.


**Table 3.** Score metric for the MLP-based surrogate modeling.

To evaluate the accuracy of the MLP-based surrogate model, the model representation error (MRE) is calculated with respect to additional random DoE cases. As a result, R-square, which describes how well the model predictions adhere to reality, is equal to 0.98 and root mean square error (RMSE), which describes how to spread out the residuals, is approximately 3.12.

#### *3.6. Monte Carlo Simulation*

We utilize the Monte Carlo simulation (MCS) technique to see the trend of resulting outcomes generated from the MLP-based surrogate model. Uniform distribution with min/max values is used for the GPT-2 model's diversity parameters. The MCS is then performed with 1,000,000 sample points generated by the uniform distribution of the GPT-2 model's diversity parameters, which are eventually incorporated into the MLP-based surrogate model to yield statistical distributions. Figure 10 notionally depicts the MCS process flow diagram especially used in this research (i.e., input and output mapping).

**Figure 10.** Notional sketch of the MCS process flow diagram.

#### **4. Results and Discussion**

### *4.1. Sensitivity Analysis*

Sensitivity analysis with respect to the GPT-2 diversity parameters is performed to enhance the level of understanding of the relationship between prediction accuracy and the diversity parameters. The GPT-2 model has three different diversity parameters implemented in the sampling process.

The Boltzmann temperature is one of the GPT-2 model diversity parameters that control randomness in the sampling process. Figure 11 shows the MCS results with respect to the Boltzmann temperature. Lower and upper bounds are specified with 0.1 and 0.9, respectively. A black dot represents one experiment case generated by the MLP-based surrogate model with three different input variables randomly sampled from the uniform distributions with respect to the GPT-2 model's diversity parameters. As can be seen from Figure 11, it seems that decreasing Boltzmann temperature (i.e., *x*-axis) keeps the model to generate a high prediction score (i.e., *y*-axis), while increasing Boltzmann temperature causes the model to tend to frequently have a low prediction score. Based on these observations, one may claim that the model with lower Boltzmann temperature value, named deterministic design in this paper, would be the best option to predict the most likely

next code element as the deterministic design strives to minimize the degree of surprise in model output. However, it is too early to draw such a conclusion, because the model with a higher Boltzmann temperature value, named probabilistic design in this paper, would provide a list of reasonable next code elements in which one could select it manually to increase prediction accuracy. Details will be discussed in the section of model evaluation.

**Figure 11.** MCS results for diversity parameter 1 (i.e., Boltzmann temperature).

The *Top*\_*k* is another GPT-2 model diversity parameter that controls the number of sampling words to be considered. For example, the most likely word is only considered if the *Top*\_*k* is equal to onem thus resulting in deterministic design. The deterministic design can successfully eliminate rather weird candidates; however, one may claim that better results would be achieved if the algorithm considers sampling words more than one. Figure 12 shows the MCS results with respect to the *Top*\_*k* parameter. Lower and upper bounds are specified with 40 and 100, respectively. As can be seen, it appears that the *Top*\_*k* parameter value does not have a significant impact on the prediction score, indicating that the *Top*\_*k* parameter may not entirely contribute to the model output diversity.

**Figure 12.** MCS results for diversity parameter 2 (i.e., *Top*\_*k*).

The *Top*\_*p*, one of the GPT-2 model diversity parameters, considers sampling words from the largest possible set of words whose cumulative probability exceeds a user-defined number. Instead of sampling only from the most likely *K* words, the *Top*\_*p* parameter provides an option that dynamically controls the size of the set of sampling words to be considered. For example, if the *Top*\_*p* is equal to 0.9, the algorithm computes cumulative probability distribution (CDF) and cuts off the words as soon as the CDF exceeds 90 percents. Figure 13 shows the MCS results with respect to the *Top*\_*p* parameter. Lower and upper bounds are specified with 0.1 and 0.9, respectively. As the *Top*\_*p* parameter value increases, it results in more randomness in terms of the prediction score. On the other hand, as the *Top*\_*p* parameter value decreases, it leads to less randomness with regard to the prediction score. This implies that the model with lower *Top*\_*p* parameter value, which is approximately 0.25 in this case, becomes deterministic and repetitive; while, the model with a higher *Top*\_*p* parameter value becomes a probabilistic design that may relatively improve code suggestion quality compared to a deterministic design. Details about the difference between deterministic and probabilistic design will be discussed in the section of model evaluation.

**Figure 13.** MCS results for diversity parameter 3 (i.e., *Top*\_*p*).

### *4.2. Model Evaluation*

The simplest way to evaluate the fine-tuned GPT-2 model is to allow the model to ramble on its own, which is called generating unconditional samples, but we are determined to use the option, called generating interactive conditional samples, for a model evaluation purpose, as it is easy to steer customized samples. The interactive conditional sample refers to generating samples based on a user-defined input code. For example, it generates the most likely next code element once the user provides an initial code. After users select the code element, the element is then added to the sequence of the input code. Then, a new sequence becomes the input for the next step. This process is repeated until it fills the rest of the sequence. In this paper, we use open-source SSD firmware codes, namely SimpleSSD [20], to evaluate the framework developed by this research, because Samsung Electronics SSD firmware source codes are strictly confidential. Figure 14 shows one sample code element [21] tested by the framework.

**Figure 14.** Sample code element to be tested.

Based on the sensitivity analysis results, we specify the parameter values tabulated in Table 4 for the deterministic design. It should be noted that we specify 40 for the *Top*\_*k* parameter (i.e., rule of thumb) [22], as the parameter does not affect randomness in the sampling process. Regarding the probabilistic design, we specify the maximum value (i.e., upper limit) for the GPT-2 diversity parameters except for the *Top*\_*k* parameter.


**Table 4.** Diversity parameter values for deterministic design.

Figure 15 shows the results of predictions by the deterministic and probabilistic design for the sample code element from Figure 14. Incorrect code elements are underlined by solid straight lines. This result indicates that the probabilistic design is better than the deterministic design with respect to similarity, especially for the sample code element. Here, it must be noted that the probabilistic design is 100% correct as the users could select the correct code element after the framework suggest a list of the most likely next code elements. Furthermore, it is worth mentioning that the deterministic design is capable of predicting the correct next code elements in most cases; however, it sometimes produces the model output in some unexpected ways.


**Figure 15.** Sample code element predictions by deterministic and probabilistic design.

Figure 16 shows how deterministic and probabilistic design predicts the next code element differently based on the same user-defined input. As can be seen, the deterministic design generated the only one output that would be the most likely (i.e., 100% probability) next code element in the prediction process. Unfortunately, this prediction was incorrect for this particular case even though the deterministic design was typically able to predict correctly in most cases. In other words, the reason why the deterministic design would not be able to predict the code element correctly for this case was that it would automatically select the option with the highest prediction probability. On the other hand, given the same input code element, the probabilistic design generated a list of the most likely next code elements with a certain probability. This approach enabled the users to select the option manually from the list that included the correct code element based on previous input. For example, although the option with the highest probability (e.g., [this](unit\_64t ) was incorrect for this particular case, the probabilistic design was able to increase prediction accuracy by allowing the users to select an option (e.g., [this](unit\_32t) manually. Thus, there was no discrepancy with the probabilistic design given that the user could manually select the candidate from the list of options (e.g., choosing an option with 20% probability instead of an option with 80% probability).

**Figure 16.** Deterministic design vs. Probabilistic design.

This example shows the impact of the GPT-2 model diversity parameters on the model output prediction accuracy. In addition to this particular example, Figures 17–19 support the argument that the deterministic design sometimes generates the model output in some unexpected ways, while there is no discrepancy with the probabilistic design, given that the user could manually select an option from a list of candidates.


**Figure 17.** Sample code (*ftl.cc*) element predictions by deterministic and probabilistic design.

The main gist of these case studies is as follows: First, the deterministic design is recommended for those who would like to reduce latency time for the model output prediction, but users should recognize that the prediction accuracy may not be guaranteed. Second, the probabilistic design is recommended for those who want to guarantee the model prediction accuracy; however, users may have to address potential issues related to high computational costs about inference. Regarding the potential issues, we recognize the trend that transformer-based language models are going to get bigger (e.g., GPT-3 model), so they may require a lot of computing power and memory to run them, which potentially leads to challenging to run on a personal computer with reasonable latency time. In this case, it is imperative to implement the code auto-completion engine developed by this

research into a cloud computing resource. The good news is that Samsung Electronics already has an environment internally that includes many cloud computing systems. With the internal systems, the company does not have to account for security concerns. However, one potential issue is that the cloud systems may experience a bottleneck because of an increase in traffic from user requests. This paper does not address the potential limitation about how to operate the framework developed by this research with the cloud computing systems, but focus on proving the concept of the machine learning-based code auto-completion functionality.


**Figure 18.** Sample code (*hil.cc*) element predictions by deterministic and probabilistic design.


**Figure 19.** Sample code (*fifo.cc*) element predictions by deterministic and probabilistic design.

### **5. Conclusions**

In this research, we established a machine learning-based code auto-completion framework, especially for SSD firmware developers at Samsung Electronics. The hybrid approach that harnesses the synergy between machine learning techniques and advanced design methods was presented

to enhance the level of the understanding of the relationship between the GPT-2 model diversity parameters and prediction accuracy. The sensitivity analysis results showed that the probabilistic design outperformed the deterministic design with respect to the prediction accuracy as we observed a few cases of failure showing that the deterministic design generated output in some unexpected ways. It was found that there must be a balance between model prediction accuracy and prediction latency time given that users utilize the framework with either laptops or desktops. The accomplishment of this research can be implemented in any firmware development environment at a company as needed. In conclusion, it is expected that the framework developed by this research can save numerous hours of productivity by eliminating tedious parts of writing code and helping SSD firmware developers write code in a more efficient manner. Future research will extend the framework by implementing a new functionality accounting for potential issues related to the order of suggestions given that users may select accidentally the first entry (i.e., unwanted selections), which may not always be the correct option, among the recommended code elements.

**Author Contributions:** Conceptualization, J.K. and S.C.; Methodology, J.K.; Software, J.K. and S.C.; Validation, J.K.; Investigation, J.K. and K.L.; Writing—original draft preparation, J.K.; Writing—review and editing, K.L. and S.C.; All authors have read and agreed to the published version of the manuscript.

**Funding:** This research was supported by Kyungpook National University Research Fund, 2020.

**Acknowledgments:** This paper is an extension of a PhD summer internship project that was done at Samsung Electronics in 2020. We would like to thank Hankyu Lee for his feedback on this research. We would also like to thank Jinbaek Song for his support on the NVIDIA DGX-1 setup process.

**Conflicts of Interest:** The authors declare no conflict of interest.

### **References**


**Publisher's Note:** MDPI stays neutral with regard to jurisdictional claims in published maps and institutional affiliations.

© 2020 by the authors. Licensee MDPI, Basel, Switzerland. This article is an open access article distributed under the terms and conditions of the Creative Commons Attribution (CC BY) license (http://creativecommons.org/licenses/by/4.0/).

## *Article* **An Improvement on Estimated Drifter Tracking through Machine Learning and Evolutionary Search**

**Yong-Wook Nam 1, Hwi-Yeon Cho 1, Do-Youn Kim 2, Seung-Hyun Moon <sup>1</sup> and Yong-Hyuk Kim 1,\***


Received: 28 September 2020; Accepted: 13 November 2020; Published: 16 November 2020

**Abstract:** In this study, we estimated drifter tracking over seawater using machine learning and evolutionary search techniques. The parameters used for the prediction are the hourly position of the drifter, the wind velocity, and the flow velocity of each drifter position. Our prediction model was constructed through cross-validation. Trajectories were affected by wind velocity and flow velocity from the starting points of drifters. Mean absolute error (MAE) and normalized cumulative Lagrangian separation (NCLS) were used to evaluate various prediction models. Radial basis function network showed the lowest MAE of 0.0556, an improvement of 35.20% over the numerical model MOHID. Long short-term memory showed the highest NCLS of 0.8762, an improvement of 6.24% over MOHID.

**Keywords:** drifter trajectory; evolutionary computation; machine learning; deep learning; NCLS

### **1. Introduction**

The worldwide increase of large ships and maritime transportation volume causes a number of accidents that are beyond the capacity of individual nations. Pollutants released from accidents may extensively contaminate the marine environment due to ocean currents, weather, and weathering. Therefore, pollutants released during an accident should be removed as soon and as much as possible. An accurate prediction of the pollutant movement can help track and address them, as a variety of studies have proven [1–5]. The prediction model for oil spills generally calculates the movement and spread of the oil spill using the Lagrangian particle approach [5–7]. This forecasting method uses physics for critical parameters such as flow velocity, wind velocity, water level, and temperature in the current state. Parameter optimization using evolutionary computation [8] showed better results than MOHID [6], a numerical water modelling system. Machine learning and ensemble methods [9,10] were also used to estimate the movement of drifters.

Predicting the movement of a drifter on the ocean is an essential step in tracking the spread of an oil spill [11]. While spilled oil may sink or evaporate, most of it floats on the surface. Conventional numerical models can predict oil spills more accurately if the trajectories of drifting particles can be predicted accurately. Therefore, we aimed to predict more accurate trajectories by using various machine learning methods including deep learning, which has attracted much recent attention, rather than by using evolutionary computation as in our previous study [8]. We integrated artificial intelligence (AI) technology to predict the future ocean state, utilizing the continuity of parameter data over time. We expanded the area covered by our previous study [8], which predicted the trajectory of drifting objects in the ocean and systematically compared a wide range of regression functions and artificial neural networks for predicting the movement of drifters. In order to avoid the look-ahead

bias, we used wind and flow forecasts made by the Korea Meteorological Administration and private technical agencies instead of actual future parameters.

In sum, the contributions of this study are summarized as follows. We applied evolutionary computation and machine learning to the prediction of drifter trajectories. We extended and improved our previous study [8] that used evolutionary computation, and we also first predicted the trajectories using various machine learning techniques. This study was the first time that machine learning techniques have been applied to the prediction of drifter trajectories—before this study, only numerical models such as MOHID have been applied to drifter trajectories. Our methods could significantly improve on the results of the representative numerical model, MOHID.

#### **2. Literature Review**

ADIOS [12], developed by the United States in the early 1990s, is one of the widely used decision-making support systems for spilled oil. These support systems share the characteristics of simplicity, high performance, open-source programming, and an extensive oil database. The performance of ADIOS has continued to improve [13] and provides a basis for newly developed oil weathering models. Unfortunately, it cannot use data on the current state [14] and simulate the trajectory of the oil spill. The ADIOS model requires information on spilled oil situations, environmental conditions and prevention strategies, and calculates optimal prediction results by inputting minimum information obtained or expected in the field.

The prediction models for oil spill movements were developed for accurate and detailed analysis. Oil companies, consulting sectors, national agencies, and research centers use certain models worldwide, which enable them to input various marine weather and environmental data to consider oil weathering, and are thus suitable for planning stages and research scenarios for different types of oil and marine weather conditions. Related models are GNOME [7], OILMAP [15] OSCAR [16], OSIS, GULFSPILL [17], MOHID [6], etc. Among them, MOHID [6], which is used as a benchmark measure of performance in this study, was first developed in 1985 at the Marine and Environmental Technology Research Center (IST) of the Instituto Superior Técnico (IST), affiliated with the University of Lisbon, Portugal. MOHID is a multifunctional three-dimensional numerical analysis model that can be applied to coastal and estuary areas, which basically calculates physical actions in coastal and estuary areas such as tide and tsunami. It consists of more than 60 modules that can calculate fluid properties (water temperature, salinity, etc.), Lagrangian movement, turbulence, sediment movement, erosion and sedimentation, meteorological and wave conditions, water quality and ecology, and oil diffusion.

Recent advancements in operational maritime, weather forecasting, and computing technologies have allowed for automated prediction in desktop and web environments. Such prediction models include MOTHY [18], POSEIDON OSM [19], MEDSLICK [20], MEDSLICK II [5], OD3D [21], LEEWAY [22], OILTRANS [14], BSHmod.L [23], and SEATRACK Web [24]. Some models cannot process 3D ocean data or consider stokes drift and the vertical movement of droplets. Research is underway to provide user requirements, convenient and comprehensive user-friendly environments, and geographic information system (GIS) results to improve the model landscape.

Typical models that predict Lagrangian drifter trajectory are known to be less accurate for large structures [25]. Therefore, Aksamit et al. [25] used long short-term memory (LSTM) to accurately predict the velocity of drifters. They used the drifter data obtained from the Gulf of Mexico. Their model was much more accurate than the model using the Maxey–Riley equation [26] in terms of root mean square error (RMSE).

### **3. Data**

Our previous study [8] used data obtained from Seosan, located on the west coast of South Korea. For this work, we added new data obtained from Jeju Island to the previous data. Hourly data from 2015 to 2016 at these locations were used for our study. We used this to predict the hourly locations of drifters in this study. Figure 1 shows the observed trajectories from the two locations in 2015,

and Table 1 shows the features related to the start and end points of trajectories of the sample data. The height of the wind velocity above the ocean was 10 m, and its spatial resolution was 900 m.

**Figure 1.** Observed trajectory data from Seosan and Jeju in 2015.


**Table 1.** Attributes of the data (examples).

### **4. Discussion**

This study added machine learning (ML) techniques to the methods of our prior work [8]. We used regression functions for numerical predictions, and also examined various artificial neural networks.

### *4.1. Numerical Model and Evolutionary Methods*

We used MOHID [6] as a numerical model using the Navier-Stokes equations [27], and we also examined various evolutionary methods from our previous study [8]. The evolutionary methods include differential evolution (DE) [28], particle swarm optimization (PSO) [29], and the covariance matrix adaptation evolutionary strategy (CMA-ES) [30].

### *4.2. Machine Learning*

We used supervised learning to find the mapping between input data (wind velocity and flow velocity) and output data (drifter location). The performance of the following two regression functions were examined.

Support vector regression (SVR): While conventional classifiers minimize error rates during the training process, support vector machines (SVMs) construct a set of hyperplanes so that the distance from it to the nearest training data point is maximized. They were considered as alternatives to artificial neural networks in the 1990s since nonlinear classification became possible through the kernel trick [31]. Support vector regression uses SVM for regression with continuous values as the output [32].

Gaussian process (GP): GP [33] is an ML model that predicts data as the average and variance of probability distributions. It predicts functions that can represent given data in the defined function distribution and estimates functions for experimental data based on the arbitrary training data.

### *4.3. Artificial Neural Networks*

Artificial neural networks represent a learning algorithm inspired by the neural networks of biology. With the recent advancement of deep learning, the use of artificial neural networks has yielded excellent performance in classification and can be used for regression when a mean-square error (MSE) is the loss function. Below are the neural network methods we used in this study.

Multi-layer perceptron (MLP) is a basic neural network structure that adds hidden layers into the perception structure. We built a hierarchical model with four inputs, two outputs, and one hidden layer.

Radial basis function network (RBFN) is a type of neural network that represents the proximity to the correct answer using Gaussian probability and Euclidian distance [34]. One hidden layer based on Gaussian probability distribution is used, and the training process is extremely fast.

Deep neural network (DNN) increases training parameters by adding hidden layers in MLP. Figure 2 shows the settings of the input and output values in the basic DNN structure. It is often essential to use a rectified linear unit (ReLU) [35] and Dropout [36] to prevent the vanishing gradient problem [37].

Recurrent neural network (RNN): the connection between units is characterized by a circular structure [38] and can be used when continuous data is given. RNNs can be used to predict trajectories since the movement of a drifter is sequential data. The input data for MLP or DNN has a fixed size of 4, as depicted in Figure 2. However, RNNs should incorporate all the drifter moves in sequence; thus, the data length can be different. The maximum length that the model can receive as an input sequence is set, and the remaining data space is padded with zero. For example, if the maximum length is set to 200 and given data set has 120 h of movement information, the remaining data is filled with 80 zeros. Figure 3 shows the input and output data structure of the RNN model.

Long short-term memory (LSTM): the RNN model creates the next unit based on product operation and suffer from the vanishing gradient problem when dealing with long data sequences. LSTM [39], which has the function of forgetting past information and of remembering current information by adding a cell-state to the hidden state of RNN, has emerged as one of the most widely used RNN methods and can effectively solve not only the vanishing gradient problem but the long-term dependency. This model is expected to remember the movement of a drifter according to the specific short sequence of wind and flow.

**Figure 2.** Data input/output in our deep neural network (DNN) model.

**Figure 3.** Data input/output in our recurrent neural network (RNN) model.

### **5. Experiments**

### *5.1. Setting and Environments*

We implemented the evolutionary computation methods using DEAP software (https://deap. readthedocs.io/en/master/) except for PSO. For PSO, we used PySwarms (https://pyswarms.readthedocs. io/en/latest/), which performed better than DEAP. We used WEKA 3 [40] for MLP, GP, SVR, and RBFN, and PyTorch [41] for DNN, RNN, and LSTM. Table 2 summarizes software libraries we used.

The previous methods of evolutionary computation evaluated test data by creating a single model per method. When a single method yields several models, the performance may vary depending on the random seed. We confirmed that there are considerable performance differences between models using the same method for deep learning methods. We used the bagging [42] method to analyze each method by creating 10 models for each method and then measuring the average value, variance, and standard deviation of the resulting measured values. There might be sharp loss value changes in the process of solving local minimum problems for deep learning methods. We calculated mean absolute error (MAE) according to epoch and ended training when the MAE value was low in our experiments.


**Table 2.** Methods and library resources.

### *5.2. Evaluation Measures*

We used mean absolute error (MAE) and normalized cumulative Lagrangian separation (NCLS). In prior work [8], we also used the Euclidean distance as an evaluation measure. Since it is calculated by a mechanism similar to MAE, the average of the error distance was calculated only using MAE. NCLS, also called the skill score, was calculated by subtracting the error from 1. Lower values, therefore, represent better results for MAE, whereas higher values close to 1 represent better results for NCLS. The calculation for MAE can be expressed as the following Equation.

$$\frac{1}{m}\sum\_{i=1}^{m}(|\text{pred\\_Lat}\_i - \text{observed\\_Lat}\_i| + |\text{pred\\_Lore}\_i - \text{observed\\_Dom}\_i|)\tag{1}$$

where *pred\_Lon* and *pred\_Lat* represent the longitude and latitude of the predicted data. *Observed\_Lon* and *observed\_Lat* denote the longitude and the latitude of the observed data. Therefore, the difference between these values can be considered as an error. The denominator *m* is the number of test datasets. Therefore, the MAE is the average of the errors.

NCLS is a measurement method that is quite frequently used in trajectory modeling and it is proposed to solve weaknesses in the Lagrangian separation distance in relation to the continental shelf and its adjacent deep ocean. The error of each location is calculated in MAE, whereas NCLS calculates errors by cumulative calculation. Figure 4 shows the calculation process of NCLS. In this process, if *s* becomes too large, the skill score may continue to remain zero. If the tolerance threshold *n* is set high, this can be solved to some extent. In this study, we set *n* to 1 since errors sufficient to make *s* relatively large did not frequently occur.

**Figure 4.** Formula to calculate skill score (*ss*) of NCLS.

### *5.3. Results*

Previous measurements are available for evolutionary computational methods. In Section 5.3.1, we verified the performance of the Python-based system by comparing the results with only the Seosan data. The degree of training is also an essential factor. Prior to experimenting with all the data, in Section 5.3.2, we measured epoch numbers deemed good enough to end training by measuring MAE for each epoch in each deep learning method. Lastly, in Section 5.3.3, we trained them using all the data and described the experimental results that predicted the trajectories of the newly added Jeju data.

### 5.3.1. Evolutionary Search on Seosan Data

Table 3 compares the results of the previous study (C language) [8] and this study (Python). CMA-ES showed particularly good performance compared to the previous study. Overall, we could improve the performance of evolutionary search by using new software libraries.

Table 4 shows the CPU time to build the prediction models for Seosan data. Since inference time is usually much shorter than training time, actual performance can be more important than training speed.


**Table 3.** Parameters of evolutionary computation methods.

The lower the mean absolute error (MAE) values, the better. The higher the NCLS values, the better.


**Table 4.** Computing time of evolutionary computation methods.

### 5.3.2. Deep Learning

The neural network methods use the loss value to calculate how well a model is trained. The loss value decreases as the model accurately predicts the training data. It is better to use cross-entropy and *softmax* in the final layer of neural network-based classifiers [43]. However, we used MSE as the loss function since we predicted continuous values, not discrete ones.

The loss function measures the difference between the correct answer of the training data and the value predicted by the model, which may not relate to MSE and NCLS. In order to identify whether or not loss and MAE are related to each other, we examined several neural network models. Neural networks learn the values of the weights to reduce loss, and a reduction of MSE can prove worthwhile. We investigated the loss according to epoch, MAE, and MAE of the test data in the three deep learning methods, CNN, RNN, and LSTM. From Jeju data, Case 1 was used as the test data, and the rest were used as training data. Figure 5 shows the results of DNN.

DNN calculates the error between individual data independently. As the training progresses, the loss value generally decreases. However, the MAE of the training data and test data decreases sharply only at the beginning, and the performance does not significantly improve thereafter. We looked for an additional way to solve this problem, since the MAE deviation between each epoch is large even when the training is complete. We attempted to solve the problem by bagging among the ensemble techniques. The final epoch of DNN is set to 1500. The MAE of the test data was low in the 100–200 epoch section when training was incomplete, and the MAE deviation between epochs was large even post-training in the case of DNN. After 1500 epochs, the loss value did not decrease further, so we set the final epoch of DNN to 1500.

**Figure 5.** MAE and Loss of DNN.

Figure 6 shows the MAE and loss values of RNN. The data are continuous time-series data of the movement of a drifter over time. Although the loss value is reduced above 100 epochs, the MAE of the training and test data did not decrease. We set the final epoch of RNN to 1000. LSTM was similar to RNN, but there were ups and downs on the MAE graph as learning progressed. The final epoch of the LSTM was set to 500. Figure 7 shows the MAE and loss values of LSTM.

**Figure 7.** MAE and loss values of LSTM.

### 5.3.3. Results for Each Case

We describe the performance measurement results of methods beyond the evolutionary computation mentioned in Section 3. Table 5 shows the main parameter values for each method. Our prior work [8] measured the error per iteration. The deep leaning method using PyTorch (e.g., DNN, RNN, and LSTM) is conceptually similar; MAE and loss were measured per epoch. The result based on the WEKA library could not measure the error according to the iteration since it could not set the iteration number internally.

Table 6 shows the evaluation of measured values by building models from ML and evolutionary computation methods. Only the Seosan data was used for training. GP, MLP, RBFN, and SVR based on WEKA showed outstanding performance. However, the deep learning methods did not perform very well as the variation in data volume by case was large for the Seosan data. The difference between the numbers of Cases 1 and 2 was nine times, as shown in Table 1. The data length also has a significant impact on training because the previous event influences the next event in RNN. Table 7 shows the CPU time spent building models for each method. Evolutionary computation methods took more time than GP, MLP, and RFBN using WEKA software.

Table 8 shows the experiment results with the Jeju data. Unlike the result of the Seosan data, the deep learning methods showed excellent performance. LSTM performed better than the basic RNN models, and sometimes DNN performed better. In Cases 1, 2, and 3, MLP and RBFN using WEKA software indicated good performance, whereas the evolutionary computation methods are neither good nor bad performance. The results of the RNN methods could be improved since the variation of the number of data by case is small for the Jeju data. Table 9 shows the CPU time for calculation. The computation time for the Jeju data was not much different from that of the Seosan data.


**Table 5.** Parameter values for machine learning (ML) methods.


The lower the MAE values, the better. The higher the NCLS values, the better.


**Table 7.** Computing time for Seosan data.

**Table 8.** Results for Jeju data.


**Table 9.** Computing time for Jeju data.


### 5.3.4. Weighted Average Results

We compared evolutionary algorithms and ML using the Seosan and Jeju data. However, it was not easy to find out which method was better overall. We used the weight-averaged results and trajectory plots of predicted and actual points. The weighted average provides an advantage when the number of data for each case is different. Experiments with more data are considered more important; thus weights are based on the number of data. The weighted average is calculated as follows:

$$\left| \sum\_{i=1}^{n} d\_i r\_i \right| \sum\_{i=1}^{n} d\_{i\prime} \tag{2}$$

where *di* refers to the number of data of the *i*th case and *ri* refers to the evaluation result of the *i*th case. One of the indicators covered in this section is the standard deviation. As described in Section 4.1, this experiment uses the method of building ten models, evaluating each of them, and then obtaining the average. The performance of each model may vary for each run. For practical use, we need to determine whether or not the performance deviation of each model is large.

Table 10 shows the overall performance of each method. In this context, CMA-ES showed the highest performance on the Seosan data and LSTM on the Jeju data. The performance of DNN and RNN on the Jeju data was good. The amount of data has to be equalized in each case of using neural networks. Rankings were calculated separately for each method in terms of MAE and NCLS. RBFN was the best for MAE, and LSTM was the best for NCLS. Since NCLS is more popular measure for predictions of drifter trajectory than MAE, LSTM was the best method for this study. As expected, LSTM was superior to RNN. There is past information to be forgotten and current information to be remembered according to the direction of the drifters. The performance of DNN and RNN was good for Jeju, so if we get more data in the future, it will be possible improve their performance. Evolutionary methods showed good performance on both Seosan and Jeju data. Especially for Seosan, where the number of data in each case was large, these methods showed better performance than the other methods.


Figure 8 shows the trajectory of a drifter predicted by CMA-ES, which showed the best performance for Seosan. Figure 9 exhibits the trajectory of a drifter predicted by LSTM, which had the best performance for Jeju. Finally, Figure 10 presents the trajectory of a drifter predicted by RBFN with the best performance for MAE. Both ML and evolutionary search optimize the parameters. There is a slight difference in accuracy, but all of them predict similar paths.

Except for the RNN series (RNN and LSTM), only the data at that point were used to predict the trajectory of a drifter at a point in hourly time. However, there is a difference in that RNN uses hidden layer neurons, which were used for prediction in previous instances, and LSTM showed the best performance as a measure of NCLS. It is believed that this is because the RNN series take into account the inertia of the drifter.

**Figure 8.** Comparison of trajectory predicted by our CMA-ES model, trajectory predicted by an existing numerical model (MOHID), and observed trajectory for four major drifters.

**Figure 9.** Comparison of trajectory predicted by our LSTM model, trajectory predicted by an existing numerical model (MOHID), and observed trajectory for four major drifters.

**Figure 10.** Comparison of trajectory predicted by our RBFN model, trajectory predicted by an existing numerical model (MOHID), and observed trajectory for four major drifters.

### **6. Conclusions**

We extended and improved our previous study [8] which predicted the trajectories of drifters using evolutionary computation, and we also predicted the trajectories of drifters using various machine learning techniques [44–49]. To the best of the authors' knowledge, this was the first trial in which machine learning has been applied to the prediction of drifter trajectories, and it significantly improved upon the representative numerical model, MOHID.

In terms of MAE, RBFN using the WEKA library showed the best performance, an improvement of 35.20% over the numerical model MOHID. LSTM using PyTorch showed the best performance regarding NCLS, an improvement of 6.24% over MOHID. These neural network-based methods did not take a long time to construct a model. In the future, we plan to experiment with other representative variants of RNNs such as gated recurrent units [50], and we will design more models that increase the performance of DNNs or basic RNNs by adding more training data.

**Author Contributions:** Conceptualization, D.-Y.K. and Y.-H.K.; methodology, Y.-W.N. and H.-Y.C.; software, Y.-W.N. and H.-Y.C.; validation, Y.-W.N., H.-Y.C. and D.-Y.K.; formal analysis, Y.-W.N. and Y.-H.K.; investigation, Y.-H.K.; resources, D.-Y.K.; data curation, D.-Y.K.; writing—original draft preparation, Y.-W.N., H.-Y.C., S.-H.M. and Y.-H.K.; writing—review and editing, S.-H.M. and Y.-H.K.; visualization, H.-Y.C.; supervision, Y.-H.K.; project administration, Y.-H.K.; funding acquisition, D.-Y.K. and Y.-H.K. All authors have read and agreed to the published version of the manuscript.

**Funding:** This research was a part of the project titled 'Marine Oil Spill Risk Assessment and Development of Response Support System through Big Data Analysis', funded by the Ministry of Oceans and Fisheries, Korea.

**Conflicts of Interest:** The authors declare that there is no conflict of interest regarding the publication of this article.

### **References**


**Publisher's Note:** MDPI stays neutral with regard to jurisdictional claims in published maps and institutional affiliations.

© 2020 by the authors. Licensee MDPI, Basel, Switzerland. This article is an open access article distributed under the terms and conditions of the Creative Commons Attribution (CC BY) license (http://creativecommons.org/licenses/by/4.0/).

## *Article* **Salespeople Performance Evaluation with Predictive Analytics in B2B**

**Nelito Calixto <sup>1</sup> and João Ferreira 1,2,\***


Received: 13 April 2020; Accepted: 8 June 2020; Published: 11 June 2020

**Abstract:** Performance Evaluation is a process that occurs multiple times per year on a company. During this process, the manager and the salesperson evaluate how the salesperson performed on numerous Key Performance Indicators (KPIs). To prepare the evaluation meeting, managers have to gather data from Customer Relationship Management System, Financial Systems, Excel files, among others, leading to a very time-consuming process. The result of the Performance Evaluation is a classification followed by actions to improve the performance where it is needed. Nowadays, through predictive analytics technologies, it is possible to make classifications based on data. In this work, the authors applied a Naive Bayes model over a dataset that is composed by sales from 594 salespeople along 3 years from a global freight forwarding company, to classify salespeople into pre-defined categories provided by the business. The classification is done in 3 classes, being: Not Performing, Good, and Outstanding. The classification was achieved based on KPI's like growth volume and percentage, sales variability along the year, opportunities created, customer base line, target achievement among others. The authors assessed the performance of the model with a confusion matrix and other techniques like True Positives, True Negatives, and F1 score. The results showed an accuracy of 92.50% for the whole model.

**Keywords:** data mining; predictive analytics; sales; performance measurement; human resources

### **1. Introduction**

Salesperson performance measurement is a process that occurs multiple times per year on a company. The performance evaluation is based on various Key Performance Indicators (KPI's) extracted from multiple systems like Customer Relationship Management (CRM), and Enterprise Resource Planning (ERP).

Evaluating these KPI's can be time-consuming as they require the analysis of figures with complex calculations, a judgment based on the values, and the weight that each of the KPI's contributes to the performance as a whole. The KPI's often include the amount of products/services sold by the salesperson, the number of opportunities created, the ability to sell multiple products/services, the variability of the sales along the year, among many others. When a company has dozens or hundreds of salespeople, this process transforms on a thorough process that may involve other departments like Human Resources (HR) and Operations.

The result of the performance evaluation is a classification followed by actions to improve the performance where it is needed. Technology, through Data Mining (DM), currently is capable of make classification based on data. DM is the process of exploration and analysis, by automatic or semiautomatic means, of large quantities of data to discover meaningful patterns and rules [1]. DM tasks are classified into two categories: descriptive and predictive [1]. The predictive tasks, are the ones that perform inferences based on data to make predictions. The goal of these tasks is to create a

predictive model. The goal of the predictive model is to allow the data miner to predict an unknown value of a specific variable. When the result of the prediction is a number, it is called a regression, and when the result is a label it's called a classification [1].

DM classification capabilities can help improving the process of the salesperson performance measurement. Companies can take advantage of the Predictive Analytics (PA) classification capabilities, to help on the judgment of KPI's that are based on complex calculations and the weight that each KPI contributes to the whole performance evaluation. By using classifications previously made by humans, companies can build models that can classify current sales of a salesperson and use them on the performance evaluation. Through these models it is possible to automate part of the performance evaluation process. The gains these automated evaluations can bring to the companies are among others:


In this work, the authors propose the use of DM techniques, to allow salesperson and sales leaders to make a better decision about salespeople performance measurement, by building a model in R that can classify a salesperson's performance based on metrics defined by the business. As many companies can have different evaluation processes, all companies in B2B area that has teams of salespeople being measured based on metrics, can take advantage of this DM process.

The dataset used for this analysis is composed of data regarding salespeople performance measurement from a Freight Forwarding global company. The sales are made by 594 salespeople between January 2017 and June 2019. This measurement is based on the company's internal performance measurement process, that are explained in this article on a very high level, to provide to the reader an understanding of the data, and the fields necessary to make the performance measurement. It is not the goal of this work to evaluate scientifically the process of salespeople performance measurement of this company.

### *1.1. Research Contribution*

The DM process applied on this work can be replicated to any company who have historical objective metrics, and classifications applied to people based on these metrics.

The contributions of this paper are the followings:


### *1.2. Paper Structure*

The paper is structured in the following way: Section 2 has the literature review; Section 3 has the background where the company's performance evaluation process is explained; Section 4 has the work methodology and all the steps needed to prepare the data for modeling and evaluation; Section 5 has the discussion; Section 6 has the conclusion, and proposals of future work.

#### **2. Literature Review**

#### *2.1. Salesperson Performance*

Academic studies demonstrate that the success of a salesperson normally has a direct relationship with the company performance, some authors states that: "When salespeople do well, the organization is likely doing well, and the contrary is normally true as well." [2]. When measuring salesperson performance, there are objective data, such as total sales increase, sales commissions or percent of quota, and subjective measures like manager's or peer's assessment of the salesperson [3]. Many companies use a combination of objective and subjective KPI's to make the assessment. A meta-analysis of objective and subjective sales indicators suggests that there is a low correlation identified between objective and subjective sales success indicators, which show that these indicators are not necessarily interchangeable, and the choice of the most appropriate may require trade-off [2].

The evaluation process of performance varies from company to company [4]. Activities on a job cannot be measured by only one method of objective or subjective measures, as some tasks of a job requires objective method of evaluation, and for others subjective measures are better. Bikrant Kesari examined the impact of objective and subjective measures of evaluation in sales departments, using various methods. For the company being studied, Bikrant Kesari concluded that objective measures were the most relevant factor used in the salesperson evaluation demonstrating the positive impact of the performance [4].

Muhammad Ruhul Amin et al., evaluated the effectiveness of weighted checklist method to appraise the performance of employees on different levels of a bank, based on Self assessment, Competency & demonstration of leadership behaviours, and Skill & knowledge assessment, the achievement classifications were made in 6 levels. The authors of the paper in question concluded that the impact of the method on employees was inevitable and all the financial and non financial benefits were effected due to the method [5].

John P. Campbell et al. defined individual job performance as things people do, and actions people take, that contribute to the organizations goals [6]. In another article Campbell et al. mention that performance is what facilitates achieving the organization goals directly [7].

The performance itself can be measured with judgmental and nonjudgmental measures which are the outcome measures [8]. The outcome measures use objective data, which don't need abstraction from who is collecting the data [9]. There are three predominant methods of measuring the sales performance. These are Outcome measure that are composed by sales volume and its variants, Judgmental managerial ratings and the salesperson Self-evaluation [10]. In the current work only objective measures are available, as the data provided for the current study only contain volume figures among other information related to sales, but none of these are related to subjective measures.

### *2.2. Predictive Analytics for Sales*

Predictive analytics is an area increasingly entering the business and academic fields [11]. Companies more and more have been using DM to improve their internal processes and automate not only repetitive, but complex tasks nowadays completed by humans [12,13].

Authors in the academic area refer that PA has been used for several years by companies to get a competitive advantage, [14,15]. At first, by companies acting in the B2C with a large customer base and capacity to collect and store transactional data from customers, and only then by companies acting in the B2B area [14].

B2B selling companies are hiring cloud-based PA providers to draw on both inside and outside data sources to identify new leads so that they can take advantage of PA [16].

Mirzaei and Iyer did a comprehensive study on the application of PA over CRM data in 2014. The results show 57 articles found in 4 databases, where the studies focused on dimensions like Customer Acquisition, Attraction, Retention, Development, and Equity Growth [17]. Another fact

the results show is that PA techniques between 2003 and 2013 gained a lot of popularity in areas like casinos, retailers, telecommunications, manufacturing, insurance and healthcare [17].

To understand what has been studied in the academic area in terms of predictive analytics, the authors hereunder describes some success cases of PA applied in sales forecasting.

2.2.1. Sales Forecasting of Computer Products Based on Variable Selection Scheme and Support Vector Regression (SVR)

Like many other industries, sales forecasting is also a challenge for computer product retailers. Wrong forecasts can cause product backlog or inventory shortages, incorrect customer demands and decrease customer satisfaction [18].

Chi-Jie Lu et al. combined Multi Variable Adaptive Regression Spines (MARS) with SVR to make a sales forecasting model for computer products. The main idea over the scheme was first to use MARS to select the essential forecasting variables and then use the identified key forecasting variables as the input variables for SVR. The data used was a compilation of the weekly sales data of five computer products from a computer retailer in Taiwan. The sales in the dataset referred to products like Notebooks, LCDs, Main Board, Hardrives, and Display cards [18].

#### 2.2.2. Fast Fashion Sales Forecasting with Limited Data and Time

Another case of success found is applied to fast fashion, which is an industrial practice, where the main idea is to offer a continuous stream of new merchandise to the market [19]. With this practice, some fashion companies are even capable of having the products from the conceptual design to the final product in just two weeks. Companies working with this practice have to make their inventory decisions based on a forecast with short lead time and a tight schedule. The result is companies making a forecast on a near real-time basis and with a minimal amount of data. TM Choi et al. proposed an algorithm called Fast Fashion Forecasting (3F), that give the companies the ability to make forecasts with limited data and time. This algorithm uses two artificial intelligence methods: Extreme Learning Machine (ELM) and the Grey Model (GM). The data used belonged to a knitwear fashion company using a fast-fashion concept. The algorithm was tested with real and artificial sales data, and the results revealed an acceptable forecasting accuracy [19].

#### 2.2.3. Support Vector Regression for Newspaper/Magazine Sales Forecasting

The next case is in the media area, where due to the constant transformations that information technologies are bringing to the world, new generations are more and more used to browse the internet for news and exciting stories [20]. With that in mind, the media industry also has to evolve to keep up with the progress. For that reason, it is more urgent for traditional media companies to make an accurate forecast on printing newspapers and magazines, to avoid excessive printing or not meeting the expected demand [20]. The authors of the study in question used SVR in a media company with printed newspaper/magazines to create a sales forecast that estimate and prepares the prints plan and distribution. The results of the study showed that SVR is a superior method in forecasting sales for the news/magazines industry [20].

With these scientific articles about success cases of PA in the B2C, we move next to success cases in the B2B area.

### 2.2.4. On Machine Learning towards Predictive Sales Pipeline Analytics

On companies operating in B2B, new sales are often identified as Leads. These leads move then into the Sales Opportunity Pipeline Management System. Later on, some of these Leads are qualified into opportunities. A sales opportunity is a set of one, or several products or services that the salesperson is trying to convert into a purchase. All the Opportunities are tracked, ideally ending on a won business that generates revenue for the company [21].

A fundamental part of the pipeline quality assessment is the lead-level win-propensity score identified as the win-propensity. The salesperson usually enters these scores, but to avoid noise inserted by the salesperson for various reasons and biased scores, the authors of the article in question proposed and successfully deployed a model to calculate the win-propensity using the Hawkes process model in a multinational Fortune 500 B2B-selling company in 2013 [21].

### 2.2.5. Prescriptive Analytics for Allocating Sales Teams to Opportunities

Still, in the Opportunities, other authors used Predictive and Prescriptive Analytics to increase the revenue of a company by 15%. Such increase was achieved by automating the allocation of sales resources to opportunities, to maximize opportunities revenue in B2B selling for the company [13].

The Predictive part was achieved by mining the historical selling data to learn sales response functions that have the behavioral relationship between the size and composition of a sales team, the revenue earned for the different types of customers, and the opportunities, through multiple linear regression [13].

For Prescriptive, these authors used the sales response functions to determine the allocation of salespeople's effort to the customer's opportunities that maximize the overall revenue earned by the salespeople, using a piece-wise linear approximation [13].

As presented in above articles, PA is widely being used on sales, the data used for these predictions is the data type needed to use in measurement of salespeople. With this base on PA for sales, the authors now moves to the application of PA in HR. HR is essential in this work due to the performance evaluation processes.

### *2.3. Predictive Analytics in HR Management*

The articles studied in HR, refers to first how PA is being used for HR in general and then how PA is being used for people performance evaluation and analysis.

### 2.3.1. How PA Is Being Used for HR in General

An article published in 2017 [22], propose the use of PA in HR for:


The authors of the article in question also proposes research in Appropriate Recruitment Profile Selection, Employee Sentiment Analysis, and Employee Fraud Risk management [22].

Sujeet N. Mishra et al. proposes the use of Human Resource Predictive Analytics (HRPA) for decision making by presenting two cases of success: One in a US wind turbine maker that changed the recruitment and retaining policies based on HRPA; Another is at Cisco, which used IBM SPSS to transform the relationship between its HR analysis and executive leaders [23]. Kessler et al. presents the categorization module of E-Gen, a modular system to treat job listings automatically. Through SVM these authors managed to rank candidate responses based on several information [24]. On another article, two authors used machine learning techniques to rank candidates on a recruiting process by analyzing the candidate adaptability to a job position based on the candidate tweets [25]. Other authors proposed an approach to evaluate job applications in online recruitment systems so they could solve the candidate ranking issue. They achieved this by analyzing the candidate's Linkedin profile and

infer their personality characteristics using linguistic analysis on the candidate blog profile. For that, they had to use training data provided by human recruiters and applied in a large-scale recruitment scenario with three different positions and 100 applicants using Regression Tree and SVR [26].

### 2.3.2. How PA Is Being Used in HR for Performance Evaluation and Analysis

Zhao in his Conference Proceeding "International Seminar on Future Information Technology and Management Engineering" published in 2008, proposed a method of DM for performance evaluation. For that they gathered information about Ability, Attitude, Performance, Harvest, and Spirit in a dataset. Then they used the K-Expectation algorithm to classify employees into the same group. After that, a Decision tree is used to train a model based on rules that can be used by managers to classify and select the best employees from the applicants [27].

Jing applied Fuzzy Data Mining Algorithm (FDMA) for performance evaluation of human resources. For that, the author used evaluation records with four features: innovation ability, learning level, work efficiency, independence and workability, and each of these had four levels, which are the corresponding score of each feature [28]. Then, Jing used the maximal tree to cluster the human resource leading to the next step, that was to compare the data from management with each cluster and calculate the proximal values based on the FDMA, the last step referred to determine the evaluation. The evaluation, in this case, was a result closer to each of the 4 clusters that are named as Best, Better, General, and Worse [28].

Two authors applied Decision Trees on performance analysis of human resources to make classification analysis. The results show that there are mutual restraint and influence between performance results and working quality, tasks, skills, and attitude. Concluding that if the enterprise in the future cultivates employee working skills and quality, the employees will consciously improve themselves in these areas [29].

The above on PA for HR and performance evaluations are not based on data from sales made by salespeople. What is proposed in this article, is the use of PA to evaluate salesperson using the sales that was made by the salesperson, taking advantage of the data already available in the CRM, ERP systems, and previous performance evaluations. With that ground base, it is now time to proceed into the background section, where the company's salesperson evaluation process is described.

### **3. Background**

In this section, the process and main KPI's used to evaluate the salesperson performance is described on a very high level, to provide an understanding of the data and fields used on this research. It is not the goal of this work to evaluate scientifically the process of salespeople performance measurement of this company.

### *3.1. Main KPI's Used for Salespeople Performance Evaluation*

According to the process of the company that provided the data for this research, the main KPI's used to evaluate a salesperson performance are:


### *3.2. Assess Salespeople Performance*

Based on the company's performance evaluation process, there are a number of questions whose answer lead's to the evaluation level. The answer to these questions are provided by the KPI's described below:

• What growth did the salesperson brought to the company?


As displayed in Figure 1, the first level to verify is the growth, then check if the targets were achieved and finally if the targets follow the company guidelines. Other relevant KPIs that contribute to salesperson performance is also assessed, but these are the most important ones.

**Figure 1.** Company's performance evaluation stages.

### 3.2.1. What Growth Did the Salesperson Brought to the Company?

Starting with the first query: "What growth did the salesperson brought to the company?". A salesperson is assigned to an Account Base that has on average 70 customers, the base for analysis is the growth, which is the difference between the number of Twenty-foot equivalent unit (TEU) sold between the current and previous year. The base in the analysis is the sum of the growth for each year.

### 3.2.2. The Salesperson Achieved the Defined Targets?

The target definition in this company is supported on a top/down process. Targets are based on a roadmap that is defined globally by the sales controlling department, these targets are assigned for each region, and then distributed by the regional managers to the countries. The process continues until it reaches the salesperson. As exemplified in Figure 2 a global roadmap of 10,000 TEU's globally was defined. These TEU's are shared among all the regions, and ends on salesperson x and y in Lisbon with 30 TEU's each.

Although the company has implemented this process, not always the salesperson gets a reasonable target, because this will depend on the strategy defined by the local sales management, and on this company, part of the strategy is defined locally. For instance, in Figure 2, all Portugal's targets are assigned to Lisbon and none to Oporto. If the sales management in Portugal believe it's possible to achieve all targets with the 2 salespeople in Lisbon, they don't have to assign targets to salespeople in Oporto.

**Figure 2.** Target definition Top Down.

Other than the number of TEU's assigned for a region/country, there is also a target definition at the product level. This is another way of strategically redirect the sales team to target a specific product. For instance, if a country has a higher market for Import, the sales manager should set Targets on Import to boost Import sales.

3.2.3. Do the Assigned Targets to the Salesperson Follow the Company Guidelines?

In this company, targets are set to a salesperson based on 3 pillars:


As described previously, the Account Base is composed of the customers that are assigned to the salesperson, and it has a significant impact on the level of the target that can be assigned to the person. If a salesperson has a Customer Base composed by 10 customers and these customers have a possibility of purchase 100 TEU's along the year, the targets assigned to this salesperson should not be a value that is too far from the 100 TEU's, unless the person who defines the targets have information's that may indicate that the customer will have a higher increase.

Sales roadmap is the document that has the plan for the company sales growth for the long term. This document for the company in question is composed of the main product categories, regions, trade lanes, among other information. Often sales managers set targets just based on the sales roadmap, but this may lead to the definition of "unrealistic" targets if the Account Base does not provide the potential needed to achieve the targets. When this happens, CRM Pipeline figures is another ally to set the targets. Usually, to improve target setting, Pipeline figures are added to the Sales Planning process. This way, the salesperson and manager have not only the Customer Base line but also the forecast (assuming good forecasting accuracy).

The salesperson seniority also has a significant role in how the salesperson works the Customer Base. A junior salesperson may not have the ability to manage complex accounts. Therefore the sales manager, when assigning the Customer Base, has to know the salesperson seniority. Seniority in the company/products has also consequences on managing the Account Base. For instance, if somebody has just joined the company and is also junior (young), he/she will need "more" time to start generating results: new company, new products, the need to build an internal network, among other relevant tasks. To mitigate this issue, often sales managers give a new/junior salesperson lower targets in the beginning and then increase the targets year-by-year as the seniority increases.

The Figure 3 displays an example of target definition for one salesperson (the name was replaced by one randomly generated for data protection), where it's possible to verify a 15% increase from the Account

Base line (Identified as the Full Year Actual Adjusted) that is 283 TEU's, the increase has an impact of 42 more TEU's, and is splitted across 4 quarters by 10 for Q1, 10 for Q2, 11 for Q3, and 11 for Q4.



Pipeline and seniority are entirely missing in this research, so to judge the targets, a validation is made comparing the targets directly with the Customer Base line in the dataset.

In this dataset, the evaluation is made by dividing the targets with the Customer Base line as displayed in the Formula (1).

$$\text{Target evaluation} = \frac{\text{Target}}{\text{AccountBascline}} \tag{1}$$

3.2.4. Other Relevant KPIs for the Salesperson Performance

There are other KPIs that need to be validated over the salesperson to measure the performance, these include:


The table available in the Figure 4 provides all this information's for a sample of 5 salespeople. Worth of highlighting in the table is the number of opportunities of the first salesperson, which is remarkably high when compared to the second salesperson. Another important information is the average number of months with growth above 0, on average Bella Connor (Belle) is able to grow the Customer Base for about 8 months each year, and she can also grow more than one product.


**Figure 4.** Other relevant KPIs for 5 salespeople sorted by growth.

These rules generate a dataset of 42 KPI's, where based on the accumulated performance of the salesperson on each of the measures, a classification is possible to define for the salesperson. The classifications are divided into the following categories: Not Performing, Good, and Outstanding.

### **4. Work Methodology**

The work methodology used in this research was the Cross Industry Standard Process for Data Mining (CRISP-DM). This methodology as presented in Figure 5 is divided into 6 stages. In this article, the authors describes the steps executed from stage 1 to 5, the last stage is not described here as requested by the company to not provide any information on that area.

**Figure 5.** CRISP-DM Methodology adapted from [30].

The authors hereunder describes each of the CRISP-DM steps taken during this research following the CRISP-DM methodology.

### *4.1. Business Understanding*

### 4.1.1. Objectives

With the main goal of classifying salespeople, and build a model that can tell if a salesperson is successful or not, this research project has the following business objectives:


### 4.1.2. Business Success Criteria

The main success criteria for this research project is the ability to achieve the specific goals defined previously on the objectives. To evaluate these goals, the authors used the metrics provided by algorithms that measure the accuracy of the classifications.

### *4.2. Data Understanding*

The data used in this research refers to sales between January 2017 and June 2019, from a freight forwarding company that operates worldwide on Air, Ocean, and Land. The sales were made by 594 salespeople. The data refers to shipments and sales opportunities for the customers grouped by year. As this company don't want to have their sensitive data provided to public, all sensitive data were removed from the dataset. Remaining only the figures and classification. The names of the salespeople were all replaced with names generated on a Name generator website [31].

There are 1071 rows and 45 columns. Each row represents all the sales, customer base, and sales opportunities made by one salesperson to all he/she's customer base along one year. The dataset has the following structure:


### Data Description

The dataset is publicly provided in the university online database. The data is provided on a csv file and the below tables (Tables 1 and 2) has the description of the attributes.



For each of the six main products, the following fields with performance indicators are also part of the dataset:

$$\text{Target Action} = \frac{\text{Target}}{\text{Growth}} \tag{2}$$

A sample of the dataset is provided on this work in the Figure 6 for better understanding.





**Figure 6.** Sample of report data.

The classifications on the dataset, are made in the categories: Not Performing, Good, and Outstanding, these categories represents the following:


### *4.3. Data Preparation*

The dataset is composed of 45 columns and 1071 Rows. From the 45 columns, four have categorical data: these are Sales\_Person\_Code, Sales\_Person\_Name, Year, and Talent. The remaining columns have numerical data containing the salesperson's performance. A summary of the data available in the dataset is provided in the Table 3 for reference. The columns Sales\_Person\_Code, Sales\_Person\_Name, and Year were removed from the dataset, leaving the dataset with 42 columns.


**Table 3.** Table with classification statistics.

In the next sections, the authors submits the dataset to several techniques that evaluates the importance that each column may have to the model, and eliminates all the ones that contributes little or none. All the evaluations were made using RStudio, all the packages and functions used are identified.

The dataset contains:


In order to train the model, below evaluations and transformation were applied to the 695 classified rows.

The scripts used for this research are made in R language, using the free version of R Studio obtained from: [32] These scripts are provided in the university public database.

#### 4.3.1. Near Zero Variance

Columns with low variance on the data, provide little or no knowledge to the models, so to improve the performance of the model, these columns can be eliminated. To Identify the columns that provide low knowledge, the authors used the function nearZeroVar from the carret package from R. This function diagnoses the predictors that have one unique value, or predictors that have few unique values in relative to the number of samples and the ratio of the frequency, from the most common value to the frequency of the second most common value.

From the results provided by the function, the most importants are zeroVar that has TRUE when the column contains only one distinct value and nzv, which has TRUE when the column in question has a near-zero variance predictor, for reference, the results are provided in the Table 4.


**Table 4.** Result of the nearZeroVar function.


**Table 4.** *Cont.*

There are 19 columns identified by the nearZeroVar function to be removed. After the removal of the 19 columns, the dataset still has 24 columns, 23 numerical + the Talent column.

### 4.3.2. Correlation Matrix

After the removal of the columns with low variance, a correlation matrix was applied to the remaining columns (excluding the Talent column), to find the ones that are highly correlated and remove at least one of them. For that, the authors used the function cor from the caret package. The cor function computes the variance, and the covariance of x and y. The results are a percentage of correlation between columns.

The result of the correlation matrix, as presented in the Figure 7, shows that there are 6 columns highly correlated (above 0.8). The authors eliminated three of the six columns, specifically: (Grow\_with\_Different\_Products, Ocean\_FCL\_Export\_Target\_Achievement, and Ocean\_FCL\_Export\_Growth\_Percent). The dataset has now 21 columns, 20 numeric + the Talent. Only the columns with information specific to a product were removed, because between the columns referring to one product only and the overall, the overall provided more information to the dataset.

**Figure 7.** Correlation matrix.

### 4.3.3. Outliers Treatment

After removing the columns that contribute less, and the columns that are highly correlated, an outlier analysis to the remaining columns of the dataset was processed to identify them. Currently, there are 21 columns in the dataset, including the Talent column, which is the column with the classification.

The dataset has a high number of outliers, as it's possible to verify in the Figure 8. To identify the outliers, the authors used the boxplot.stats function of the package grDevices. This function is typically called by another function to build the boxplot. With that, it was possible to identify the outliers for all the 20 numeric columns.

To not remove data from the small dataset (695 rows from the training dataset), the outlier treatment was focused on applying to every outlier, the values in the range limit, obtained also using the boxplot.stats function from the package grDevices. The lower and higher values applied are provided in the Table 5 for reference, limits were applied to all columns except column: Nº\_Months\_with\_growth\_above\_0 witch didn't needed.

**Figure 8.** Outlier display.



After all the evaluations made, the authors discussed with the business the added value of the columns that refers to specific products, like Ocean FCL Export and Ocean FCL Import (the value added of Freight Management was practically removed by the fact that the outlier treatment eliminated all the values). The fact that these 2 products would be the only ones in the model would bias the salespeople that succeed more on these 2 products over the remaining products. Although the Overall Growth is still part of the dataset, the removal of all the columns specific for the products would produce similar results and with more value to the business. This lead to the removal of the other 10 columns. After the removal of these 10 columns, the dataset got reduced to 11 columns 10 numeric + 1 categorical.

### 4.3.4. Normalize Data

After the completion of all the data treatment steps, and as the Naive Bayes (NB) from R requires all the numeric columns to be standardized. The authors Standardized all the numeric columns using the function normalize of R from the BBmisc package.

With this task completed, the data treatment phase is concluded. The next phase is the evaluation where the results are assessed. This is described in the discussion section.

### **5. Discussion**

### *5.1. Naive Bayes*

In the research, from the studied algorithms, the authors selected the NB because of ease of it's implementation. The NB algorithm is a probabilistic classifier that selects each independent variable, and then associates it to a conditional probability. The conditional probability is calculated based on the following Formula (3)

$$P(C|A) = \frac{P(A|C) \* P(C)}{P(A)}\tag{3}$$

The algorithm calculates the probability of an event occurs, based on another event that occurred in the past. For example, to predict if a salesperson may achieve his targets. In the formula, we can associate C to the probability of a salesperson achieving his targets, while A corresponds to the conditions that allowed the salesperson to achieve the targets, for instance, a customer base composed by customers that buy high volumes of TEU's.

The data was split into 2 separate datasets using the sample function in R, the training dataset with 70% of the data, which corresponds to 481 observations and the test dataset with 214 observations.

### *5.2. Identify Most Important Factors for Salesperson Success*

To achieve the goal: Identify the most important factors for salesperson success, the authors built a Random Forest model with the same train dataset prepared for the NB model, but with the randomForest of R so that the function varImp could be used. The Random Forest model was created using the defaults of R, adding the following parameters: Type of random forest: classification, number of trees: 500, and No. of variables tried at each split: 2. The results were: Out of Bag (OOB) estimate of error rate: 2.91%, and the confusion matrix as provided in the Table 6.

**Table 6.** Confusion matrix from Random Forest.


The results of the varImp function are provided in the Table 7.


**Table 7.** Feature importance.

The results show that the most important features are:


The remaining columns have residual importance compared to the ones before mentioned. The results go in line with the business people's opinions. The salesperson to succeed, have to: focus on growing the customer base, work to achieve their targets, and have steady positive growth for as many months as possible.

#### *5.3. Run the Classification*

The authors created a 20 Fold Cross Validation NB model based on the trainControl function from the carret package. Based on this model, the testing dataset was loaded and the predictions were requested.

A confusion matrix was built to evaluate the performance of the predictions made over the test dataset. The results are displayed in the Table 8.

The Accuracy (average) of the model is 92.52%. Based on the Confusion Matrix provided in the Table 8 it's possible to verify that the model only failed in 7.5% of the cases.


**Table 8.** Cross-Validated (20 fold) Confusion Matrix.

An evaluation of the Precision, Specificity, Sensitivity, and an F1 score was made to evaluate the model accuracy and the results. As it's possible to verify in the Table 9, the Outstanding has a high Specificity but has a lower Sensitivity.

The F1 score display that the precision of the Not Performing is the highest, but for the Outstanding and Good classes, the accuracy of the tests made are high, which is very important considering that the results of this model are to evaluate people performance. Judging by the dataset size used on this analysis (695 observations), and analyzing it by the classes available, the Good has 269, Not Performing 373, and Outstanding 53. The scores obtained in the Detection Rates reflects the high number of correctly predicted evaluation for each class, and when compared to the Detection Prevalence it confirms the small number of erroneous predictions.


**Table 9.** Evaluation scores for NB model.

The limitation of this work was the data size and availability, as the number of observations available is not high and the number of observations between the available classes can differ. The authors believe that with a larger dataset, where it would be possible to extract data for each class with a similar number of observations, the model accuracy could be improved, and erroneous cases would decrease, leading to a more accurate model.

As the example, in the Figures 9–11, it's possible to review the results of the assessment in Power BI on a dashboard created for salesperson assessment, the dashboard has all the metrics and a classification made by the Predictive Analytics as Not Performing, Good and Outstanding, with this, all the objectives of the research are concluded successfully.

**Figure 9.** Dashboard for a salesperson classified as Not Performing.

**Figure 10.** Dashboard for a salesperson classified as Good.

**Figure 11.** Dashboard for a salesperson classified as Outstanding.

The steps presented above conclude the evaluation of the model performance. This was the last task in the research. In the next chapters, the authors concludes the research with a summary of the work and suggestions for future work.

#### **6. Conclusions**

In this work, the authors applied a Naive Bayes model to classify salespeople into pre-defined categories provided by the business. The classification is done in 3 classes, being: Not Performing, Good and Outstanding. The classification was achieved based on KPI's like growth volume and percentage, sales variability along the year, opportunities created, customer base line, target achievement among others.

The dataset is composed by 594 salespeople classified into three categories being these:


The dataset used had in the beginning 45 columns. It was then reduced to 11 columns, based on several techniques to clean the data and evaluate the relevance of the columns to classify a salesperson's success. In this process, the authors also identified the most critical factors to evaluate a salesperson's performance based on the data, as Growth amount on all the products, Target achievement on all the products, Growth percentage on all the products, and the Number of Months with Growth above 0.

The model was evaluated with a confusion matrix and other techniques like True Positives, True Negatives, and F1 score. The results showed an Accuracy (average) of 92.52% for the whole model. For each of the classes in terms of precision, Not Performing has 90%, Good 87%, and Outstanding 100%. The F1 scores for Not Performing were 94%, for good 86%, and Outstanding 80%.

The accuracy results in this work are high because the size of the dataset and the variations of data have similar behavior for each of the classes. For instance, a salesperson not performing has in most of the time, low growth, low number of opportunities, and sales above 0 for a small number of months in one year; a good salesperson may have high growth in at least six months over one product; the outstanding salesperson should have growth extremely high for at least one product and growth above 0 for at least eight months.

This approach, when data is available, can help produce new guidelines that HR with pre-defined rules can use to automate part of the performance appraisal process. It can be applied to other cases and companies, and with DM, start automating the analysis of complex KPI's with relationships between them to generate a classification.

### *Future Work*

As for future work, the authors proposes the use of a NB model to evaluate salespeople's performance with more CRM information. By taking advantage of other information that is also part of the salesperson job, information like the number Leads, activities (Visits, Calls), the other opportunity states, opportunities conversion rate, and the costs involved for each of the salespeople. The inclusion of subjective factors can also be part of the salesperson's performance. For instance, a more experienced salesperson may be training a junior salesperson, or taking several lost customers to recover, these facts can have an impact on the sales performance of the salesperson, the inclusion of flags that rate these can also be included.

All to aim towards a detailed and precise evaluation of salespeople's performance, increasing the fairness and reduce drastically the amount of work needed to make a performance evaluation for the salesperson.

**Author Contributions:** N.C. is a Master student that performed all development work. J.F. is a thesis supervisor and organized all work in the computer science subject. All authors have read and agreed to the published version of the manuscript.

**Funding:** This work has been partially supported by Portuguese National funds through FITEC programa Interface, with reference CIT "INOV—INESC Inovação—Financiamento Base".

**Conflicts of Interest:** The authors declare no conflict of interest.

### **Abbreviations**

The following abbreviations are used in this manuscript:

B2C Business to Consumer

B2B Business to Business


### **References**


© 2020 by the authors. Licensee MDPI, Basel, Switzerland. This article is an open access article distributed under the terms and conditions of the Creative Commons Attribution (CC BY) license (http://creativecommons.org/licenses/by/4.0/).

MDPI St. Alban-Anlage 66 4052 Basel Switzerland Tel. +41 61 683 77 34 Fax +41 61 302 89 18 www.mdpi.com

*Applied Sciences* Editorial Office E-mail: applsci@mdpi.com www.mdpi.com/journal/applsci

Academic Open Access Publishing

www.mdpi.com ISBN 978-3-0365-7907-8