1. Introduction
Predicting patient outcomes [
1,
2], especially mortality in critical care settings, has long been a priority in medical research. Numerous clinical parameters have been identified as significant predictors. The length of stays in the intensive care unit (ICU) has been associated with severe circumstances such as mechanical ventilation (MV) and psychiatric medication poisoning, in addition to biochemical indicators. Furthermore, early predicting mechanical ventilation duration for patients suffering from acute respiratory distress syndrome (ARDS) can enhance risk stratification and improve care strategies [
3].
Predictive models for assessing risks in critically ill patients, such as hypoglycemia in septic patients or post-surgical outcomes following coronary artery bypass grafting (CABG), are increasingly being explored [
4,
5]. While CABG remains a critical intervention for patients with coronary atherosclerosis, the long-term prognosis remains uncertain, making the development of predictive models essential for improving patient survival probabilities.
Anyway, the complexity of disorders like
pulmonary hypertension (pH) emphasises the requirement for all-encompassing prediction models. Despite advancements in therapy, the pathophysiology [
6] of pH includes a mix of musculoskeletal, cardiovascular, and respiratory problems that lead to increasing exercise intolerance and a reduced quality of life. Because these comorbidities are multi-dimensional, predictive models that leverage electronic health record (EHR) data are crucial for accurately predicting patient mortality and guiding clinical decision-making.
These results demonstrate the importance of using EHR data to forecast patient mortality. EHRs capture a wide range of clinical information, including biochemical markers, comorbidities, and intervention outcomes. These are key for developing robust mortality risk prediction models to improve patient care and outcomes.
Hospitals often record patient data as EHRs, which include data on tests, symptoms, diagnoses, and prescriptions. These EHRs contain structured patient information, lab report details, and unstructured data in free text comments, such as medical notes. This rich information within EHRs is crucial to integrating knowledge about illnesses, treatments, and proteomics into clinical knowledge graphs, all within a real-time patient care system.
Figure 1 shows some example tasks involved in EHR information extraction using a Clinical Data Warehouse (CDW) approach.
Integrating structured and unstructured data in EHRs allows for a comprehensive view of patient health, vital for personalised medicine. Structured data, such as lab results and medication lists, provide clear, quantitative insights. In contrast, unstructured data, like medical notes, offer contextual information and detailed narratives about the patient’s condition and treatment responses. Despite several deep learning methodologies that provide patient-specific death forecasts from unstructured data in EHRs, these current techniques frequently fail to completely extract the concealed, intricate information essential for thorough analysis. Knowledge graphs offer a robust solution by organising unstructured data into interconnected, semantic relationships. They capture complex associations between medical entities—such as symptoms, treatments, and diagnoses—allowing for a more holistic understanding of patient history. By transforming fragmented narratives into a structured form, knowledge graphs enhance the interpretability of unstructured data and enable more accurate predictions, decision-making, and personalised care. This approach bridges the gap between raw clinical narratives and actionable insights, significantly advancing the precision of healthcare analytics.
Knowledge graphs are essential for describing the complex connections and meanings inherent in the data domain. They encapsulate a wide range of biomedical entities and their interrelations, such as diseases, symptoms, drugs, and genes. Using a patient graph network, the framework enables the extraction of meaningful embeddings from the knowledge graph, facilitating the identification of subtle patterns and associations within biomedical information. For instance, a knowledge graph can reveal how specific genetic markers correlate with disease susceptibility or how drug combinations impact patient outcomes.
One effective strategy is to employ GCN for the knowledge graph. These networks leverage data statistics to guide the process of structural learning, presenting a promising approach to unravel the underlying structure inherent in EHR data [
7]. GCNs can capture the dependencies and interactions between different features in the data, which traditional flat models might overlook.
Creating a knowledge graph from EHR data is a collaborative, multidisciplinary effort involving experts in healthcare, data engineering, natural language processing, machine learning, and graph databases. The resulting knowledge graph becomes a powerful tool for healthcare professionals to improve patient care, conduct research, and make more informed clinical decisions. The most common technique for applying neural networks to handle EHR data has been to treat each case as an unordered set of characteristics, essentially representing it as a “bag of features”. Unfortunately, this method disregards the vital geometric structure representing the physician’s assessment process. For instance, when we analyse the encounter in
Figure 2 as a bag of features, we lose crucial information that the combination of Decadron, Revlimid, and Velcade drugs prescribed to patient ‘111791005’ was the suspected cause of anaemia, resulting in severe medical conditions.
Problem Statement: To predict patient mortality using EHRs, a patient knowledge graph focusing on extracting relationships between entities like diagnoses and treatments from unstructured medical notes for better interpretability and decision support.
A patient network is effectively modelled as a graph, where each node represents an individual patient’s hospital stay, encoded using graph representations derived from their medical notes and medical data. The edges between nodes indicate a connection between two hospital stays based on a similarity measure, such as shared diagnoses, treatment responses, or other medical characteristics. The objective of this model is to be on par with the decision-making process of healthcare professionals. In clinical practice, doctors rely on a patient’s medical history and draw on their experience with patients who have exhibited similar conditions or treatment responses. Using this graph-based approach, we can simulate this process computationally, allowing the model to inform decisions about medication, treatment plans, or interventions by identifying patterns and outcomes from similar patients. This knowledge graph captures the implicit knowledge gained from prior cases, supporting personalised and evidence-based care recommendations. This knowledge graph is the input for GCN, which learns the patient embeddings from the patient graph, encouraging a regularised latent space for the embeddings.
Thus, to address the problems of EHR data that does not always provide complete structure information, we propose PKGNN, an ensemble approach for concurrently learning the hidden structural information for different prediction tasks. We contribute the following in this paper:
This study uses the MIMIC-IV benchmark dataset to compare the performance of the proposed framework with that of SOTA deep learning models and predict critical patient outcomes. The models’ performances have been evaluated for mortality and 30-day hospital readmission predictions.
3. Materials and Methods
This section defines the proposed PKGNN, focusing on clinical risk prediction problems with EHR data. The proposed ensemble GCN architecture utilises medical notes [
27,
28] with feature extraction using pretrained BERT variant models.
3.1. Datasets
We validate the proposed PKGNN on a real-world EHR database, Medical Information Mart for Intensive Care (MIMIC-IV) [
29], which is openly accessible. We selected the following two forecasting tasks to evaluate the performance of the proposed models.
The 30-day Hospital Readmission is a binary classification task that aims to predict whether a patient, at time t, will have to be re-admitted to the hospital in the next 30 days. We evaluate the AUROC and AUPRC metrics.
Mortality prediction is a binary classification task that aims to predict whether a patient, at time t, will expire in the upcoming 24 h. We evaluate the AUROC and AUPRC metrics.
The MIMIC-IV [
12] and MIMIC-IV discharge summary notes [
27] database undergo a selection process to identify a subset of data records for our patient cohort, omitting irrelevant and redundant features. The cohort comprises individuals aged 18 years or older who have spent a minimum of one day in the ICU, with an average daily duration exceeding six hours. Patients who are not organ donors and have not been transferred from another hospital are included in our cohort. To minimise ambiguity, we exclude individuals with conditions such as neuromuscular diseases, malignant tumours, and severe burns, which typically require extended hospital stays. Every ICU stay record includes both time-series and static characteristics (e.g., age, gender).
Figure 3 summarises the cohort data that were taken from the tables in the MIMIC-IV database. The 35 unique tables that make up the MIMIC-IV relational database are divided into four different modules that correspond to the core, hospital, intensive care unit, and derived tables. We extract information from the admissions, patients, and icu_stay tables according to the cohort requirements. Further, we transfer the ICD diagnostic codes to the cohort selection schema by mapping them from the diagnosis_icd table. Accessing the derived tables, ICUstay_hourly and vitalsign, is necessary to retrieve the hourly details of patients and their routines. Then, the discharge table’s discharge summary text field is concatenated to create the final complete cohort.
3.2. Problem Formulation
Consider a set of patients, denoted by
, where each patient is denoted as
for
, with N being the total number of patients. Each patient,
, has an associated set of hospital stays, represented by
. Here,
denotes the total number of hospital stays for patient
. Each hospital stay
, where
j indexes the individual hospital stay for patient
, contains a set of medical notes. A sample patient knowledge graph is shown in
Figure 4.
To analyse each patient’s medical record comprehensively per hospital stay, we aggregate all the medical notes of each hospital stay for that patient. This aggregated set of medical notes for patient
for the
jth stay is denoted by
, which is defined in Equation (
1) as the union of the medical notes from each hospital stay where
z denotes each medical notes and
k denotes the total number of medical notes during the
jth hospital stay.
Let
represent the set of all medical notes for patient
, combining data from all hospital stays as in Equation (
2)
The primary purpose is to learn the model’s prediction : , where represents the predicted likelihood for the target label . Learnable modelling characteristics are denoted by .
Figure 5 depicts the overall architecture of the proposed PKGNN.
3.3. Medical Notes Representation and Knowledge Graph
The observed medical notes
for each patient are pre-processed to extract relevant information and create vectors that are used for graph node feature embeddings. Algorithm 1 describes the medical notes’ latent representation process based on feature aggregation for each patient’s hospital stay data. We utilise a pre-trained BERT variant model to tokenise medical notes
and generate feature embeddings
, which represents the medical notes embedding vectors of size 768 for hospital stay
j, as explained in Algorithm 1.
Algorithm 1 Medical Notes Latent Representation |
- 1:
Input: Medical notes , for patient . - 2:
Output: Medical Notes Latent Representation - 3:
- 4:
for each do - 5:
for each hospital stay do - 6:
Concatenate all the medical notes of jth hospital stay using - 7:
Divide into 512-byte chunks since the BERT model can only process 512 input sequences at once. - 8:
for each chunk where do: - 9:
- 10:
- 11:
end for - 12:
Use the average feature aggregator to obtain the feature vector - 13:
. - 14:
end for - 15:
end for - 16:
Note: is of the dimension (where 768 is the dimensionality of the BERT embeddings for each chunk), while is of the dimension (averaged feature aggregated vector).
|
For each hospital stay, the medical notes are concatenated and divided into 512-byte chunks, as the BERT model can only process sequences of this length. Each chunk is tokenised using the BERT tokeniser, and a feature vector is extracted using a BERT variant.
To obtain a single feature vector for each hospital stay, the algorithm averages the feature vectors of all chunks. The resulting vector has a dimensionality of , representing the averaged feature-aggregated vector. This process is repeated for all hospital stays, ensuring that each set of medical notes is transformed into a compact, meaningful representation suitable for further analysis or modelling tasks. This representation serves as a compressed latent encoding of the textual information within the medical notes, facilitating downstream predictive modelling tasks.
An undirected, unweighted knowledge graph
G =
is constructed where
V is the set of nodes and
E is the set of edges. The set of nodes
V is defined as follows:
The hospital stay of patient during their j-th visit is represented by , where j ranges from 1 to . Here, the total number of hospital stays for all the patients are denoted by , where represents the number of hospital stays for patient .
Two nodes
and
corresponding to two hospital stays
p,
q are connected by an edge if their feature similarities are above a threshold
. The similarity score
for vertices is calculated using Equation (
4).
where
represents an edge between hospital stay nodes with
p and
q corresponding to hospital stays, while
i and
j denote patients. Following hyper-parameter tuning, we set
= 0.95, as the average node similarity is high. This will help to link nodes with substantial similarity and reduce the occurrence of false positive predictions.
Here, the constructed knowledge graph is trained using a two-layer GCN model. Let the symmetric adjacency matrix of the graph be , where M is the size of the node set V. The corresponding degree matrix is represented by T, where . The adjacency matrix L is augmented with self-loops to form , where is the identity matrix.
The normalised adjacency matrix
is computed as follows:
where
is the adjacency matrix augmented with self-loops, and
is the degree matrix of
, with
.
In Equation (
5), the matrix
is used in the GCN to aggregate information from node
i and its neighboring nodes, with normalisation based on the degrees of the nodes. Specifically, the feature representation of each node is updated by combining its own features with those of its neighbors, weighted by the normalised adjacency matrix. For an undirected and unweighted graph, this weighting is based on the degrees of the nodes, ensuring that the contributions of neighbors are balanced. This process is formalised in the update rule for the
g-th layer:
where
is the trainable weight matrix,
indicates the ReLU activation function, which is applied element-wise to induce non-linearity, and
is the matrix of activations in the
g-th layer.
To perform classification, the softmax function (Equation (
7)) is applied to the forward model in Equation (
5) to obtain class probabilities:
where
X is the matrix of node feature embeddings and
and
are the input-to-hidden and hidden-to-output weight matrices for the two-layer GCN, respectively. A GCN model for classification is illustrated in
Figure 6.
3.4. Loss Function
Cross-entropy loss is frequently utilised when outcomes are categorised, for instance, in clinical risk classification. The cross-entropy loss function
for all labelled examples is expressed in Equation (
8), as follows:
where
M is the set of indices of labelled vertices in the graph, and
Q is the output feature dimension, equal to the number of classes. And
is the label indicator matrix.
3.5. Ensemble Graph Learning
The proposed ensemble model uses three BERT variants: Clinical BERT [
9], Bio BERT [
10], and Blue BERT [
11]. These models individually extract medical notes’ feature representations and create patients’ hospital stay feature vectors. The proposed ensemble model uses an aggregator to generate a fixed feature vector of medical notes on top of the BERT variants. This technique captures extended word interdependence, which is essential in clinical situations.
Algorithm 2 describes the whole working procedure of the proposed PKGNN framework.
Algorithm 2 Proposed PKGNN framework |
- 1:
Initialisation: ← Learning_rate, Batch_size, seed, max_grad_norm, GCN: Input_size, hidden_size, out_size, num_layers, and threshold - 2:
Obtain the feature aggregated embedding from Algorithm 1. - 3:
for each classifier do - 4:
for each epoch do - 5:
Build a graph G, with node features and edge connections based on the cosine similarity of node embeddings using Equation ( 4) - 6:
Train the GCN model with the node features. - 7:
Calculate the binary cross-entropy loss using Equation ( 8) - 8:
Update parameters using Adam optimiser - 9:
end for - 10:
end for - 11:
Use ensemble approach to obtain the predictions using majority voting Equation ( 9). - 12:
Test and validate the trained model predicting the probability scores.
|
Here, we consider three classifiers based on Clinical BERT, BioBERT, and BlueBERT, respectively. Let
, and
denote the three classifiers. The ensemble voting classifier
predicts the class
from the predictive score of the individual classifiers. The trained GCN models are integrated into the ensemble model. The majority voting classifier determines the final output of the ensemble method, which aggregates the predictions of the three classifiers, as described in Equation (
9):
In this equation, represents the prediction of the classifier for input S. The indicator function is , which returns 1 if the argument is true and 0 otherwise. The term counts the number of classifiers that predicted class .
This study demonstrates that the proposed ensemble graph-based learning approach (PKGNN) is a valuable technique for enhancing the performance of clinical prediction models, in contrast to most previous attempts to construct .
4. Results
We implemented the code with python 3.12.4, pytorch-cuda 11.7, and trained all the models on a workstation with Intel ® XeonTM processor (Intel, Santa Clara, CA, USA), NVIDIA Quadro P5000 Graphics Card, 64 GB RAM (NVIDIA, Santa Clara, CA, USA).
4.1. Evaluation Metrics
The outcome of this classification must be assessed and quantified to determine whether or not the samples are correctly categorised. Accuracy, precision, recall, and AUROC are used as evaluation metrics.
True Positive (TP): Instances of deceased patients that were correctly identified as deceased.
False Positive (FP): Instances of survived patients that were misclassified as deceased.
True Negative (TN): Survived patients’ instances that were correctly identified as survived.
False Negative (FN): Deceased patients’ instances that were misclassified as survived.
Precision: Precision measures how many positive predictions are correct. The precision of a model is 1.0 if it generates no false positives. The formula is as follows:
Recall: The capacity to recognise each relevant value in the data collection is known as recall.
Accuracy: Accuracy describes the number of correct and overall predictions.
AUROC: The area under the ROC curve is the AUROC for a particular curve. The best AUROC is 1, while the lowest is 0.5. The trade-off between TP and FP at various decision thresholds between 1 and 0 is displayed by the AUROC curve. For unbalanced data, this measure provides extra information.
AUPRC: The Area Under the Precision-Recall Curve (AUPRC) is a metric used particularly in scenarios with imbalanced datasets. It summarises the trade-off between precision and recall across different classification thresholds. A higher AUPRC indicates better model performance, with a maximum value of 1 representing perfect precision-recall balance.
R@P80: Recall at 80% precision, indicating the recall when the precision is fixed at 80%. The formula can be expressed as follows:
4.2. Patient Knowledge Graph Framework
We comprehensively evaluate and compare the proposed method against six state-of-the-art (SOTA) methods.
Table 1 shows that the PKGNN model achieves better performance than state-of-the-art results, where ensemble learning for a global patient graph with a feature aggregation method improves performance for 30-day hospital readmission prediction and mortality prediction. For the hospital readmission task, the proposed model achieves an AUROC of 0.951 and an AUPRC of 0.754, surpassing all competing models. Likewise, it attains an AUROC of 0.934 and an AUPRC of 0.652 for mortality prediction, consistently outperforming all competing models.
The model has been trained to minimise loss using the ensemble learning method with a binary cross-entropy loss function. The training configuration is set with a random seed of 42 to ensure the reproducibility of results. The model has been trained for 100 epochs, with logging occurring every 1000 iterations, validation after each epoch, and model checkpoints saved every 10 epochs. To prevent exploding gradients, gradient clipping is applied with a maximum gradient norm of 100. The batch size for training is set to 32. The optimiser is Adam, with a learning rate of 0.01, weight decay of 0.0005, and beta values of 0.9 and 0.999 for the first- and second-moment estimates, respectively. A step learning rate scheduler is employed, which reduces the learning rate by gamma (1.0) every 100 steps.
The dataset configuration specifies a graph-based dataset stored at the root path and uses a threshold of 0.99 for data processing. The GCN model with ensemble learning has an input feature size of 768, a hidden layer size of 16, and an output size of 2, indicating a binary classification task. The GCN has two layers and includes dropout with a probability of 0.5 to prevent over-fitting.
Figure 7 and
Figure 8 show the comparative AUROC plot for mortality prediction and 30-day hospital readmission.
4.3. Ablation Study
In
Table 2 and
Table 3, we performed ablation tests to analyse the efficacy of the predicted task and relationship ensemble module and global patient graph module. In the MIMIC-IV dataset’s prediction task, we compare these experiments.
For the hospital readmission task, Set 1 includes 42,671 nodes and 472,435,459 edges, achieving the highest performance with an AUROC of 0.955, an AUPRC of 0.754, and a recall at 80% precision (R@P80) of 0.641. Set 2, with 6,162 nodes and 7,456,808 edges, shows a slight decrease in performance with an AUROC of 0.934, an AUPRC of 0.652, and an R@P80 of 0.575. Set 3, the smallest set, comprises 3700 nodes and 3,591,982 edges, resulting in an AUROC of 0.903, an AUPRC of 0.604, and an R@P80 of 0.455.
For the mortality prediction task, we again evaluated three sets. Set 1, the largest, includes 42,671 nodes and 472,435,459 edges, achieving an AUROC of 0.934, an AUPRC of 0.652, and an R@P80 of 0.575. Set 2, with 15,292 nodes and 59,643,760 edges, shows an AUROC of 0.917, an AUPRC of 0.544, and an R@P80 of 0.515. Set 3, which contains 6162 nodes and 7,456,808 edges, has the lowest performance, with an AUROC of 0.899, an AUPRC of 0.541, and an R@P80 of 0.415.
These results indicate that more extensive sets of nodes and edges generally improve the predictive performance of the proposed models for both hospital readmission and mortality prediction tasks. The significant performance drop in smaller sets highlights the importance of comprehensive data inclusion in constructing patient graphs.
5. Discussion
Deep learning algorithms for analysing raw health data in ICUs have tremendous potential for improving patient outcomes. These advanced methods enable real-time analysis of complex and unstructured data, facilitating the rapid identification of essential patterns, predicting patient deterioration, and supporting clinicians in their decision-making. Thus, we propose an ensemble patient graph framework with BERT variants: Clinical BERT, BioBERT, and BlueBERT were leveraged as cutting-edge natural language processing models pre-trained on healthcare-specific datasets. These models provide context-specific word representations from medical notes, enhancing generalisation capability and capturing the extended dependencies between words, which are crucial in clinical settings.
We successfully developed the PKGNN framework, a promising and ensemble GCN-based approach to address clinical and biomedical information complexities. The framework provides a structured and meaningful representation of clinical and biomedical data by constructing knowledge graphs and applying an ensemble approach. The ensemble model aims to leverage the strengths of both models to improve overall performance in predicting patient fatality. The ensemble approach employed in the framework excels at uncovering latent patterns and associations within the data. This capability can reveal critical insights that may have otherwise remained hidden. The performance evaluation on the MIMIC-IV dataset demonstrates that PKGNN outperforms the state-of-the-art baselines across two different tasks: mortality prediction and 30-day hospital readmission prediction.
The current study focuses on mortality prediction and 30-day hospital readmission, but there are other critical clinical outcomes that could benefit from similar predictive modelling. Future work could expand the framework to predict disease progression, treatment response, and other relevant clinical outcomes, providing a more comprehensive approach to patient risk assessment.
As technology and data collection methods in healthcare continue to evolve, ongoing research should also investigate the integration of additional data sources, such as genomic data or real-time sensor data from wearable devices, to further enhance the model’s predictive capabilities.