*Article* **Deep Neural Network for Gender-Based Violence Detection on Twitter Messages**

**Carlos M. Castorena 1,†,‡, Itzel M. Abundez 1,†, Roberto Alejo 1,\*,†,‡, Everardo E. Granda-Gutiérrez 2,†, Eréndira Rendón 1,† and Octavio Villegas 1,†**


**Abstract:** The problem of gender-based violence in Mexico has been increased considerably. Many social associations and governmental institutions have addressed this problem in different ways. In the context of computer science, some effort has been developed to deal with this problem through the use of machine learning approaches to strengthen the strategic decision making. In this work, a deep learning neural network application to identify gender-based violence on Twitter messages is presented. A total of 1,857,450 messages (generated in Mexico) were downloaded from Twitter: 61,604 of them were manually tagged by human volunteers as negative, positive or neutral messages, to serve as training and test data sets. Results presented in this paper show the effectiveness of deep neural network (about 80% of the area under the receiver operating characteristic) in detection of gender violence on Twitter messages. The main contribution of this investigation is that the data set was minimally pre-processed (as a difference versus most state-of-the-art approaches). Thus, the original messages were converted into a numerical vector in accordance to the frequency of word's appearance and only adverbs, conjunctions and prepositions were deleted (which occur very frequently in text and we think that these words do not contribute to discriminatory messages on Twitter). Finally, this work contributes to dealing with gender violence in Mexico, which is an issue that needs to be faced immediately.

**Keywords:** gender-based violence in Mexico; twitter messages; deep neural networks; class imbalance

#### **1. Introduction**

Gender-based violence (GBV) is a big concern around the globe [1]. The United Nations (UN) recognized GBV as a problem involving health and development [2]. A UN declaration about GBV, specifically the cause to women, describes it as all those acts of violence that results or potentially could lead into physical, psychological or sexual damage or suffering; it also includes the menacing of doing such acts, coercion to perform them and arbitrary deprivation of liberty, no matter if this is done in public or private circumstances [3].

Mexico has shown an escalation in the number of victims of GBV due to its social, economic and political context [4,5]. Moreover, crisis like the recent novel coronavirus disease (COVID-19) outbreak have exposed critical inequalities in the social and economic environments, as well as the health system, which have negatively contributed to the GBV problem [6].

Efforts of scholars and activists have increasingly turned society and government attention to this problem, warning about how certain conditions of power or privilege tend to reproduce broader relations of inequality, domination, exploitation, victimization and,

**Citation:** Castorena, C.M.; Abundez, I.M.; Alejo, R.; Granda-Gutiérrez, E.E.; Rendón, E.; Villegas, O. Deep Neural Network for Gender-Based Violence Detection on Twitter Messages. *Mathematics* **2021**, *9*, 807. https:// doi.org/10.3390/math9080807

Academic Editors: Florin Leon, Mircea Hulea and Marius Gavrilescu

Received: 26 February 2021 Accepted: 6 April 2021 Published: 8 April 2021

**Publisher's Note:** MDPI stays neutral with regard to jurisdictional claims in published maps and institutional affiliations.

**Copyright:** © 2021 by the authors. Licensee MDPI, Basel, Switzerland. This article is an open access article distributed under the terms and conditions of the Creative Commons Attribution (CC BY) license (https:// creativecommons.org/licenses/by/ 4.0/).

finally, loss of humanity [1]. In this respect, computer science researchers have developed algorithms and methodologies based on machine learning to address the GBV problem. For example, Ref. [7] presents a camouflaged electronic device to help potential victims of GBV; it allows to send a voice command and Global Positioning System (GPS) location via smartphone to a Control Center, which analyzes the message to properly assist the victim. A similar but more sophisticated work is presented in [8]; it uses two psychological sensors to identify GBV through a robust speaker identification system, based on the evaluation of speech stress conditions by using data augmentation techniques. Rodríguez-Rodríguez et al. [9] used historic open access data to model and forecast GBV through machine learning methods; their methodology produced successful results in three specific Spanish territories with different populations.

GBV has affected many women around the world in online social network environments [10] and several works have been developed to tackle this problem. In Ref. [11], a classification of cyber-bullying detection methods in online social networks was presented; it shows a survey of techniques to automatically identify cyber-bullying through the machine learning algorithms. Another interesting approach is MANDOLA [12]; it is a big-data processing system intended to evaluate the proliferation and effect of online hate-related speech, which is generally inspired by religion beliefs, ethnicity or gender. Gutiérrez-Esparza et al. [13] studied two machine learning algorithms, and the variable importance measures (VIMs) method, to select the best features from the data set, in order to classify situations of cyber-aggression on Facebook for Spanish-language users from Mexico. They collected 2000 Facebook comments, which were manually labeled as racism, violence based on sexual orientation and violence against women, by a group of three machine learning teachers which supported the psychologists who specialized in evaluation and intervention of bullying situation in high schools. Experimental results of these works showed a classification performance greater than 90% in accuracy.

Twitter has been a scenario where violence against women, indigenous, minorities and migrants, is frequent. Consequently, much work has been focused on this problem and the potential use of machine learning has been demonstrated as a methodology in Ref. [14]. In addition, data-mining [15] has been used to detect domestic violence. Other works have been performed for automatic detection of sexual violence [16], cyber-bullying [17], hate expressions [18], offensive or aggressiveness [13,19,20] on the twitter messages' content, in which the feature extraction method, including the appropriate collection of expressions (words), is essential.

On the specific attention to GBV, much research has been performed. Ref. [21] exhibited the use of machine learning methods on Twitter to study about the circumstances implicated in the #MeToo movement (an initiative to denounce GBV), mainly those related to business and marketing activities.

Ref. [22] presented the automatic detection and categorization of misogynous language in Twitter by using different supervised classifiers. Techniques like N-grams, linguistic, syntactic and embedding were used in order to build the feature space of the training data set. One of the main contributions of these work was to make available to the research community a data set of corpus of misogynistic tweets. Ref. [22], and similarly [23] who collect data from Twitter from frequent words in domestic violence, highlighted the importance of building a data set corpus of misogynistic tweets and consider the language regionalization, i.e., the data corpus should be in accordance with regional context [13].

Xue et al. [15] evidenced the viability of employing topic-modeling methods for data-mining on Twitter to identify GBV. An unsupervised algorithm to discover hidden topics in the tweets was used. Twitter messages were converted into a document-term matrix by applying the CountVectorizer method [24], in order to collect words that appear more frequently in domestic violence, which are related to GBV.

In Ref. [16], a deep neural network was applied to identify the risk factor associated with sexual violence on Twitter; however, it did not explain how the messages were pre-processed.

Mohammed et al. [17] recommended an array of unique features obtained from Twitter (based on network, activity type, user as well as the content of the tweet) for the detection of cyber-bullying (which has a direct relationship with GBV). Results showed an AUC of 0.943 indicating that this set of features provides an effective approach for detecting cyber-bullying.

In Ref. [18], an approach to automatic detection of hate expressions on Twitter was shown. Authors collected offensive or hateful expressions for hate speech detection. The pre-processing stage consisted of a cleaning up of the tweets, tokenization, generation of negation expressions (e.g., "not", "never", etc.) and detection of the broadcast of these words. In addition, a feature selection process was done.

Ref. [23] exhibited a technique for detection of xenophobia and misogyny in tweets by using computing methodologies. Authors created a suitable language resource for hate speech recognition in Spanish (Spain), highlighting the importance of language regionalization, i.e., whether it is Spanish from Spain or Mexico.

In [19], an Arabic offensive tweet detector was built. An inherent complexity to classify tweets is noticeable, which is in accordance with the particular language.

In the Mexican Spanish context in Twitter, a few works have been performed for automatic identification of GBV. Most of them have been focused on detection of aggressiveness. Alvarez-Carmona et al. [25] presented an overview of results from MEX-A3T competition (2018), which is addressed to automatic identification of aggressiveness in Mexican Spanish tweets. The competition included two tracks: in the first, author profiling, the aim is to identify the place of residence and occupation of the users; in the second, the goal was detection of aggressiveness in the message. Results showed 76.4% accuracy in the aggressiveness identification task. Results of the deep learning methods used in MEX-A3T did not overcome 68% accuracy [20]. Ref. [20] analyzed the performance of two deep learning models for automatic classification of aggressive Mexican Spanish tweets. It highlighted the low performance of studied deep learning neural networks to identify aggression in Mexican Spanish tweets, i.e., there are still open issues to better understand this topic, thus, they should be addressed.

Based on the previous works, two essential components were identified in the analysis of content in Twitter messages: (a) the suitable collect of expressions (words) related to the topic under study in accordance to regional context, and (b) the extraction features stage by simple techniques like the CountVectorizer method [24], which transforms tweet content into vectors by counting occurrences of each word in each tweet, but also the use of sophisticated methodologies like those presented in [18] or [25].

In relation to the pre-processing stage, it was noted that most of the works need a complex pre-processing or specialized group to manually tag the comments (or use small data sets).

As a relevant concern, it was observed that most recent advances are developed for the English language [25], but the few works performed for other languages agree with the importance of the regional context of the messages in their original tongue [19,23].

In this paper, a simple methodology to identify GBV in Mexican Spanish Twitter messages is studied, which includes three common extraction feature methods: CountVectorizer, TfidfVectorizer and HashingVectorizer. In contrast with other state-of-the-art works, our proposal does not employ a stage to collect expressions related to GBV, but only give to the classifier enough samples previously labeled by human volunteers of tweets containing evidence of GBV or not containing GBV. Thus, the significance of this work can be highlighted as follows:

1. This research contributes to the automatic detection of GBV in Mexican Spanish tweets (specifically contextualized to Mexican language jargon), which is a little faced issue, with the potential use of this work in the early attention of dangerous behaviors in the users.


#### **2. Deep Learning Multilayer Perceptron**

Deep learning neural networks are characterized by the increase of the network depth, i.e., the number of hidden layers; then, the multilayer perceptron is a general and intuitive architecture to be transformed to the deep learning multilayer perceptron (DL-MLP) with two or more hidden layers [26].

DL-MLP tries to find a relation between a set of input vectors *x* and labels *id* by modifying the parameters linking those sets. The output *yj* is a function of *x* and weight *w* so that if *w* is modified, the difference *z* between the system output and target *id* could be minimized. DL-MLP uses two or more hidden layers constituted of nodes or neurons. Each neuron is connected with the neurons of the previous layer and the output signal is calculated by combining all the inputs from the preceding layer [27]. The connections between nodes use a neuronal weight (*w*) to modify the output signal before getting in the neuron; this transformation corresponds to multiply the respective signal (*xi*) times the weight (*wi*).

The use of multiple layers generates a more complex optimization problem, but gains a reduction in the number of nodes per layer inside the architecture [28]. However, the increase of the computational effort can be overcome by the availability of advanced frameworks like Spark [29] and Tensorflow [30] that provide tools to optimize the cost function of the perceptron. The use of such tools makes possible that the DL-MLP could be used increasingly in big-data problems [31,32], and also increases the capability of abstraction of DL-MLP to complex problems [28].

Usually, DL-MLPs are trained by means the back-propagation algorithm (based on the stochastic gradient descent) [33–35] and initial weights are randomly assigned. One of the most common algorithms of descending gradient optimization is Adam [36], which is based on adaptive estimation of first-order and second-order moments [37]. This algorithm reduces the error between the *f*(*x*, *w*) and ˆ *f*(*x*, *w*).

Typically, DL-MLP includes different activation functions that modify the linear space to a nonlinear space of the samples *x* in each hidden layer, namely: Rectified Linear Unit (ReLU) *f*(*z*) = *max*(0, *z*), tangent function *f*(*z*) = *tanh*(*z*), Exponential Linear Unit (ELu) *<sup>f</sup>*(*z*) = *<sup>z</sup>* <sup>≥</sup> <sup>0</sup> <sup>→</sup> *<sup>z</sup>*, *<sup>z</sup>* <sup>&</sup>lt; <sup>0</sup> <sup>→</sup> (*e<sup>z</sup>* <sup>−</sup> <sup>1</sup>) and sigmoid function *<sup>f</sup>*(*z*) = 1/(<sup>1</sup> <sup>+</sup> *<sup>e</sup>*−*z*).

#### **3. Deep Learning for Natural Language Processing and Sentiment Analysis**

The advent of the world wide web and search engines brought with it the emergence of natural language processing (NLP) [38], which allows a machine to process a natural human language and then translates it into a format that is processable and understandable to a computer [39]. This field has received a lot of attention due to the efficiency in language modeling. Some of the NLP models have been applied in various areas, as they provide great mechanisms to analyze text in real time, in addition to the reliability that they also demonstrate in different tasks [40].

Due to the rapid growth of the Internet, the use of social networks, forums, blogs and other platforms where people from all over the world share their ideas, opinions and comments on multiple topics, has increased. Politics, cinema, sports, music, among others, have given rise to a great deal of unstructured information [41]. For this reason, sentiment analysis has become one of the main challenges addressed by NLP, whose main objective is to extract feelings, opinions, attitudes and emotions from the users [42] through a series of methods, techniques and tools on the detection and extraction of subjective information

to detect the polarity of the text, that is, to determine if the given text is positive, negative or neutral [43].

Sentiment analysis has been positioned as one of the essential tools to transform the emotions and attitudes of a text into actionable and understandable information for a machine [44]. It is so important within the NLP that this area has been addressed at 3 different levels [42]: (1) the document level, focused on determining whether an opinion document expresses a positive or negative sentiment, (2) the sentence level, whose task is to check whether each sentence expresses a positive, negative or neutral opinion and (3) the aspect level, responsible for looking directly at the opinion itself.

To address the problems of sentiment analysis, previously, approaches based on machine learning algorithms and the sentiment lexicon have been used. However, these methods have limitations such as limited data, word order and a large number of tagged texts that make them ineffective for NLP tasks [45]. However, for some of these problems, models based on deep learning have been the solution, these methods have been gaining popularity, thus proving to be a better option to face the problem of sentiment analysis and this is attributed to the high performance they show in different tasks of the NLP [46].

For years, the implementation of a deep learning or pattern recognition system in NLP has required careful engineering and extensive experience to design a feature extraction system that can transform raw data into appropriate internal data or in a vector of characteristics that a learning subsystem, generally a classifier, could use to detect patterns [47]. Feature extraction, as a data preprocessing method in the learning algorithm, contributes to performance improvement. The extraction methods used for this task range from simple approaches, such as those based on the bag of words model (like CountVectorizer [24], TfidfVectorizer [48] or HashingVectorizer [49], to more sophisticated approaches, such as transformers [50–53].

#### *3.1. Text Feature Extraction*

The CountVectorizer method converts a document *d* into a numeric vector *d* = {*u*1, *u*2, ... , *ui*, ... , *uT*}, where (*ui*) is the weight of the word with the number *i* in the document *d*. The feature *i* of the document will be the sum of the times that the word *i* appears in it; seen in another way, *ui* will be made up of the frequency of appearance of each word *i* in the document *d* [48].

TfidfVectorizer method uses the CountVectorizer matrix and applies a term frequencyinverse document frequency transformation (TFIDF), which takes a frequency of the word *i*, and the inverse frequency of occurrences in the document *d* (Equation (1)), instead of the raw frequencies of occurrence of a token [54].

$$
\mu\_i = TF\_i \* IDF\_{i\prime} \tag{1}
$$

where the weight (*ui*) is a function of *TFi* (term frequency), i.e., the appearance frequency of the word *i* in a document *d*, and *IDFi* (inverse document frequency) which is:

$$IDF\_i = \log(\text{Total of documents}/DF\_i),\tag{2}$$

being *DFi* (document frequency) the quantity of documents in which the word *i* appears at least once.

By using IDF, the weight of high frequency words that are not significant (like conjunctions, prepositions or common words) is reduced, because these kinds of words will appear in several documents allowing to identify those with specific relevance in certain documents.

HashingVectorizer implementation works in a similar way to CountVectorizer, but it employs the hashing trick to find the token string name to include integer index mapping, normalized as token frequencies. Thus, there is no way to compute the inverse transformation, i.e., it does not consider inverse document frequency. However, it is very efficient for large data sets [49].

The CountVectorizer, HashingVectorizer and TfidfVectorizer methods can use different forms of assigning the number of the words included in a token (this parameter is *Ngram*). In the present work, tokens with 1, 2 or 3 words were used, which can give more relations between the pattern of the data.

#### **4. Methodology**

The methodological aspects of the work are exhibited in this section. Details about data collection, pre-processing, classifier parameters and assessment test are explained in order to allow the replication of the experiments. The source code for this work is accessible through https://github.com/ccastore/GenderViolence (accessed on 1 January 2021).

#### *4.1. Data Collection*

Data were collected by using the *twlets* (http://twlets.com) tool. Twitter messages were collected from 18–19 May 2019, taking tweets comments in Spanish language and located in Mexico (coordinates −118.599, 14.388 to −86.493, 32.718). In order to select tweets related to GBV, messages from individual users, companies and organizations that contained words or phrases related to diverse forms of possible GBV were selected. In addition, news pages and political figures were considered.

A total of 1,857,450 messages were retrieved from Twitter. 61,604 of them were manually tagged by human volunteers as follows: messages referring to GBV (those containing possible intention of GBV) and messages not referring to GBV, resulting in 1604 positive and 60,000 negative tweets.

#### *4.2. Data Pre-Processing*

Once the messages were retrieved from the Twitter stream, they were pre-processed to transform the input text to a normalized, comprehensible model of numbers sequence, proceeding as follows:


Finally, a matrix obtained by CountVectorize, TfidfVectorizer and HashingVectorizer methods were used to build and test the classifier. For this, the hold-out method [27] was applied; it randomly split the original matrix on training (TDS) 70% and testing (TS) 30% data sets, where TDS ∩ TS = ∅.

#### *4.3. Sampling Methods*

Oversampling methods are popular and successful techniques to deal with the class imbalance [55]. The most common algorithms are: (a) Random Over Sampling (ROS), that randomly duplicates samples from the minority class to mitigate the class imbalance, and (b) SMOTE, which produces artificial samples in the minority class by interpolation of near occurrences [56]. Specifically, for each minority class, they find the *k* intra-class nearest neighbors and generate synthetic samples in the direction of those nearest neighbors. In this work, *k* was set to five in SMOTE (as in Ref. [57]) and ROS and SMOTE were applied to the data set to achieve a relatively balanced class distribution.

In particular, for this work, TDS obtained from CountVectorize, TfidfVectorizer and HashingVectorizer methods contains 1122 GBV and 42,000 non-GBV samples (see Sections 4.1 and 4.2); thus, the resultant over-sampled TDS by SMOTE and ROS is composed of 42,000 GBV and 42,000 non-GBV samples approximately, i.e., those methods balance the class distribution.

#### *4.4. Neural Network Set-Up*

DL-MLP was developed on Tensorflow 2.0 and Keras 2.3.1, and Adam algorithm [36] was employed to train it. The Adam algorithm is used to calculate the adaptation of the learning rate for each parameter, storing an exponentially decreasing average of past gradients [30]. The learning rate (*η*) was established as 0.0006, meanwhile the stopping criterion was 20 epochs with a batch size of 150.

DL-MLP was set-up through of the trial and error method, which is usual in neural network environments. For this, we randomly take from TDS a subset ST (about of 20%), that was split into ST*train* and ST*test*, where ST ⊆ TDS, and ST*train*∩ ST*test* = ∅. In this process, we use ST*train* and ST*test* to assess different configurations of numbers of hidden layers and neurons by layer, and the topology that produced the best classification result was selected. Final architecture was a DL-MLP with six hidden layer and sigmoid activation functions, and the number of hidden nodes for each layer was set as 6, 6, 5, 5, 4 and 3, respectively.

#### *4.5. Classifier Performance*

Classification accuracy and error rate are widely used to assess the performance of learning models. Nevertheless, in class imbalanced scenarios these measures are biased to majority classes or more represented classes (for example, in this work, there are much more non-GBV tweets than GBV tweets). Thus, others metrics should be used.

The receiver operating characteristic curve (ROC) is an appropriate instrument to evaluate the classifiers performance on imbalance scenarios, according to the trade-offs between benefits (true positives) and costs (false positives). The quantitative depiction of ROC is the area under the curve (AUC), calculated as *AUC* = (*sensitivity* + *speci ficity*)/2, where *sensitivity* is the percentage of correctly predicted *positive* samples, and *speci ficity* is the percentage of negative samples predicted correctly [58] (see Table 1). In this work, *sensitivity*, *speci ficity* and the *AUC* were used to measure the effectiveness of deep learning neural network to identify GBV on Mexican tweets.


**Table 1.** Confusion matrix for binary classification.

#### **5. Experimental Results and Discussion**

The main experimental results in identifying GBV in Mexican tweets are presented in this section. Table 2 summarizes the results in term of features obtained for extraction methods, classification performance measures *sensitivity*, *speci ficity* and *AUC*.

The number of features for HashingVectorizer method was calculated as trial-error for this work. Several values were tested and the best value was determined to be 350 features. For CountVectorizer and TfidVectorizer methods the default parameters were used. Thus, the employed algorithms settled on number of features (see Section 3.1).

In Table 2, is noted that the class imbalance severely affects the classifier overall performance. Results obtained without using any sampling method indicate that the classifier does not learn the minority class (GBV tweets). Thus, this approach is not appropriate to identify GBV on Mexican tweets.



Results obtained by employing sampling methods (ROS and SMOTE) indicate that the DL-MLP is effective to learn GBV tweets. However, Table 2 shows that when the minority class has a best performance the majority class performance is reduced, as it can be observed from the *sensitivity* and *speci ficity* values. For example, on ROS with HashingVectorizer, and *Ngram* = 1, the high value of *sensitivity* is obtained simultaneously with the worst *speci ficity* value. A similar performance is observed with SMOTE.

*AUC* gives a better understanding of the classifier performance for both classes than the *sensitivity* and *speci ficity* measures. High *AUC* values imply a best trade-off between benefits (GBV tweets correctly classify) and costs (GBV tweets incorrectly classify). In this respect, it is observed in Table 2 that CountVectorizer with *Ngram* = 1 presents the best AUC value. Then, it is suggested that the simplest method obtains the highest score.

A trend in the studied feature extraction methods is that the better values of *speci ficity* and *AUC* are obtained when the *Ngram* = 1 is used than when applying other values. In other words, experimental results of this work notice that to identify GBV on Mexican tweets, the employment of only the mean of each word is an effective approach.

Table 2 shows that the worst AUC values correspond to the HashingVectorizer method. However, this method was developed to work with big data sets; then, it could explain this behavior because the data set used in this research contains only 61,604 samples.

Finally, with respect to the number of features obtained for the extraction methods (CountVectorizer, HashingVectorizer and TfidVectorizer), there is not evidence in the obtained results about the relationship between the number of features used and the classifier performance.

#### **6. Conclusions**

GBV is a problem that exist on the social network Twitter. Many works have been performed to deal with it along with related issues like hate speech, xenophobia, misogyny, domestic violence, among others. A main stage of that research is the collection of a corpus of words related to particular situations and language. In the Mexican Spanish context, few works have been developed to deal with GBV in Twitter messages and the language regionalization has been recognized as critical. In addition, results of the most of those works need to be improved.

Thus, in this paper, a study to identify GBV on Twitter messages in Mexico is presented. Three common feature extraction methods were used (CountVectorizer, TfidfVectorizer and HashingVectorizer) together with a deep learning multilayer perceptron as the classifier. A data set containing 1604 GBV tweets and 60,000 non-GBV tweets from a total of 1,857,450 messages retrieved from Twitter social network were labeled by human volunteers as GBV or non-GBV messages to train and test the proposed scheme.

Experimental results showed that the class imbalance problem significantly affects the classification of GBV messages. In this sense, oversampling methods, mainly ROS and SMOTE, are effective to overcome this problem. Thus, it was noticed that the CountVectorizer method (and a sampling method) allows DL-MLP to identify GBV on Mexican tweets with about 80% *AUC*. As a remarkable result, it is worth to mention that only a minimal data set pre-processing was applied to obtain important results. TfidfVectorizer and HashingVectorizer methods show competitive results, but CountVectorizer presented a trend to obtain the best results.

Results of this research give evidence that giving enough labeled samples, obtained from Mexican Spanish Twitter messages and transformed by simple feature extraction method like CountVectorizer to DL-MLP, can produce improved classification results.

GBV is an issue that must be immediately addressed. In this sense, this study could potentially contribute to deal with gender violence in Mexico because it provides the analysis of useful tools to identify GVB in online social networks despite the language jargon. However, the classification results should be improved because the rate of GBV tweets that have been predicted correctly (*sensitivity*) is still low. The analysis in specific variants of Spanish of certain tools for the detection of GBV could help to push further research needed to improve the studied strategies on the identification of GBV in Twitter messages in Mexican Spanish.

Thus, future work should be addressed mainly to reduce the human effort to label the GBV texts and to test advanced deep learning models in order to increase the classifier performance, including more sophisticated natural language processing techniques. Currently, we work in an application on streaming to identify GVB, which uses a DL-MLP with a rejection option, i.e., when the classifier has doubts about a tweet's content it is rejected and sent to a human volunteer to be targeted and included in the training data set. We consider that this procedure will allow to improve the classifier performance.

**Author Contributions:** C.M.C., R.A.: conceptualization, methodology and experiment; I.M.A.: conceptualization and review; E.R.: supervision; R.A., E.E.G.-G.: writing—review and editing. O.V.: Experiment. All authors have read and agreed to the published version of the manuscript.

**Funding:** This research did not receive external funding.

**Institutional Review Board Statement:** Not applicable.

**Informed Consent Statement:** Not applicable.

**Data Availability Statement:** Not applicable.

**Acknowledgments:** This work has been partially supported under grants of project 5046/2020CIC from UAEMex.

**Conflicts of Interest:** The authors declare no conflict of interest.

#### **References**


### *Article k***-Nearest Neighbor Learning with Graph Neural Networks**

**Seokho Kang**

Department of Industrial Engineering, Sungkyunkwan University, 2066 Seobu-ro, Jangan-gu, Suwon 16419, Korea; s.kang@skku.edu; Tel.: +82-31-290-7596

**Abstract:** *k*-nearest neighbor (*k*NN) is a widely used learning algorithm for supervised learning tasks. In practice, the main challenge when using *k*NN is its high sensitivity to its hyperparameter setting, including the number of nearest neighbors *k*, the distance function, and the weighting function. To improve the robustness to hyperparameters, this study presents a novel *k*NN learning method based on a graph neural network, named *k*NNGNN. Given training data, the method learns a task-specific *k*NN rule in an end-to-end fashion by means of a graph neural network that takes the *k*NN graph of an instance to predict the label of the instance. The distance and weighting functions are implicitly embedded within the graph neural network. For a query instance, the prediction is obtained by performing a *k*NN search from the training data to create a *k*NN graph and passing it through the graph neural network. The effectiveness of the proposed method is demonstrated using various benchmark datasets for classification and regression tasks.

**Keywords:** *k*-nearest neighbor; instance-based learning; graph neural network; deep learning

#### **1. Introduction**

The *k*-nearest neighbor (*k*NN) algorithm is one of the most widely used learning algorithms in machine learning research [1,2]. The main concept of *k*NN is to predict the label of a query instance based on the labels of *k* closest instances in the stored data, assuming that the label of an instance is similar to that of its *k*NN instances. *k*NN is simple and easy to implement, but is very effective in terms of prediction performance. *k*NN makes no specific assumptions about the distribution of the data. Because it is an instancebased learning algorithm that requires no training before making predictions, incremental learning can be easily adopted. For these reasons, *k*NN has been actively applied to a variety of supervised learning tasks including both classification and regression tasks.

The procedure for *<sup>k</sup>*NN learning is as follows. Suppose a training dataset <sup>D</sup> <sup>=</sup> {(**x***t*, **<sup>y</sup>***t*)}*<sup>N</sup> t*=1 is given for a supervised learning task, where **x***t* and **y***t* are the input vector and the corresponding label vector of the *t*-th instance. **y***t* is assumed to be a one-hot vector in the case of a classification task and a scalar value in the case of a regression task. In the training phase, the dataset D is just stored without any explicit learning from the dataset. In the inference phase, for each query instance **x**, *k*NN search is performed to retrieve *k*NN instances N (**x***t*) = {(**x** (*i*) *<sup>t</sup>* , **<sup>y</sup>**(*i*) *<sup>t</sup>* )}*<sup>k</sup> <sup>i</sup>*=<sup>1</sup> that are closest to **x** based on a distance function *d*. Then, the predicted label **y**ˆ is obtained as a weighted combination of the labels **y**(1), ... , **y**(*k*) based on a weighting function *w* along with the distance function *d* as follows:

$$\mathfrak{H} = f(\mathbf{x}; \mathcal{D}) = \frac{\sum\_{i=1}^{k} w(d(\mathbf{x}, \mathbf{x}^{(i)})) \cdot \mathbf{y}^{(i)}}{\sum\_{i=1}^{k} w(d(\mathbf{x}, \mathbf{x}^{(i)}))} \tag{1}$$

The difficulty in using *k*NN is determining the hyperparameters. The three main hyperparameters are the number of neighbors *k*, the distance function *d*, and the weighting function *w* [3]. Firstly, in terms of *k*, a small *k* makes it capture a specific local structure in the data, and thus, the outcome can be sensitive to noise, whereas a large *k* makes it more concentrate on the global structure of the data and suppresses the effect of noise. Secondly,

**Citation:** Kang, S. *k*-Nearest Neighbor Learning with Graph Neural Networks. *Mathematics* **2021**, *9*, 830. https://doi.org/10.3390/ math9080830

Academic Editor: Florin Leon

Received: 24 March 2021 Accepted: 9 April 2021 Published: 10 April 2021

**Publisher's Note:** MDPI stays neutral with regard to jurisdictional claims in published maps and institutional affiliations.

**Copyright:** © 2021 by the author. Licensee MDPI, Basel, Switzerland. This article is an open access article distributed under the terms and conditions of the Creative Commons Attribution (CC BY) license (https:// creativecommons.org/licenses/by/ 4.0/).

the distance function *d* determines how to calculate the distance between the input vectors of a pair of instances with nearby instances having high relevance. Popular examples of this function for *k*NN are the Manhattan, Euclidean, and Mahalanobis distances. Thirdly, the weighting function *w* determines how much each *k*NN instance contributes to the prediction. The standard *k*NN assigns the same weight to each *k*NN instance (i.e., *w*(*d*) = 1/*k*). It is known to be better to assign larger/smaller weights to closer/farther *k*NN instances based on their distances to the query instance **x** using a non-uniform weighting function (e.g., *w*(*d*) = 1/*d*). Thus, a *k*NN instance with a larger weight will contribute more to the prediction for the instance.

The performance of *k*NN is known to be highly sensitive to hyperparameters, the best setting of which depends on the characteristics of the data [3,4]. Thus, the hyperparameters must be chosen appropriately to improve the prediction performance. Since this is a challenging issue, considerable research efforts have been devoted to hyperparameter optimization for *k*NN, which are introduced briefly in Section 2. Compared to related work, the main aim of this study is end-to-end *k*NN learning toward improved robustness to the hyperparameter setting and to make predictions for new data without additional optimization procedures.

This study presents a novel end-to-end *k*NN learning method, named *k*NN graph neural network (*k*NNGNN), which learns a task-specific *k*NN rule from the training dataset in an end-to-end fashion based on a graph neural network. For each instance in the training dataset and its *k*NN instances, a *k*NN graph is constructed with nodes representing the label information of the instances and edges representing the distance information between the instances. Then, a graph neural network is built to consider the *k*NN graph of an instance to predict the label for the instance. The graph neural network can be regarded as a data-driven implementation of implicit weight and distance functions. By doing so, the prediction performance of *k*NN can be improved without careful consideration of its hyperparameter setting. The proposed method is applicable to any type of supervised learning task, including classification and regression. Furthermore, the proposed method does not require any additional optimization procedure when making predictions for new data, which is advantageous in terms of computational efficiency. To investigate the effectiveness of the proposed method, experiments are conducted using various benchmark datasets for classification and regression tasks.

#### **2. Related Work**

This section discusses related work on hyperparameter optimization for the *k*NN algorithm, which has been actively studied by many researchers. As previously mentioned, *k*NN learning involves three main hyperparameters: the number of neighbors *k*, the distance function *d*, and the weighting function *w*. A different dataset requires a different hyperparameter setting, and no specific setting can universally be the best for every application, as indicted by the no-free-lunch theorem [5]. Thus, the proper choice of these hyperparameters is critical for obtaining a high prediction performance. In practice, the best hyperparameter setting for a given dataset is usually determined by performing a cross-validation procedure that searches over possible hyperparameter candidates. Various search strategies are applicable, such as grid search, random search [6], and Bayesian optimization [7]. They are time consuming and costly, especially for large-scale datasets. Previous research efforts have focused on choosing the hyperparameters of *k*NN in more intelligent ways based on heuristics or extra optimization procedures for each query instance.

There are two main research approaches regarding the number of neighbors *k*. The first approach is to assign different *k* values to different query instances based on their local neighborhood information instead of a fixed *k* value [8–12]. The second approach is to employ non-uniform weighting functions to reduce the effect of *k* on the prediction performance.

For the distance function *d*, one research approach is to learn task-specific distance functions directly from data to improve the prediction performance, which is referred to as distance metric learning [13,14]. Many methods for this approach were developed for use in the classification settings [15–19], while some were developed for use in the regression settings [20–22]. Another approach is to adjust the distance function in an adaptive manner for each query instance [23–27]. This requires an extra optimization procedure, as well as a *k*NN search when making a prediction for each query instance.

For the weighting function *w*, existing methods have focused on designing nonuniform weighting functions that decay smoothly as the distance increases [4]. One main research approach is to assign adaptive weights to the *k*NN instances of each query instance by performing an extra optimization procedure [23,25–28], which also helps to reduce the effect of *k*. Another approach is to develop fuzzy versions of the *k*NN algorithm [29–31].

The three hyperparameters affect each other, which means that the optimal choice of one hyperparameter is dependent on the other hyperparameters. Therefore, they must be considered simultaneously rather than independently. Moreover, the methods involving costly extra optimization procedures when making predictions for query instances are computationally expensive, which is undesirable in practice. In addition, the majority of existing methods focus on specific settings, primarily classification tasks. Developing a universal method that is efficient and applicable to various tasks is beneficial. To address these concerns, this study proposes to jointly learn a distance function and a weighting function using a graph neural network in an end-to-end manner, which aims to make it robust to the choice of *k* in the prediction performance and is applicable to both classification and regression tasks.

#### **3. Method**

#### *3.1. Graph Representation of Data*

Suppose that a training set <sup>D</sup> <sup>=</sup> {(**x***t*, **<sup>y</sup>***t*)}*<sup>N</sup> <sup>t</sup>*=<sup>1</sup> is given, where **<sup>x</sup>***<sup>t</sup>* <sup>∈</sup> <sup>R</sup>*<sup>p</sup>* is the *<sup>t</sup>*-th input vector for the input variables and **y***t* is the corresponding label vector for the output variable. For a classification task with regard to *c* classes, **y***<sup>t</sup>* is a *c*-dimensional one-hot vector where the element corresponding to the target class is set to 1 and all the remaining elements are set to 0. For a regression task with a single output, **y***t* is a scalar representing the target value.

The proposed method uses a transformation function *g* that transforms each input vector **x***<sup>t</sup>* into a graph G*<sup>t</sup>* such that G*<sup>t</sup>* = *g*(**x***t*; D). Two hyperparameters need to be determined: the number of nearest neighbors *k* and the distance function *d*. They are used only to operate the transformation function *g* for *k*NN search from D; however, they are not used explicitly in the learning procedure in Section 3.2. For each **x***t*, its *k*NN instances are searched from D\{(**x***t*, **y***t*)} based on the distance function *d*, denoted by N (**x***t*) = {(**x** (*i*) *<sup>t</sup>* , **<sup>y</sup>**(*i*) *<sup>t</sup>* )}*<sup>k</sup> <sup>i</sup>*=1. Then, the *k*NN graph G*<sup>t</sup>* = (V*t*, E*t*) is constructed as a fully connected undirected graph with *k* + 1 nodes and *k*(*k* + 1)/2 edges as follows:

$$\begin{aligned} \mathcal{V}\_{t} &= \{ \mathbf{v}\_{t}^{i} | i \in \{0, \dots, k\} \}; \\ \mathcal{E}\_{t} &= \{ \mathbf{e}\_{t}^{i\_{j}} | i \in \{0, \dots, k\}, j \in \{0, \dots, k\}, i \neq j \}, \end{aligned} \tag{2}$$

where each node feature vector **v***<sup>i</sup> <sup>t</sup>* <sup>∈</sup> <sup>R</sup>*c*+<sup>1</sup> and edge feature vector **<sup>e</sup>** *i*,*j <sup>t</sup>* <sup>∈</sup> <sup>R</sup>*<sup>p</sup>* are represented as:

$$\begin{aligned} \mathbf{v}\_t^i &= \begin{cases} (\mathbf{0}, 1), & \text{if } i = 0\\ (\mathbf{y}\_t^{(i)}, 0), & \text{otherwise} \end{cases};\\ \mathbf{e}\_t^{i,j} &= |\mathbf{x}\_t^{(i)} - \mathbf{x}\_t^{(j)}| \end{aligned} \tag{3}$$

where the *t*-th input vector **x***<sup>t</sup>* is denoted by **x** (0) *<sup>t</sup>* for the simplicity of description. The number *c* is set to the number of classes in the case of classification and is 1 in the case of regression.

In the graph G*t*, the 0-th node corresponds to **x***t*, and the other nodes correspond to the *k*NN instances of **x***t*. Each node feature vector **v***<sup>i</sup> <sup>t</sup>* represents the label information with the last element set to zero, except that **v**<sup>0</sup> *<sup>t</sup>* does not contain the label information and has the last element set to one. Each edge feature vector **e** *i*,*j <sup>t</sup>* consists of the absolute difference between each of the input variables **x** (*i*) *<sup>t</sup>* and **x** (*j*) *<sup>t</sup>* . Thus, G*<sup>t</sup>* represents the labels of the *k*NN instances and pairwise distances between the instances. It should be noted that G*<sup>t</sup>* does not contain **y***t* because it needs to be unknown when making a prediction in a supervised learning setting.

#### *3.2. k-Nearest Neighbor Graph Neural Network*

Here, the proposed method named *k*NNGNN is introduced, which implements *k*NN learning in an end-to-end manner. It adapts the message-passing neural network architecture [32], which can handle general node and edge features with isomorphic invariance, to build a graph neural network for *k*NN learning. To learn a *k*NN rule from the training dataset D, it builds a graph neural network that operates on the graph representation G = *g*(**x**; D) for an input vector **x** given the training dataset D to predict the corresponding label vector **y** as ˆ**y** = *f*(G) = *f*(*g*(**x**; D)).

The model architecture used in this study is as follows. It first embeds each **v***<sup>i</sup>* into a *p*-dimensional initial node representation vector using an embedding function *φ* as **h**(0),*<sup>i</sup>* = *φ*(**v***<sup>i</sup>* ), *i* = 0, ... , *k*. A message-passing step for the graph G is then performed using two main functions: message function *M* and update function *U*. The node representation vectors **h**(*l*),*<sup>i</sup>* are updated as below:

$$\begin{aligned} \mathbf{m}^{(l),i} &= \sum\_{j|\mathbf{v}^{i}\in \mathcal{V}\backslash\mathbf{v}^{i}} \mathcal{M}(\mathbf{e}^{i,j}) \mathbf{h}^{(l-1),j} \,\!/\,\forall i \\ \mathbf{h}^{(l),i} &= \mathcal{U}(\mathbf{h}^{(l-1),i}, \mathbf{m}^{(l),i}) \,\!/\,\forall i. \end{aligned} \tag{4}$$

After *<sup>L</sup>* time steps of message passing, a set of node representation vectors {**h**(*l*),*<sup>i</sup>* }*L l*=0 per node is obtained. The set for the 0-th node {**h**(*l*),0}*<sup>L</sup> <sup>l</sup>*=<sup>0</sup> is then processed with the readout function *r* to obtain the final prediction of the label **y** as:

$$\hat{\mathbf{y}} = r(\{\mathbf{h}^{(l),0}\}\_{l=0}^{L}).\tag{5}$$

The component functions *φ*, *M*, *U*, and *r* are parameterized as neural networks, mostly based on the idea presented in Gilmer et al. [32]. The function *φ* is a two-layer fully connected neural network with *p* tanh units in each layer. The function *M* is a two-layer fully connected neural network where the first layer consists of 2*m* tanh units and the second layer outputs a *m* × *m* matrix. The function *U* is modeled as a recurrent neural network with gated recurrent units (GRUs) [33], which pass the previous hidden state **h**(*l*−1),*<sup>i</sup>* and the current input **m**(*l*),*<sup>i</sup>* to derive the current hidden state **h**(*l*),*<sup>i</sup>* at each time step *l*. The function *r* is a two-layer fully connected neural network where the first layer consists of *p* tanh units and the second layer outputs **y**ˆ by softmax and linear units in the case of classification and regression tasks, respectively. Different types of supervised learning tasks can be addressed using different types of units in the last layer of *r*.

The model defined above is denoted as the function *f* . The model makes a prediction from the input vector **x** and its *k*NN instances in D, i.e., **y**ˆ = *f*(*g*(**x**; D)). The model differs from conventional neural networks in that it does not directly learn the relationship between input and output variables. In terms of *k*NN learning, the weight and distance functions are embedded implicitly into the function *f* . Therefore, the function *f* can be regarded as an implicit representation of a *k*NN rule, in which the functions *M* and *U* work as implicit distance and weighting functions, respectively.

#### *3.3. Learning from Training Data*

Given the training dataset <sup>D</sup> <sup>=</sup> {(**x***t*, **<sup>y</sup>***t*)}*<sup>N</sup> t*=1, the proposed method learns a taskspecific *k*NN rule from D in the form of **y**ˆ = *f*(*g*(**x**; D)). The prediction model *f* is trained based on the graph representation *g* using the following objective function J :

$$\mathcal{J} = \frac{1}{N} \sum\_{(\mathbf{x}\_{l}, \mathbf{y}\_{l}) \in \mathcal{D}} \mathcal{L}(\mathbf{y}\_{t\prime} \boldsymbol{\hat{y}}\_{t}) = \frac{1}{N} \sum\_{(\mathbf{x}\_{l}, \mathbf{y}\_{l}) \in \mathcal{D}} \mathcal{L}(\mathbf{y}\_{t\prime} f(\mathcal{g}(\mathbf{x}\_{l\prime} \mathcal{D}))),\tag{6}$$

where L is the loss function, the choice of which depends on the target task. The typical choices of the loss function are cross-entropy and squared error for the classification and regression tasks, respectively.

#### *3.4. Prediction for New Data*

Once the prediction model *f* is trained, it can be used to predict unknown labels for new data. The prediction procedure is illustrated in Figure 1. Given a query instance **x**<sup>∗</sup> whose label **<sup>y</sup>**<sup>∗</sup> is unknown, its *<sup>k</sup>*NN instances <sup>N</sup> (**x**∗) = {(**<sup>x</sup>** (*i*) <sup>∗</sup> , **<sup>y</sup>**(*i*) <sup>∗</sup> )}*<sup>k</sup> <sup>i</sup>*=<sup>1</sup> are searched from the training dataset D based on the distance function *d*. Then, the corresponding graph G∗ <sup>=</sup> *<sup>g</sup>*(**x**∗; <sup>D</sup>) is generated. The prediction of **<sup>y</sup>**∗, which is denoted by **<sup>y</sup>**<sup>ˆ</sup> <sup>∗</sup>, is computed using the model *f* as:

$$
\mathfrak{F}\_\* = f(\mathcal{G}\_\*) = f(\mathcal{g}(\mathbf{x}\_\*; \mathcal{D})).\tag{7}
$$

**Figure 1.** Schematic of the *k*NN graph neural network (*k*NNGNN) prediction procedure.

The proposed method does not require additional optimization procedures when making predictions. The prediction for a query instance is simply conducted by performing a *k*NN search to identify the *k*NN instances and then processing these instances with the model. This is advantageous in terms of computational efficiency.

As the proposed method learns the *k*NN rule, incremental learning can be implemented efficiently. This is the main advantage of the *k*NN algorithm compared to other learning algorithms, especially when additional training data are collected over time after the model is trained. When new labeled data are added to the training dataset D, the prediction performance will be improved without updating the model.

#### **4. Experimental Investigation**

#### *4.1. Datasets*

The effectiveness of the proposed method was investigated through experiments on various benchmark datasets. They contained 20 classification datasets, and twenty regression datasets were collected from the UCI machine learning repository (http://archive.ics.uci. edu/ml/ (accessed on 10 January 2021) and the StatLib datasets archive (http://lib.stat. cmu.edu/datasets/(accessed on 10 January 2021)). The datasets used for classification tasks were *annealing*, *balance*, *breastcancer*, *carevaluation*, *ecoli*, *glass*, *heart*, *ionosphere*, *iris*, *landcover*, *movement*, *parkinsons*, *seed*, *segment*, *sonar*, *vehicle*, *vowel*, *wine*, *yeast*, and *zoo*. The datasets used for regression tasks were *abalone*, *airfoil*, *appliances*, *autompg*, *bikesharing*, *bodyfat*, *cadata*, *concretecs*, *cpusmall*, *efficiency*, *housing*, *mg*, *motorcycle*, *newspopularity*, *skillcraft*, *spacega*, *superconductivity*, *telemonitoring*, *wine-red*, and *wine-white*. Each dataset had a different number of instances with a different dimensionality. For each dataset, onethousand instances were randomly sampled if the size of the dataset was greater than 1000. All numeric variables were normalized into the range of [−1, 1]. The details of the datasets used are listed in Tables 1 and 2.



The lowest values for each dataset are presented in bold.

**Table 2.** Summary statistics of the RMSE over different hyperparameter settings on the regression datasets.


The lowest values for each dataset are presented in bold.

#### *4.2. Compared Methods*

Three *k*NN methods that use different weighting schemes *w* were compared in the experiments: uniform *k*NN, weighted *k*NN, and the proposed *k*NNGNN. The uniform *k*NN and weighted *k*NN respectively used the following weighting functions:

$$\begin{aligned} w\_{\mathbb{U}}(d(\mathbf{x}, \mathbf{x}')) &= 1/k; \\ w\_{\mathbb{W}}(d(\mathbf{x}, \mathbf{x}')) &= 1/d(\mathbf{x}, \mathbf{x}'). \end{aligned} \tag{8}$$

For *k*NNGNN, the weighting function is embedded implicitly.

For each method, the hyperparameter settings were varied to examine their effects. The candidates for the distance function *d* were as follows:

$$\begin{aligned} \text{Manhattan } d\_{\text{L1}}(\mathbf{x}, \mathbf{x'}) &= ||\mathbf{x} - \mathbf{x'}||\_1; \\ \text{Euclidean } d\_{\text{L2}}(\mathbf{x}, \mathbf{x'}) &= ||\mathbf{x} - \mathbf{x'}||\_2 = \sqrt{(\mathbf{x} - \mathbf{x'})^T(\mathbf{x} - \mathbf{x'})}; \\ \text{Mahalanobis } d\_{\text{M}}(\mathbf{x}, \mathbf{x'}) &= \sqrt{(\mathbf{x} - \mathbf{x'})^T \mathbf{S}^{-1} (\mathbf{x} - \mathbf{x'})}, \end{aligned} \tag{9}$$

where *S* is the covariance matrix of the input variables calculated from the training dataset.

Accordingly, there were a total of nine combinations of distance and weighting functions compared in the experiments, as summarized in Table 3. None of the methods used any additional optimization procedures when making predictions. For *k*NNGNN, the distance function was only explicitly used for the *k*NN search to generate graph representations of the data. For each combination, the effect of *k* was investigated on the prediction performance by varying its value from 1, 3, 5, 7, 10, 15, 20, and 30.



#### *4.3. Experimental Settings*

In the experiments, the performance of each method was evaluated using a two-fold cross-validation procedure. In this procedure, the original dataset was divided into five disjoint subsets. Then, two iterations were conducted, each of which used one subset and the other subset as the training and test sets, respectively. As performance measures, the misclassification error rate and root mean squared error (RMSE) were used for the classification and regression tasks, respectively. Given a test set denoted by D- <sup>=</sup> {(**xt**, **yt**)}*N*- *<sup>t</sup>*=1, the performance measures are calculated as:

$$\begin{aligned} \text{ErrorRate} &= \frac{1}{N'} \sum\_{(\mathbf{x}\_l, \mathbf{y}\_l) \in \mathcal{D}'} I(\text{argmax}(\mathbf{y}\_l) \neq \text{argmax}(\mathbf{\hat{y}}\_l));\\ \text{RMSE} &= \frac{1}{N'} \sum\_{(\mathbf{x}\_l, \mathbf{y}\_l) \in \mathcal{D}'} (\mathbf{y}\_l - \mathbf{\hat{y}}\_l)^2. \end{aligned} \tag{10}$$

For the proposed method, each prediction model was built based on the following configurations. In the objective function J , the loss function L used for the classification and regression tasks was set to cross-entropy and squared error, respectively. For the model, the hyperparameter *L* was set to 3, as Gilmer et al. [32] demonstrated any *L* ≥ 3 would work. The hyperparameter *p* was explored on {10, 20, 50} by holdout validation. In the training phase, dropout was applied to the function *r* with a dropout rate of 0.1 for regularization [34]. During the training, eighty percent and 20% of the training set

were used to train and validate the model, respectively. The model parameters were updated using the Adam optimizer with a batch size of 20. The learning rate was set to 10−<sup>3</sup> at the first training epoch and was reduced by a factor of 0.1 if no improvement in the validation loss was observed for 10 consecutive epochs. The training was terminated when the learning rate was decreased to 10−<sup>7</sup> or the number of epochs reached 500. In the inference phase, for each query instance, thirty different outputs were obtained by performing stochastic forward passes through the trained model with the dropout turned on [35]. The average of these outputs was then used to obtain the predicted label for the instance.

All baseline methods were implemented using the scikit-learn package in Python. The proposed method was implemented based on GPU-accelerated TensorFlow in Python. All experiments were performed 10 times independently with different random seeds. For the results, the average performance over the repetitions was compared. Then, for each of the three weighting functions *w*, the summary statistics of the performance over different settings of distance functions *d* and the number of neighbors *k* are reported.

#### *4.4. Results and Discussion*

Figure 2 shows the error rate comparison results of the baseline and proposed methods with varying the hyperparameter settings on 20 classification datasets. Compared to the baseline methods, *k*NNGNN overall yielded lower error rates at various values of *k* for most datasets. For the results with different hyperparameters, the average, standard deviation, and best error rate for each dataset are summarized in Table 1. *k*NNGNN yielded the lowest average and standard deviation of the error rate over different hyperparameters on most datasets, which indicated that the performance of *k*NNGNN was less sensitive to its hyperparameter settings. In particular, *k*NNGNN was superior to the baseline method when the hyperparameter *k* was larger.

Figure 3 compares the baseline and proposed methods in terms of the RMSE with varying hyperparameter settings on 20 regression datasets. As shown in this figure, the performance curves of *k*NNGNN flattened as *k* increased on most datasets, whereas the RMSE of the baseline methods tended to increase at large *k* for some datasets. Table 2 shows the average, standard deviation, and best RMSE for different hyperparameter settings for each dataset. The behavior of *k*NNGNN was similar to that of the classification tasks. *k*NNGNN showed stable performance against changes in the hyperparameter settings. *k*NNGNN yielded the lowest average and standard deviation of the RMSE for the majority of datasets.

In summary, the experimental results successfully demonstrated the effectiveness of *k*NNGNN in improving the prediction performance for both classification and regression tasks. Although *k*NNGNN failed to yield the lowest error for some datasets, *k*NNGNN yielded high robustness to its hyperparameters. This indicated that *k*NNGNN would provide comparable performance without carefully tuning its hyperparameters; thus, it can be preferred in practice considering the difficulty of choosing the optimal hyperparameter setting. Because the performance curve of *k*NNGNN flattened at large *k* values on most datasets, setting a moderate *k* value around 15∼20 would be reasonable considering the trade-off between the performance and computational cost.

#### **5. Conclusions**

This study presented *k*NNGNN, which learns a task-specific *k*NN rule from data in an end-to-end fashion. The proposed method constructed the *k*NN rule in the form of a graph neural network, in which the distance and weighting functions were embedded implicitly. The graph neural network considered the *k*NN graph of an instance as the input to predict the label of the instance. Owing to the flexibility of neural networks, the method can be applied to any form of supervised learning tasks including classification and regression. It does not require any extra optimization procedure when making predictions for new data, which is beneficial in terms of computational efficiency. Moreover, as the method learns the *k*NN rule instead of the explicit relationship between the input and output variables, incremental learning can be implemented efficiently.

The effectiveness of the proposed method was demonstrated through experiments on benchmark classification and regression datasets. The results showed that the proposed method can yield comparable prediction performance with less sensitivity to the choice of its hyperparameters. The proposed method allows more robust *k*NN learning without carefully tuning the hyperparameters. The use of a graph neural network for *k*NN learning may still have room for improvement and thus merits further investigation. One practical concern is the high complexity of a graph neural network in terms of time and space, which increases with *k*. A graph neural network cannot be trained in a reasonable amount of time without using a GPU. Alleviation of complexity to improve learning efficiency will be an avenue for future work.

**Funding:** This work was supported by the National Research Foundation of Korea (NRF) grant funded by the Korea government (MSIT; Ministry of Science and ICT) (Nos. NRF-2019R1A4A1024732 and NRF-2020R1C1C1003232).

**Institutional Review Board Statement:** Not applicable.

**Informed Consent Statement:** Not applicable.

**Data Availability Statement:** Not applicable.

**Conflicts of Interest:** The authors declare no conflict of interest.

#### **References**

