To test whether our proposed VBA framework can effectively address the impact of distributional shifts caused by confounders on model generalization, we conducted experiments in this section on the simulated, semi-synthetic, and real-world datasets, respectively. We compare our proposed VBA framework with Empirical Risk Minimization (ERM), Distributionally Robust Optimization (DRO) [
27], Kernelized Heterogeneous Risk Minimization (KerHRM) [
10], Environment Inference for Invariant Learning (EIIL) [
30], and Invariant Risk Minimization (IRM) [
1]. The IRM method, representing invariant causal learning approaches, requires environment labels. Therefore, we provided environment labels to IRM during training, while other algorithms did not require environment labels. We implement the VBA framework and baselines with PyTorch 2.0.1 on a computer with NVIDIA RTX 3070 Ti. The implementation of the VBA framework is accessed on 19 December 2023. It can be downloaded at
https://github.com/hangsuuuu/VBA. The parameter settings of the VBA framework in each dataset are shown in
Table 1. The settings for other methods are described in
Appendix E.
4.1. Linear Simulated Data
In this subsection, we assess the performance of the VBA framework on linear simulated data. We employed three different data generation methods. The detailed descriptions of the data generation procedures can be found in
Appendix D.
To test the performance of the VBA framework under different conditions, we designed three different OOD cases. The causal graphs for the three cases are illustrated in
Figure 4. In each case, the label
Y is generated by a linear combination of
and
. From
Figure 4 and
Table A1, it can be observed that
C is a common unobservable cause for both
and
Y. Hence,
C serves as a confounder for
and
Y. By altering the distribution of
C, we can change
, thereby simulating an OOD dataset.
We set continuous values for
y to test the model’s regression ability. In each case, we generated 800 data points as the training set and 200 data points as the test set. To test the model’s performance when there is a distributional shift in the confounder, we use different confounder distributions for the training and test sets. In Case 1 and Case 2, we simulated the distributional shift by changing
and
in
Table A1. We chose
in the training set, and
in the test set. In Case 3, we chose
in the training set and
in the test set. To evaluate the performance of the VBA framework when the prior of the confounder is not a Gaussian distribution, as in Case 3, we generated C using a uniform distribution. We simulated the distributional shift by altering the range of values in the uniform distribution.
The experimental results are reported in
Table 2. Each experimental result is the average obtained after running the experiments 10 times.
In
Table 2, we evaluate the performance using Mean Squared Error (MSE) and the variance of the MSE. The results indicate that all methods perform suboptimally in Case 3. On the one hand, this is due to using a uniform distribution to generate confounders in Case 3. On the other hand, the distributional shift is most pronounced in Case 3. Overall, ERM performs the worst as it has the highest MSE and variance. DRO, KerHRM, and EIIL show varying performances across the three cases, but, overall, KerHRM exhibits the best performance among these three environment inference algorithms. IRM performs significantly better than DRO, KerHRM, and EIIL, as it utilizes environment labels directly during optimization. In contrast, the performance of the latter three algorithms depends on the effectiveness of environment inference. Our proposed VBA framework performs better than IRM in most cases (Case 1 and Case 2), even without environment labels. However, the MSE of the VBA framework in Case 3 is significantly higher than in Case 1 and Case 2. This is because the distribution of the confounder in the test set differs significantly from a Gaussian distribution, resulting in the model having difficulty accurately inferring the confounder. Although the performance of the VBA framework decreases in Case 3, its MSE remains close to mainstream methods that do not require environment labels (such as KerHRM), and there is no significant increase in variance.
Table 3 presents the methods’ time consumption, including the convergence time during the training phase and the inference time during the prediction phase, where the inference time refers to the time required for the model to predict 200 data points from the test set. Due to the lowest computational and parameter complexity of the ERM model, it exhibits the fastest inference speed. However, ERM has a longer convergence time due to the need for a large number of iterations to converge on OOD datasets, especially in datasets with complex variable relationships, such as in Case 3. Since EIIL performs only environment inference without outputting predicted values for
y, for comparison with other models, we utilize the environment classification output by EIIL as the environment label for IRM. Subsequently, we employ the IRM method to obtain the final prediction results of EIIL. Therefore, to ensure fairness, we recorded the total convergence time for both the environment inference module of EIIL and the inference module of IRM when calculating the convergence time. The recorded inference time reflects the time required for the environment inference module of EIIL to perform environment classification. Thus, EIIL has the longest convergence time among all of the methods. DRO and KerHRM do not separate the training of their environment inference and prediction processes; therefore, their convergence times are shorter than that of EIIL. Although the VBA framework comprises three modules (encoding, sampling, and inference), the simultaneous training of these modules and the backdoor adjustments allow the model to converge after fewer iterations. Consequently, the training time of VBA is comparable to that of IRM, and, in some complex scenarios, it even exhibits the least convergence time. However, the inference time of VBA is slower than that of IRM and ERM. This is because the input features need to pass through multiple modules to obtain the predicted values.
Since the VBA framework uses
and
to predict the values of
Y, the ability to generate the correct confounders is a crucial evaluation metric for assessing the interpretability of the VBA framework. To evaluate whether the VBA framework can accurately estimate the confounders, we present in
Table 4 the mean values of the estimated confounders obtained by the sampler of the VBA framework. Since different distributions are used to generate the confounders in different environments when generating the dataset, we can evaluate the VBA framework’s ability to estimate the confounders by examining the mean values of the estimated confounders in different environments.
Table 4 shows that the estimated confounders’ mean values by the VBA framework are very close to the true values, demonstrating that the VBA framework provides nearly unbiased estimates for the confounders. The accurate estimation of the confounders is not only a necessary condition for OOD predictions but also indicative of the interpretability of the VBA framework.
4.2. Colored MNIST
To further assess the performance of our method on high-dimensional data, we utilized the Colored MNIST [
1], a semi-synthetic dataset, in this experiment. The Colored MNIST, is derived from MNIST and is designed for binary classification methods. In Colored MNIST, the hand-written digits 0–4 are labeled as
, and digits 5–9 are labeled as
. To simulate distributional shifts in the data, for digits with
, we colored them green with a probability of
, and, for digits with
, we colored them green with a probability of
. Then, the remaining uncolored digits were colored red. In the training set, we set
, while in the testing set,
. The data for this example can be generated using the code available at
https://github.com/facebookresearch/InvariantRiskMinimization. The code is accessed on 29 September 2020. For the IRM method, we divided different environments based on different values of
and enabled IRM to utilize environment labels. For other methods, we mix generated training data from different environments and use them as input for the models.
In Colored MNIST, the original pixel values and the arrangement of pixels in the images constitute the model’s raw input . The true numerical value of the handwritten digit serves as an unobservable confounder that influences both and Y. The proposed VBA framework is to eliminate the impact of environmental factors (digit color) on model training through backdoor adjustment.
Table 5 presents the experimental results of VBA and other methods. It includes the accuracy and time consumption of all methods. As in the previous experiment, each experimental result is the average obtained after running the experiments 10 times. Inference time refers to the time required for the model to predict 10,000 images from the test set. We employed three metrics to evaluate the methods’ performance: training accuracy, testing accuracy, and generalization gap (
). To assess the impact of the number of environments in the training set on the methods’ performance, we tested each method’s performance as the number of environments varied from two to seven. In different training environments,
values range from
to
. The testing accuracy of each method with respect to the number of training environments is shown in
Figure 5.
By benefiting from environment labels, IRM achieves decent testing accuracy. However, in scenarios with fewer environments, the performance of IRM significantly lags behind the VBA framework. This result indicates that the performance of IRM is highly dependent on the heterogeneity of environments. KerHRM and EIIL have similar test accuracies in various situations because both methods are based on environment inference, and they obtain similar environment partitions. However, KerHRM and EIIL fall significantly behind the VBA framework regarding test accuracy and generalization gap. ERM and DRO obtain test accuracies close to random selection () in this experiment, and, as the number of environments increases, the performance of DRO even falls behind that of ERM. As the number of environments increases, the distributional shift in the training data is weakened. This benefits ERM in learning the relationship between features and labels, while it does not significantly assist the robust optimization algorithm DRO. The results show that VBA exhibits the best testing performance in terms of testing accuracy and generalization gap. The VBA framework maintains high prediction accuracy when the number of training environments is two. However, due to their high requirements for environmental heterogeneity, other methods exhibit much lower prediction accuracies than that of the VBA framework. This indicates that the VBA framework has a more significant advantage when the training set exhibits low environmental heterogeneity.
Regarding time consumption, the VBA framework benefits from the rapid convergence facilitated by the variational backdoor adjustment intervention, resulting in the least convergence time. However, the inference speed of the VBA framework ranks only at a moderate level among all of the methods. This is attributed to its complex inference process. DRO struggles to converge on the Colored MNIST dataset, with its training iterations far exceeding those of other methods. Consequently, the convergence time of DRO is significantly longer than those of other methods. This indicates that robust optimization methods struggle to converge stably to solutions with robustness.
4.3. Real-World Data
In this subsection, we utilize two real-world datasets to assess the practical applicability of the VBA framework in real scenarios. The two datasets are Non-I.I.D. Image Dataset with Contexts (NICO) [
47] and house sales prices from King County, USA [
10], respectively. We evaluate the methods’ classification and regression capabilities in real-world scenarios through these two datasets.
NICO contains wildlife and vehicle images captured in different environments. The environments in the NICO dataset are divided based on the collection environment of the images. This dataset is available at
https://nico.thumedialab.com. This dataset is accessed on 18 April 2022. In this experiment, we use NICO to construct a binary classification dataset. The dataset contains images of cows and bears from three environments (forest, river, and snowfield). The different collection environments led to a distributional shift in the data. Since our goal is to classify the species of animals, in this dataset, the label
Y represents the species of the animals, and the confounder
may include advanced features such as the outline and color of the animals, as well as the background color. We choose data from the forest and river environments to form the training set, while data from the snowy environment serve as the test set.
The house sales prices dataset comprises the sale prices and 17 different house attributes, such as the number of rooms, the built year, etc. This dataset is available at
https://www.kaggle.com/c/house-prices-advanced-regression-techniques/data. This dataset is accessed on 11 January 2020. The dataset contains data for a total of 3000 houses. We use the house sale prices as the target variable
Y, with the other information about the houses serving as predictive variables
. The relationship between predictor variables and house prices may vary with the built time of the houses. This is because the criteria for assessing the value of houses may change over different periods. Thus, the dataset experiences a distribution shift as the construction year changes. We can partition the data into different environments based on the built year to construct an OOD dataset. The built years of houses in the dataset range from 1872 to 2010. However, due to the scarcity of houses built between 1872 and 1900, we only select data with built years between 1900 and 2010. We divide the dataset into five periods, where the first three decades serve as the training set, and each subsequent period contains a time span of two decades. We test each method in the subsequent four periods and display the MSE for each method in
Figure 6. Since training IRM requires environment labels, we divide the training environment into two environments based on the built year, using these as the training environment for IRM.
The results of the experiment on NICO are provided in
Table 6. Inference time refers to the time required for the model to predict 250 images from the test set. The testing accuracy of ERM and DRO is very low (close to
), indicating that the dataset exhibits a distributional shift. KerHRM and EIIL still exhibit similar performance, both in terms of testing accuracy and generalization gap. This indicates that invariant learning methods based on the environment have limited capabilities in handling complex data, as their testing accuracy is only slightly higher than that of ERM. The testing accuracy of IRM shows a significant improvement compared to that of ERM, indicating that environment labels provide valuable information for addressing out-of-distribution generalization problems. The proposed VBA framework outperforms all of the compared methods, exhibiting the best testing accuracy and the lowest generalization gap. This indicates that our designed backdoor adjustment method can address complex OOD generalization problems effectively.
The VBA framework exhibits the shortest convergence time among all methods, indicating its ability to rapidly converge to optimal solutions even on complex datasets. This is attributed to the VBA framework having three modules with different functionalities (encoding, sampling, and inference), each of which can rapidly accomplish its specific task during training. Despite incorporating environment labels, IRM experiences a longer convergence time than VBA. This is attributed to the challenge faced by IRM’s single model structure in effectively capturing the complex variable relationships present in the OOD dataset. While KerHRM also adopts a training approach with multiple modules, it initially requires clustering to partition data environments, a process that typically converges slowly on complex datasets. Consequently, KerHRM has a longer convergence time than VBA. Due to the complexity of the model parameters, the VBA framework requires a longer inference time. However, even on complex datasets, the inference time of the VBA framework is only slightly longer than those of DRO and EIIL and shorter than that of KerHRM.
Figure 6 shows that all methods perform well on the test set close to the training environment, while the MSE significantly increases in the test set far from the training environment. This aligns with our intuition, as greater time intervals make the distributional shift more pronounced. The MSEs of ERM and DRO increase sharply, and the MSE of DRO was even higher than that of ERM in the test set, indicating that DRO cannot achieve OOD generalization in the real-world dataset. IRM shows slower MSE growth than KerHRM and EIIL, and its MSE is lower than those of KerHRM and EIIL across all test sets. This suggests that environment labels are beneficial for enhancing OOD generalization. The MSE curve for the VBA framework exhibits a very gradual increase, and, in E4 and E5, the MSE is even lower than that of IRM. This indicates that the VBA framework maintains excellent OOD generalization even in situations with strong dataset distributional shifts.
In this subsection, we evaluate the performance of the VBA framework in two real-world scenarios. The dataset includes both high-dimensional image data and low-dimensional tabular data. Therefore, the experiments in this subsection can validate the practical applicability of the method. The experimental results demonstrate that the VBA framework performs excellently in various scenarios, even when significant distributional shifts occur in the data. This indicates that the VBA framework has excellent applicability in real-world scenarios.
In the experimental section, we combined several mainstream backbone networks, such as MLP, CNN, ResNet, etc., with the VBA framework and achieved good performance. Since the VBA framework does not specify the type of network used. Theoretically, the VBA framework can be combined with any backbone network. Therefore, when applying the VBA framework in practice, we can choose an appropriate backbone network based on the data type. This enables the VBA framework to apply to a wide range of data types.