1. Introduction
Diffuse large B-cell lymphoma (DLBCL) is the most common subtype of lymphoma, where tumors develop from lymphocytes, and comprises approximately one-third of non-Hodgkin’s lymphomas, which account for 90% of all lymphomas [
1]. In addition, DLBCL is more likely to be diagnosed at an advanced stage and in older individuals compared to Hodgkin’s lymphomas [
2]. Accurate prognosis prediction remains challenging regardless of the advances in treatment, with patients exhibiting diverse outcomes even within the same risk group. Despite standard therapy, 30–40% of DLBCL patients eventually relapse or are refractory to the initial immunochemotherapy [
3]. Recently, novel therapeutic agents such as chimeric antigen receptor T cells and bispecific antibodies have been actively investigated for more effective and safer treatments in non-Hodgkin’s lymphoma patients [
4].
Accurate prediction of prognosis and treatment outcomes can guide treatment decisions and enhance clinical trial designs [
5]. Traditional survival analysis aims to identify the key covariates contributing to event occurrences such as death or relapse. The international prognostic index (IPI) is a well-established prognostic tool developed in 1993 [
6] using pretreatment clinical risk factors including age, stage, lactate dehydrogenase (LDH), performance status, and extranodal involvement. However, individual patient’s treatment outcomes and prognoses have been revealed to be heterogeneous even in the same IPI risk group [
7].
The Deauville score (DS) is a strong prognostic factor that is used to interpret F-18 fluorodeoxyglucose (FDG) positron-emission tomography (PET)/computed tomography (CT) imaging. Staging FDG PET can play a critical role in staging and risk stratification as it can identify the disease’s extent and location, including the involvement of extranodal sites and bone marrow. In addition, FDG PET/CT has played a crucial role in the prognostication of DLBCL patients. Interim FDG PET scans performed during treatment are used to assess chemosensitivity and predict prognoses [
8], and the DS has been revealed to be predictive of patient outcomes [
7,
8,
9,
10].
Artificial intelligence has recently emerged as a promising tool to improve prognostic accuracy by leveraging large-scale clinical data and incorporating complex interactions among clinical, molecular, and imaging features. In medicine, the Cox proportional hazard (CPH) model [
11], which is a semiparametric approach for calculating the hazard risk of the occurrence of an event, is the traditional standard method for survival analysis [
11,
12,
13]. The CPH model assumes linearity among covariates, and several learning-based methods have been proposed to find non-linear relationships between various features. Machine learning methods such as random survival forests (RSF) [
14], oblique random survival forests (ORSF) [
15], and hazard boosting [
16] have been successfully implemented in survival analysis. Researchers have employed Bayesian networks with the CPH model to improve its prediction performance and interpretability [
17,
18]. Faraggi et al. [
19] extended the CPH model to include non-linearity using a multi-layer perceptron (MLP); however, although a non-linear approach was employed, this model failed to outperform the CPH model [
20,
21]. Artificial intelligence has allowed deep neural networks to efficiently learn key features from clinical data for survival prediction. Recent deep-learning approaches such as DeepSurv [
22], CoxTime [
23], and CoxCC [
23] could replace the linear predictor with deep feed-forward neural networks to enable rich feature representation. DeepSurv demonstrated better performance than CPH on the concordance index (C-index) to model the interactions among covariates for treatment recommendation. The CoxTime model lifted the proportionality constraint by allowing time-dependent effects. Similarly, CoxCC is a proportional version of the CoxTime model. Introducing non-linearity enables the handling of more complex relationships between the clinical covariates and the survival times.
This study aims to develop a robust survival prediction strategy to integrate various risk factors effectively, including clinical risk factors and the DS, at different treatment stages. We conducted the experiments separately for the covariates obtained before and during treatment. Clinical information such as age can be treated as a continuous value, whereas variables such as stage and the DS are categorical. Despite being categorical, these classes cannot be considered purely independent because their order indicates the disease severity. For example, cancer stages (I, II, III, and IV) are assigned based on severity. Similarly, a DS of 1, 2, or 3 is less severe than a score of 4 or 5. As these values are neither continuous nor purely categorical, we consider them categorical variables and use transformer-based neural networks to capture the inter-class relationship within the categories.
Herein, motivated by the success of transformers in capturing critical features in various areas including natural language processing [
24] and vision-related tasks [
25], we design a transformer-based time-dependent survival model (TTSurv) for predicting the overall survival time. We follow the time-dependent approach of CoxTime and enhance the survival prediction using robust features learned through transformers. We compare the effectiveness of various clinical features obtained at two stages in the treatment process using various survival analysis methods. Based on the features’ time of availability, we divide them into before- and during-treatment groups. We then conduct experiments with the features in each group and evaluate the survival models using the C-index and the mean absolute error (MAE); the key contributions of this work can be summarized as follows:
We design a systematic analysis of the clinical covariates based on their occurrence during various stages of the treatment.
We propose a deep-learning-based method for predicting survival time in patients with lymphoma that leverages categorical embedding to represent the disease severity information in categorical data.
This paper is organized as follows.
Section 2 describes the proposed method for survival prediction and
Section 3 details the dataset, experiment design, and evaluation criteria used in this study.
Section 4 provides the survival analysis results and a comparison with existing survival prediction methods.
Section 5 presents the observations and discusses the further directions and challenges in survival analysis research. Finally,
Section 6 concludes this paper.
3. Experiments
3.1. Datasets
We conducted experiments on two clinical datasets collected at Chonnam National University and Hwasun Hospital (CNUHH,
n = 604) and Jeonbuk National University Hospital (JBUH,
n = 220) in 2011–2018. This study was approved by the Institutional Review Boards of CNUHH (CNUHH-2022-095) and JBUH (CUH 2022-11-013). The log-rank test of the patient covariates in the datasets was statistically significant (
p < 0.005), suggesting the importance of the individual features.
Figure 2 depicts the Kaplan–Meier plots of patient properties.
The CNUHH and JBUH datasets have similar percentages of censored cases (70.86% and 73.18%, respectively). The clinical data consist of clinical information including patient age, sex, performance score, lactose dehydrogenase (LDH) level, stage, number of extranodal sites, presence of B-symptoms, and the IPI score. Additionally, for on-treatment evaluation, we included the DS calculated by experienced nuclear medicine physicians at CNUHH and JBUH through observations of interim PET scans.
The IPI prognostic tool was developed in 1993 using five significant risk factors (age, stage, LDH, performance status, and extranodal involvement).
Table 1 shows the dataset’s characteristics before and during treatment. Only age and LDH (IU/L) in the dataset were used as continuous variables; all other covariates were considered categorical. The age distribution of the patients included in this study exhibits notable variance between the CNUHH and JBUH datasets. The former is composed of individuals aged 36–81 years, while the latter encompasses a wider age range of 15–87 years. This discrepancy can be attributed to the distinct patient populations treated at each institution during the period under investigation.
Table 1 presents the age, sex, LDH, performance score, number of extranodal sites involved, bone marrow involvement, B symptom, Ann Arbor stage, the IPI score, and the DS before and during treatment.
3.2. Experiment Design
We conducted survival analysis experiments separately for pretreatment and on-treatment analysis based on the clinical information available before and during treatment. Only the DS was included as an additional feature for on-treatment analysis, which allowed us to compare the impact of the DS on survival prediction. We conducted experiments using the proposed method and compared it with existing state-of-the-art deep-learning survival models: DeepSurv, CoxCC, and CoxTime.
We evaluated the proposed model with five-fold cross-validation. To evaluate using cross-institutional data, we used clinical data from CNUHH and JBUH. First, the CNUHH dataset was split into training and validation sets using stratified k-fold sampling. The same set of training and validation data was used for different models to ensure a fair comparison. After the models were trained on the CNUHH dataset, we evaluated the trained models on the JBUH dataset, which was kept separate from the training process. The best weights saved at each fold were used for inference with the testing set and the average of each fold was reported. We implemented the proposed model in Python using PyTorch, and we used the implementation from the pycox library available on GitHub as the baseline model. The models were trained using an Nvidia GeForce RTX 3080Ti GPU with 12 GB of memory.
3.3. Evaluation Metrics
We evaluated the survival models using two performance metrics, namely the C-index and the MAE.
3.3.1. The C-Index
The C-index is the most common evaluation method for survival analysis and is a measure of ranking for the predicted time. It estimates the probability that the predicted times for individuals and their true survival times have the same order, and is calculated as follows:
where
represents the risk score of a unit. In addition,
is 1 when
and 0 otherwise, and
is 1 when
and 0 otherwise. A C-index of 1.0 indicates perfect concordance and 0.5 represents poor prediction.
3.3.2. The MAE
Although the C-index measures the accuracy of ranking the survival times, it may not provide a fair assessment of a model’s overall performance. For example, the C-index does not consider the magnitude of the difference between predicted and actual survival times. We address this limitation by using the MAE as an additional evaluation metric for survival time prediction models. The MAE involves converting predicted hazards into survival functions and calculating the average difference between the predicted residual life and the true survival time. However, it cannot be used for all samples due to the presence of censored data. We evaluated the MAE for patients with observed events based on each patient’s median life. The MAE, as a complementary metric, provides additional comparison criteria when true survival times are known, and is calculated as follows:
where
represents the number of samples that observed the event and
and
represent the true and predicted survival times, respectively, for the
sample.
4. Results
We compared the performance of the proposed model with DeepSurv, CoxTime, and CoxCC; similar performance was observed for all models.
Table 2 and
Table 3 present the experimental results on the CNUHH and JBUH datasets. The CNUHH column lists the average values of five-fold cross-validation. The model saved at each fold was used to evaluate the unseen data from the JBUH dataset. The proposed model outperformed the existing best-performing model on the CNUHH dataset while achieving a comparable C-index with existing survival analysis methods on the JBUH dataset. We evaluated the models using the C-index and the MAE.
The proposed model was found to outperform existing models in terms of the MAE. Moreover, it performed better on unseen data, which suggests that the transformer architecture used in this model was effective at extracting robust features and enabling generalizability. Additionally, we observed that the models exhibited even higher performance when an additional feature obtained during treatment was included. This finding highlights the high prognostic importance of the DS in patients with DLBCL.
Figure 3 presents the survival curves of the test set as obtained using various survival models. The estimated survival function was similar for various methods.
Figure 4 provides the estimated survival plots for five patients with survival times in the range of 321–2127 days;
Table 4 lists corresponding survival time predictions. The results revealed that the model could estimate the survival times with little error in terms of days. However, Patient JBUH_DLB106 had a low predicted survival time despite their actual survival time being similar to JBUH_DLB029. This outcome could be a special case where the patient lived longer despite having severe symptoms. Patient JBUH_DLB106 was at stage III with a low LDH of 756 and a Deauville score of 3.
As indicated in
Figure 5, we compared features used both before and during treatment analysis for all uncensored patients in the test set. The estimated survival times were generally close to the ground-truth values. However, the survival times were poorly estimated for some cases, as mentioned in
Table 4. In summary, we demonstrated that TTSurv could estimate the survival times with a relatively small MAE of approximately 559 days.
5. Discussion
The proposed model, TTSurv, outperformed the existing state-of-the-art survival prediction model for patients with DLBCL using transformer-based deep-learning models regarding the C-index and the MAE in the dataset. Therefore, we have demonstrated the potential of deep-learning models to reliably predict survival times based on clinical features and the DS. In addition, we have illustrated the importance of the DS obtained during treatment, which significantly improved the model performance, indicating this feature’s high prognostic value. We conducted survival analyses based on two stages: before and during treatment. Although most prognostic clinical features were available at the beginning of treatment, the DS was only available after the interim-PET scan. Our results add to the growing body of evidence supporting the high prognostic value of the DS [
7,
8,
9,
10].
Deep-learning models have demonstrated usefulness in interpreting clinical data and providing better prognosis predictions with less manual feature engineering. Manual feature selection is a traditional method of analyzing clinical data that may not capture all pertinent information. The feature selection process requires expert knowledge and may result in a limited view of the data, resulting in incomplete or inaccurate conclusions. In addition, clinical data often contain categorical features that are not purely independent classes; these features may have complex interrelationships with other variables and their treatment as independent classes may result in the loss of valuable prognostic information. Therefore, a more sophisticated approach is required to ensure that all relevant information is considered and that categorical features are appropriately processed to provide more accurate predictions. We used transformer-based categorical data encoding on clinical datasets to address this problem and developed a deep-learning network for survival prediction.
Moreover, clinical data features are typically grouped into numeric and categorical data. However, the categories featured in clinical data do not have purely independent classes. For instance, patients with cancer are usually categorized into four stages upon diagnosis: I, II, III, and IV. As such, data are not continuous and they are often treated as categorical even though they carry information related to the disease severity on an ordinal scale, and this may result in a loss of valuable information. As transformers have been widely accepted in various domains, including natural language processing and vision-related tasks, we adopted transformers to encode categorical features using transformers. TabTransformer has demonstrated high performance in handling tabular data. Therefore, we adopted transformer-based categorical data encoding in clinical datasets and developed a deep learning network for survival analysis.
Survival analysis is considered more challenging than standard regression tasks due to the presence of data censoring. Censoring occurs when the event of interest is not observed during the study period for various reasons, such as subjects leaving before this study is complete or this study finishing before the event of interest occurs. For example, in survival analysis where the event of interest is death, some patients who survive until the end of the study duration may opt out or move to a different hospital during this study. Survival models manage information censoring by including an event indicator, which is a binary variable that indicates whether the event occurred during the study period. However, data censoring has the potential to reduce model accuracy. In future research, we plan to use larger clinical datasets that include numerous non-censored data to increase our model’s performance.
While the proposed transformer-based survival prediction model has shown promising results in predicting patient outcomes, it is important to acknowledge its limitations regarding its practicability in clinical contexts. One major limitation is that the model relies solely on clinical information and the DS calculated by the experts, which may not capture all relevant clinical information such as imaging information in PET/CT. In our future studies, we plan to overcome this limitation by incorporating radiological images to improve the model accuracy. Additionally, we aim to validate the model on a larger and more diverse patient population to ensure its generalizability. Despite these limitations, our transformer-based model represents a significant step forward in the field of predictive analytics in healthcare and holds great potential for improving patient outcomes.