1. Introduction
In recent years, deep learning has contributed to remarkable progress in the field of medical image segmentation, offering substantial support for diagnosis and treatment. However, while existing image segmentation networks offer precise results, their structures are intricate, and their computational complexity is high. This poses challenges for practical applications, particularly when computational resources are constrained. It is thus necessary to find a balance between segmentation accuracy and computational efficiency.
Since the interest in convolutional neural networks (CNNs), fundamental network architectures such as LeNet [
1], AlexNet [
2], VGGNet [
3], GoogLeNet [
4], and ResNet [
5] have since been introduced. CNNs ushered in a revolutionary era in medical image processing and analysis. This can be primarily attributed to their ability to handle high-dimensional data and their outstanding performance in image recognition through hierarchical learning [
6]. In the domain of medical image segmentation, the U-Net [
7] network model, combining encoder and decoder structures and employing skip connections to merge low-level and high-level features, holds a prominent position. This approach excels at the preservation of intricate details. The network architecture of the V-Net [
8] network model, an offspring of U-Net, closely resembles U-Net. The addition of residual operations and 3D convolutional kernels equip V-Net for 3D target segmentation. Attention U-Net [
9], another CNN built upon the U-Net architecture, introduced the noteworthy innovation of the Attention Gate (AG) module. This module leverages soft attention and seamlessly integrates attention mechanisms into the skip connections and up-sampling modules of U-Net, thereby achieving spatial attention refinement. The improved U-Net (mU-Net) [
10] proposed by Seo et al. emphasizes the critical role of local features in CT image segmentation for the liver and liver tumors by incorporating object-dependent high-level features. This approach effectively enhances the extraction of local features, increases sensitivity to subtle changes, and significantly boosts segmentation performance, further demonstrating the necessity of leveraging local information in medical image analysis.
The Transformer architecture [
11] has had a profound impact on the field of natural language processing. This model leverages self-attention mechanisms to capture long-term dependencies in sequential data. Transformers excel in tasks such as machine translation, language generation, and text categorization. The success of this architecture has given rise to models such as BERT, GPT, and T5 [
12,
13,
14,
15]. Vision Transformer (ViT) [
16] demonstrated that the Transformer architecture is not only applicable to natural language processing but also to computer vision. In March 2021, Microsoft Research Asia introduced the Swin Transformer, which employs sliding windows and a hierarchical structure. Swin Transformer became the backbone of machine vision with its core design element of the “shifted window” (shift of the window partition), executed between two consecutive self-attention layers. This shift operation facilitates interactions between previously independent windows, significantly enhancing the model’s ability to capture complex relationships [
17]. SwinUNet [
18] then leveraged the multi-scale features of the Swin Transformer to achieve superior segmentation performance while retaining the advantages of the U-Net model, such as skip connections and upsampling modules.
Each model has its own set of strengths and limitations. The Transformer architecture excels in capturing global information and modeling long-range dependencies through its self-attention mechanism. However, its reliance on self-attention comes at the cost of higher computational demands, which limits its applicability to large image sizes or real-time applications. CNNs are more efficient but struggle to accurately model global information, potentially impacting segmentation [
18,
19,
20]. Thus, many researchers have opted to combine CNNs with the Transformer architecture, most notably in the form of the TransUNet [
19] and UTNet [
20]. The former encodes labeled image patches from CNN feature maps into an input sequence, extracting global context, and then feeds it into the Transformer. The latter directly integrates self-attention into a CNN to enhance segmentation. While these methods yield accurate results, they still feature a substantial number of parameters and computational demands, thereby limiting their applicability.
This paper leverages the advantages of previous research to propose an innovative deep learning network architecture. Our primary objectives are to reduce network complexity, enhance feature representation, and maintain precision. We first reduce the computational complexity of the network by replacing the traditional attention mechanism with a lightweight shift operation within the ViT architecture. This shift operation retains the spatial structure of image information to a considerable extent while simultaneously reducing the network’s parameter count and computational load. Second, to maximize the utilization of feature information across different scales, we introduce the strategy of full-scale progressive skip connections, effectively fusing multi-scale features. This architecture, while introducing additional contextual information, brings about higher computational complexity. To mitigate this, we also introduce depthwise separable convolution. Experiments conducted on multiple datasets demonstrated the efficacy of the proposed model. Compared with the traditional TransUNet, the optimized model is superior in terms of medical image segmentation tasks and offers markedly lower computational complexity and a reduced parameter count. This network is thus more suitable for resource-constrained scenarios such as mobile devices or edge computing environments. This represents a valuable practical contribution to the field of medical image segmentation.
The remainder of this paper is structured as follows:
Section 2 provides an overview of related work, while
Section 3 offers a brief introduction to the proposed model.
Section 4 delves into network details.
Section 5 outlines implementation and presents experimental results. The paper is summarized and concluded in
Section 6.
4. Compact Deep Learning Model Using ShiftViT Framework and Optimized Skip Connections
This paper proposes an innovative network model based on the concept of the TransUNet framework and the pre-trained feature extraction network of ResNet50. In this section, we outline three key structures: ShiftViT, full-scale progressive skip connections, and depthwise separable convolution. The overall structure of the model is explained in
Section 4.4.
4.1. Simplification of ViT Structure: Introducing ShiftViT
Suppose there is an input sequence X, which contains N elements, each with embedding dimension d, and the size of the attention weight matrix A computed by the self-attention mechanism is N × N. The complexity of the structure can then be expressed as O(
× H × d), where
denotes the complexity due to the fact that each element is required to compute the attention weights with other N−1 elements; H denotes the number of attention heads (i.e., the number of multiple attention weights computed for each element); and d denotes the embedding dimension, which represents the feature dimension of each element. This complexity represents the amount of computation required by the self-attention mechanism to compute the attention weights. For large input sequences, the
factor causes a rapid increase in computational complexity. If the number of attention heads H is also large, the complexity will increase further. This gives us the following objective function:
Thus, the computational complexity of the multi-head self-attention mechanism in the encoder structure of the ViT(see
Figure 1a) is high, and simplifying this module is key to a robust network structure.
Wang et al. [
27] pointed out that it is not the attention mechanism but the overall framework structure that allows the network to achieve precise segmentation. Thus, replacing this mechanism with a shift operation that only extracts model features (i.e., it does not need to learn the weights and biases) further improves the precision of segmentation while simplifying the network. The improved ShiftBlock structure is shown in
Figure 1b and defined as follows:
where x is the input to the structure, F(x) is the output of the structure, s(x) is the shift operation, N(·) represents the layer normalization operation, and M(·) is the multilayer perceptron MLP structure.
The shift operation is a local pixel translation that captures the contextual information of an image by establishing the relationship between pixels. The importance of local features in image segmentation tasks has been emphasized in previous studies, demonstrating that the effective utilization of local information can lead to improved segmentation performance [
10]. By enhancing the retention of local features, this operation improves the model’s sensitivity to subtle changes and boundaries. Additionally, the shift operation significantly lowers computational complexity, simplifies the model structure to reduce the risk of overfitting, and improves the flow of information between different sub-regions. Its procedure is as follows: (1) select a region of the input feature map, (2) divide it into four equal parts according to the channels, and (3) perform translation on these four parts in the left, right, up, and down directions, while keeping the remaining channel unchanged. After shifting, out-of-range pixels are discarded, and empty pixels are filled with zeros. The shift operation is defined as follows with a step size of 1 pixel:
where
represents the input feature,
represents the output feature,
is the batch index,
and
are the row and column indices of the feature map, and
is the channel index.
The ShiftBlock structure is composed of three parts: the shift operation, layer normalization, and the MLP block, which is shown in
Figure 1b. Following feature extraction by the convolutional layer, the feature map is cut into fixed-size patches, and each block is spread into a one-dimensional tensor. Then, a linear transformation is performed on each plot tensor to map it to the specified embedding dimension (embed_dim); each tensor is then input to the ShiftBlock structure for shifting. The output of the final MLP is the final output of the ShiftViT module.
4.2. Full-Scale Progressive Skip Connections
To optimize segmentation, we developed the full-scale progressive skip connection module, which uses progressive up-sampling to achieve full-scale skip connections.
4.2.1. Progressive Upsampling
In traditional up-sampling methods, low-resolution feature maps are usually restored to the original image size by interpolation operations (e.g., bilinear interpolation). However, this can cause image blurring because upsampling the feature maps one by one may result in information loss. Progressive upsampling overcomes this problem by dividing the upsampling operation into multiple stages, alternating between a convolution operation and a 2-fold upsampling operation. This gradually increases the resolution of the feature map while introducing more detailed information. This strategy helps to preserve the details and contextual information of the image, thereby improving the quality of the generated image.
Suppose we have input feature map X, and we want to upsample it to Y by alternating between convolution and two-fold upsampling as follows:
where U(X) denotes one two-fold upsampling of X and C(·) denotes a convolution of the upsampled feature map. This process can be repeated
times to achieve the desired up-sampling multiplicity. This gradual increase in the resolution of the feature map is a useful strategy for tasks such as image generation, super-resolution, and image restoration, where high-quality images are required.
4.2.2. Full-Scale Skip Connections
This module is the core of the full-scale progressive skip connections. Skip connection is a typical cross-layer information transfer method, which allows each layer of the decoder to be connected to a different layer of the encoder. This connection not only helps the model to better integrate features of different scales and layers, thus improving performance but also overcomes gradient vanishing and gradient explosion while speeding up the training process. With full-scale skip connections, we can introduce feature information from multiple scales simultaneously, thus better preserving the image structure and details. We mathematically define the full-scale skip connection as follows:
where
denotes the output feature map after full-scale skip connection,
denotes the original input feature map,
is the number of layers in the network, and
is the feature map from the i-th layer.
To implement this approach, we select the appropriate scale feature map from the encoder and then transform it to the desired size by convolution, up-sampling, or down-sampling, and then fuse the processed feature map using the concat feature fusion to obtain the corresponding decoder layer. The specific transformation and construction of decoder end
are performed as follows:
where i represents the index of a generic layer in the model,
is the total number of layers of the model,
denotes the convolution operation,
denotes downsampling,
denotes progressive upsampling, and
denotes that the feature aggregation mechanism has been implemented by means of convolution, batch normalization, and the ReLU activation function.
denotes that feature fusion has been carried out by means of concat. As can be seen, when the decoder layer coincides with the encoder layer and in construction of the i-th decoder layer, the encoder layers from the 1st to i-1st undergo operations such as downsampling and convolution, while the i-th encoder layer is subject only to convolution. Layers from i + 1 to the Nth utilize the previously constructed decoder layers and undergo upsampling and convolution. Subsequently, these processed feature maps are fused through concatenation, followed by further feature aggregation through convolution, normalization, and other operations. This connection strategy allows the decoder to focus on multi-scale information, facilitating the recovery of details while preserving global context.
4.3. Depthwise Separable Convolution
Depthwise separable convolution is a convolutional technique of CNNs that reduces both computational complexity and the number of parameters while maintaining performance. The underlying principle is the decomposition of standard convolution into two steps: depthwise convolution and point-by-point convolution. Depthwise convolution is applied independently to each channel of the input data, each with its own convolution kernel. This step captures spatial features in the input data. In point-by-point convolution, the output of the depthwise convolution is convolved using a 1 × 1 convolution kernel, and the feature maps of each channel are linearly combined to generate the final output feature map.
Assuming that the size of the input feature map is [H, W, C] and we apply convolution kernel
, the size of the output feature map is [H, W, D]. Then, the computational consumption of the standard convolution is as follows:
The computational consumptions of the depthwise convolution and point-by-point convolution are as follows:
Then, the computational consumption of the depthwise separable convolution is as follows:
These equations show that standard convolution involves a large number of computations while the total consumption of the depthwise separable convolution is relatively low. In terms of the number of parameters, the standard convolution and depthwise separable convolution follow a similar trend, as follows:
Key features of depthwise separable convolution include parameter sharing, computational efficiency, and channel count maintenance. Parameter sharing reduces the number of parameters in the model, computational efficiency makes it suitable for resource-constrained environments, and channel count maintenance ensures information integrity. Overall, depthwise separable convolution is a powerful technique for convolutional operations that reduces the computational burden and number of parameters while maintaining high performance, making it particularly suitable for resource-constrained environments such as mobile devices and embedded systems.
4.4. Overarching Framework
The proposed network model is a hybrid coding network based on CNN and ShiftViT for medical image segmentation. Its network structure is shown in
Figure 2 and is roughly divided into three parts: encoder, decoder, and skip connections. The encoder consists of CNN ResNet50 and ShiftViT, and the decoder is constructed layer by layer using full-scale progressive skip connections. Depthwise separable convolution is applied for robustness.
The encoder proceeds as follows: the input image first passes through a convolutional layer and a pooling layer to obtain a downsampled feature map. Downsampling helps improve computational efficiency and reduces redundancy in feature representation. To mitigate the potential loss of local context information during downsampling, the feature map is then divided into multiple non-overlapping subregions (patches), each corresponding to a vector, which serves as the input to the ShiftBlock, a structure that contains a shift operation and a feed forward network (FFN). The shift operation enables the exchange of information between different subregions without increasing parameters or computation, thereby alleviating the local feature loss caused by downsampling, while the FFN enhances the nonlinear representation. The output of ShiftBlock is reshaped into a feature map, which is used as an input to the decoder of this network.
Notably, the shift operation is designed to be particularly effective for local feature extraction, especially in medical image processing. The shift operation can pass information between neighboring subregions, smoothing out minor errors at the pixel level. Even in the presence of a few pixel shifts in the labels, the shift operation still shows high robustness and maintains good segmentation results. This property enables the model to achieve accurate segmentation results in the face of complex organ boundaries or labeling errors.
The decoder is constructed from multiple decoder layers, each of which is, in turn, constructed through full-scale progressive skip connections. These connections not only compensate for the limitations of the ShiftViT structure in capturing global context information but also enhance the flow of information between different levels, thereby improving the model’s ability to capture global features. Additionally, the use of full-scale progressive skip connections not only increases the diversity and richness of features and improves the accuracy and robustness of segmentation but also reduces the loss of information and improves the resolution and quality of features. In addition to this, the use of depthwise separable convolution reduces the number of parameters and computations to improve the efficiency and speed of the model. The final output of the decoder goes through a convolutional layer and a softmax activation function to obtain the final segmentation result.
5. Experimental Results and Analysis
We confirmed the effectiveness of the proposed segmentation method as well as its practical value for medical applications using experiments. Our experimental design, selection of datasets, evaluation metrics, experimental setup, comparison tests, and ablation tests are detailed in this section.
5.1. Datasets
Synapse Multi-Organ Segmentation Dataset (Synapse) [31]: This dataset consists of 30 cases with a total of 3779 axial clinical computed tomography (CT) images of the abdomen for medical image segmentation tasks. Each CT volume consists of 85 to 198 512 × 512-pixel slices with different voxel spatial resolutions. The dataset covers the labeling of eight abdominal organs, including the aorta, gallbladder, left kidney, right kidney, liver, pancreas, spleen, and stomach. It is divided into 18 training cases and 12 test cases for model training and performance testing.
Automated Cardiac Diagnostic Challenge dataset (ACDC) [32]: This dataset comprises magnetic resonance imaging (MRI) data from different patients for cardiac image segmentation. The MRI scans of each patient include markers of the left ventricle (LV), right ventricle (RV), and myocardium (MYO). These cine-MR images were acquired under breath-hold and contain a series of short-axis slices covering the cardiac region from the bottom of the left ventricle to the top, with slice thicknesses ranging from 5 to 8 mm and spatial resolution in the short-axis plane of 0.83 to 1.75 mm
2/pixel. Based on previous research methods, such as TransUNet and SwinUNet, we divided the dataset into 70 training samples, 10 validation samples, and 20 test samples [
18,
19].
5.2. Evaluation Metrics
The average Dice similarity coefficient (DSC) and the average Hausdorff distance (HD) were used as performance evaluation metrics [
18,
19,
33]. The DSC, also known as the Dice coefficient or F1 score, is used to measure the similarity between two sets and is commonly applied for medical image segmentation. It calculates the ratio of the intersection of two sets to their average size as follows:
where |A∩B| denotes the size of the intersection of two sets A and B, |A| denotes the size of set A, and |B| denotes the size of set B. The value of DSC ranges from 0 to 1; the closer the value is to 1, the more similar the two sets are, the higher the ratio of intersection to the total size is, and the more accurate the segmentation result is. DSC equal to 1 indicates a perfect match with no error. DSC is often used to evaluate the performance of medical image segmentation algorithms, especially when comparing the agreement between automatic segmentation and manual labeling, and higher DSC values usually indicate more accurate segmentation.
The HD is a distance metric used to measure the similarity between two sets and is commonly applied to medical image segmentation. It measures the maximum dissimilarity between two sets, i.e., the maximum distance from a point in one set to the nearest point in the other set. The computation of the HD involves two sets A and B: for each point in set A, find the closest point in set B; for each point in set B, find the closest point in set A; and then compute the maximum of these two closest distances, as follows:
where ||·|| is the distance paradigm between set A and set B. The smaller the value of HD, the more similar the two sets are and the closer the segmentation results are. HD equal to 0 indicates that the two sets are perfectly matched and there are no mismatched points. HD is often used to evaluate the performance of medical image segmentation algorithms, especially when comparing the consistency between automatic segmentation and manual labeling.
5.3. Experimental Setup
This experiment was performed on an Ubuntu server with a V100-SXM2-32GB graphics card, and for data augmentation, we used simple random augmentation and random flipping. The ResNet-50 model used in the hybrid encoder part has been pretrained on ImageNet [
5]. On the Synapse dataset, the parameters of the experiment were set as follows: learning rate of 0.01 and batch size of 24. For the SGD optimizer, momentum was set at 0.9, weight_decay was set at 0.0001, and the random seed was set at 1234. We chose SGD because it shows good effectiveness in training deep learning models, especially segmentation tasks, and usually provides better stability and generalization performance than Adam [
34] with a limited amount of data. On the ACDC dataset, due to the small size of the data, the batch size was adjusted to 8 to update the model more frequently, while the weight decay was adjusted to 0.01 to optimize the performance, and the rest of the parameters remained unchanged. We also applied a hybrid loss function of cross-entropy and Dice.
5.4. Comparison Experiment
We conducted comparative experiments on the Synapse and ACDC datasets to evaluate the performances of the proposed network models and mainstream network models.
We present the average Dice Similarity Coefficient (DSC) and Hausdorff Distance (HD) scores of various mainstream models on the Synapse dataset in
Table 1. These values reflect the models’ segmentation performance across multiple anatomical structures. By analyzing the DSC values for each category, we are able to clearly assess the performance of different models on specific anatomical structures, thus effectively highlighting the strengths and contributions of the proposed models. In our comparative analysis, we explore the structural characteristics of different network models and their impact on the performance of the Synapse dataset. We find that V-Net and DARR, which use a convolution-based encoding-decoding architecture, are capable of basic segmentation but perform poorly in capturing complex anatomical structures, with DSC values of 68.81 and 69.77, respectively, which are much lower than that of our model at 79.46. U-Net and its variants (e.g., R50 U-Net and Att-UNet) enhance feature transfer, although the DSC of U-Net is 76.85, which is still lower than our model. We observe that TransNorm, which combines Transformer’s self-attention mechanism and spatial normalization, achieves a DSC value of 78.40 but is not as good as our model in detail capturing. MT-UNet, which combines multiple Transformer modules and U-Net, achieves a DSC value of 78.59, and also fails to outperform our model. Although SwinUNet shows a strong overall performance with a DSC value of 79.13, ShiftTransUNet performs more prominently in specific categories, especially in the categories that require strong global dependencies, such as the liver and pancreas, achieving high Dice values of 84.07 and 94.83, respectively, which fully validates its potential for application in medical image segmentation.
On the ACDC dataset, we analyzed the average DSCs of different medical image segmentation models, as well as the DSCs of the three classes, which are shown in
Table 2. As can be seen in the table, the R50 U-Net and the R50 Att-UNet use the classical coding-decoding architecture, which, despite their good performance (with a DSC value of 87.55 and 86.75, respectively), are not as good as our model (90.28) in capturing the complex anatomical structures. TransUNet utilizes the self-attention mechanism to enhance feature extraction, with a DSC value of 89.71; however, it still does not outperform our model in detail capturing. SwinUNet, with a DSC value of 90, performs well, but our model performs better in some categories, showing an advantage in feature fusion and detail capturing. In contrast, ViT-CUP and R50-ViT-CUP, which are based on the structure of Visual Transformer, are still lower than our model, although R50-ViT-CUP has a DSC of 87.57. In specific segmentation tasks, ShiftTransUNet performs particularly well, with Dice values of 90 and 87.49 for the right ventricle (RV) and myocardium (Myo), respectively, as well as 93.35 for the left ventricle (LV), showing its effectiveness in segmenting complex structures.
By analyzing the Synapse and ACDC datasets, ShiftTransUNet shows significant potential for high-precision medical image segmentation, especially in different modalities (CT and MRI). This model not only outperforms many mainstream models in terms of overall performance but also demonstrates strong capabilities in various categories of segmentation tasks. In addition, by introducing the ShiftViT structure and depth-separable convolution, ShiftTransUNet effectively reduces the computational complexity and improves the performance, which further highlights its wide applicability in medical image segmentation.
5.5. Analytical Study
The ablation study in this research was conducted on the Synapse dataset, with detailed results presented in
Table 3. This experiment aims to evaluate the impact of different structural modifications on model performance.
As shown in the table, the second row indicates that when only using TransUNet, the DSC value is 77.48, the HD value is 31.69, the computational cost is 24.66 GMac, and the number of parameters is 105.28 M. In the third row, when ShiftViT is used to replace the ViT structure in TransUNet, the DSC value slightly increases to 78.04. Although segmentation accuracy decreases, both computational cost and parameter count significantly drop to 8.94 GMac and 24.95 M, respectively, indicating that ShiftViT effectively reduces computational complexity while maintaining reasonable performance. In the fourth row, the full-scale progressive skip connections are adopted to replace the same-layer progressive skip connections in TransUNet, resulting in a significant increase in DSC to 81.06 and a decrease in HD to 28.22. Despite the computational cost and parameter count rising to 35.36 GMac and 119.48 M, respectively, the improved segmentation performance validates the effectiveness of this connection method in handling complex anatomical structures. In the fifth row, we used depthwise separable convolution to replace the convolution operation in the decoder section, resulting in a DSC value of 78.80, an HD value of 32.31, a computational cost of 22.68 GMac, and a parameter count of 101.45 M. This modification has minimal impact on segmentation accuracy but effectively reduces computational complexity, further enhancing the model’s efficiency. Finally, in the sixth row, the combination of these three innovations yields a DSC value of 79.46 and an HD value of 28.29, with computational cost and parameter count reduced to 9.46 GMac and 28.97 M, respectively. These results provide strong evidence for the effectiveness of the proposed network model, highlighting the importance of each innovation in enhancing performance and reducing complexity. In addition,
Figure 3 provides a visual comparison of the segmentation results, further demonstrating the effectiveness of the proposed network structure.
6. Conclusions
In this study, we propose a method that is both efficient and performs well in the task of medical image segmentation by combining the ShiftViT structure, full-scale progressive skip connections, and depthwise separable convolution. Experimental results demonstrate that on multiple medical image datasets, the proposed method not only achieves excellent segmentation performance but also significantly reduces the number of parameters and the computational complexity of the model.
Although the performance of the proposed method is high, there remains room for improvement. First, we plan to extend our research to the field of 3D image segmentation to cope with more complex medical data, including processing images with complex structures and smaller targets. Second, we plan to continue to explore optimization strategies to further improve the computational efficiency of the model. In addition, we will focus on data processing and preprocessing to improve the quality of the input images and to overcome labeling issues to enhance the feasibility of practical applications.
Notably, our model shows good adaptability in medical imaging tasks with different modalities, demonstrating its potential for application in multimodal medical imaging. In the future, we plan to validate the model on more medical imaging tasks and different datasets to demonstrate its wide range of applications.
In summary, although the proposed method represents a valuable contribution to the goal of balancing the trade-off between performance and complexity, further research is needed to ensure the ongoing development of medical image segmentation methods and their real-world applicability in particular.