Memory usage can be reduced by up to 75%. Scientists from the U.S. Department of Energy have proposed a cross-channel hierarchical aggregation method called D-CHAG, enabling the operation of multi-channel datasets for extremely large-scale models.
Scientists from the Oak Ridge National Laboratory of the U.S. Department of Energy have proposed a Distributed Cross-Channel Hierarchical Aggregation method (D-CHAG) for foundation models. This method distributes the tokenization process and adopts a hierarchical strategy for channel aggregation, enabling extremely large-scale models to run on multi-channel datasets.
Vision-based scientific foundation models have great potential in driving scientific discovery and innovation, mainly because they can aggregate image data from diverse sources (such as different physical observation scenarios) and learn spatio-temporal correlations using the Transformer architecture. However, the tokenization and aggregation processes of images incur high computational costs, and existing distributed methods such as Tensor Parallelism (TP), Sequence Parallelism (SP), or Data Parallelism (DP) have not fully addressed this challenge.
In this context, researchers from the Oak Ridge National Laboratory of the U.S. Department of Energy proposed a Distributed Cross-Channel Hierarchical Aggregation method (D-CHAG) for foundation models. This method distributes the tokenization process and adopts a hierarchical strategy for channel aggregation, enabling extremely large-scale models to run on multi-channel datasets. The researchers evaluated D-CHAG on hyperspectral imaging and weather prediction tasks. By combining this method with tensor parallelism and model sharding, they were able to reduce memory usage by up to 75% on the Frontier supercomputer and achieve a more than twofold increase in continuous throughput on up to 1,024 AMD GPUs.
The relevant research results, titled "Distributed Cross-Channel Hierarchical Aggregation for Foundation Models," have been published in SC25.
Research Highlights:
* D-CHAG solves the memory bottleneck and computational efficiency problems in the training of multi-channel foundation models.
* Compared with using only TP, D-CHAG can reduce memory usage by up to 70%, thus supporting more efficient large-scale model training.
* The performance of D-CHAG was verified on two scientific workloads: weather prediction and hyperspectral plant image mask prediction.
Paper Link: https://dl.acm.org/doi/10.1145/3712285.3759870
Using Two Typical Multi-Channel Datasets
This study used two typical multi-channel datasets to verify the effectiveness of the D-CHAG method: hyperspectral images of plants and the meteorological ERA5 dataset.
Among them, the hyperspectral image data of plants used for self-supervised mask prediction was collected by the Advanced Plant Phenotyping Laboratory (APPL) of the Oak Ridge National Laboratory (ORNL). The dataset contains 494 hyperspectral images of poplar trees, each with 500 spectral channels, covering wavelengths from 400nm to 900nm.
This dataset is mainly used for biomass research and is an important resource for plant phenotyping analysis and bioenergy research. These images are used for mask self-supervised training, where image patches are used as tokens for masking, and the model's task is to predict the missing content to learn the underlying data distribution of the images. Notably, this dataset did not use any pre-trained weights and was trained entirely based on self-supervised learning, which also highlights the applicability of D-CHAG in high-channel self-supervised tasks.
In addition, in the meteorological prediction experiment, the research team used the ERA5 high-resolution reanalysis dataset. The study selected five atmospheric variables (geopotential height, temperature, u-component of wind speed, v-component of wind speed, specific humidity) and three surface variables (2-meter temperature, 10-meter u-component of wind speed, 10-meter v-component of wind speed), covering more than 10 pressure levels, resulting in a total of 80 input channels. To adapt to model training, the original data with a resolution of 0.25° (770 × 1440) was regridded to 5.625° (32 × 64) using the xESMF toolkit and the bilinear interpolation algorithm.
The model's task is to predict meteorological variables at future time steps, such as the 500 hPa geopotential height (Z500), 850 hPa temperature (T850), and 10-meter u-component of wind speed (U10), to verify the performance of the D-CHAG method in time series prediction tasks.
D-CHAG: Combining Hierarchical Aggregation with Distributed Tokenization
In simple terms, the D-CHAG method is a combination of two independent methods:
Distributed Tokenization Method
During the forward propagation process, each TP rank only tokenizes a subset of the input channels. Before the channel aggregation step, an AllGather operation needs to be performed to enable cross-attention among all channels. Theoretically, this method can reduce the tokenization computational cost of each GPU.
Hierarchical Cross-Channel Aggregation
The main advantage of this method is the reduced memory usage of each cross-attention layer because each layer processes fewer channels. However, increasing the number of layers leads to an increase in the overall model size and memory usage. For datasets with a large number of channels, this trade-off is more favorable because the quadratic memory overhead of standard cross-attention is higher.
Although these two methods have their own advantages, they also have some deficiencies. For example, the distributed tokenization method has high communication overhead among TP ranks and does not solve the problem of large memory usage in the channel dimension. The hierarchical cross-channel aggregation method increases the number of model parameters on each GPU. The D-CHAG method combines the two methods in a distributed manner, and the overall architecture is shown in the following figure:
Schematic diagram of the D-CHAG method in the basic architecture
Specifically, each TP rank tokenizes the two-dimensional images in the subset of the total channels. Since each GPU only holds a part of all the channels, channel aggregation is performed locally on these channels - this module is called the partial-channel aggregation module. After channel aggregation is completed within each TP rank, the outputs are collected and finally aggregated using cross-attention. Only one AllGather operation is required during the forward propagation process; during the backward propagation, only the relevant gradients of each GPU are collected, thus avoiding additional communication.
The D-CHAG method can fully utilize the advantages of distributed tokenization and hierarchical channel aggregation while alleviating their deficiencies. By distributing hierarchical channel aggregation to TP ranks, the researchers reduced AllGather communication to each TP rank only needing to process a single channel, and no communication is required during the backward propagation process. In addition, by increasing the model depth, the advantage of reducing the number of channels processed by each layer of aggregation is retained, and the additional model parameters are distributed to each TP rank through the partial-channel aggregation module.
The study compared two implementation strategies:
* D-CHAG-L (Linear Layer): The hierarchical aggregation module uses a linear layer, which has low memory usage and is suitable for cases with a large number of channels.
* D-CHAG-C (Cross-Attention Layer): It uses a cross-attention layer, which has a high computational cost but significantly improves performance for extremely large models or a very high number of channels.
Results: D-CHAG Enables Training of Larger Models on High-Channel Datasets
After constructing D-CHAG, the researchers verified the model's performance and then further evaluated its performance on hyperspectral imaging and weather prediction tasks:
Model Performance Analysis
The following figure shows the performance of D-CHAG under different configurations of the partial-channel aggregation module:
The figure shows the performance improvement of each GPU relative to the baseline using only TP for a 1.7B parameter model under different configurations of the partial-channel aggregation module
* Tree0 indicates only one layer of aggregation in the partial aggregation module, Tree2 indicates two layers, and so on;
* The suffixes -C and -L indicate the types of layers used: all layers in -C are cross-attention, and all layers in -L are linear
The results show:
For 512-channel data, the performance of using a single-layer cross-attention layer is slightly lower than the baseline, but for 1024-channel data, it can be improved by about 60%.
As the hierarchical structure deepens, even for 512-channel data, significant performance improvement can be obtained, while the performance of 1024-channel data remains relatively stable.
When using linear layers, even with a shallow hierarchical structure, performance improvement can be obtained on both 512 and 1024-channel images. In fact, the best performance occurs in D-CHAG-L-Tree0, which only contains one layer of channel aggregation. Increasing the number of aggregation layers will increase the model parameters and introduce additional memory overhead. Although increasing the number of layers seems beneficial for the 512-channel case, for both channel scales, the performance of using only one linear layer is better than that of deeper configurations.
D-CHAG-C-Tree0 has a slight negative impact on performance when using two GPUs, but it can achieve a 60% improvement when scaling up to eight GPUs.
Self-Supervised Mask Prediction of Hyperspectral Plant Images
The following figure compares the training losses of the baseline method and the D-CHAG method in the application of hyperspectral plant image mask autoencoders. The results show that during the training process, the training losses of the single-GPU implementation and the D-CHAG method (running on two GPUs) are highly consistent.
Training losses of the baseline method and the D-CHAG method in the application of hyperspectral plant image mask autoencoders
Larry York, a senior researcher in the Molecular and Cellular Imaging Group at the Oak Ridge National Laboratory, said that D-CHAG can help plant scientists quickly complete tasks such as directly measuring the photosynthetic activity of plants from images, thus replacing time-consuming and labor-intensive manual measurements.
Weather Prediction
The researchers conducted a 30-day meteorological prediction experiment on the ERA5 dataset. The following figure compares the training losses and the RMSE of three test variables of the baseline method and the D-CHAG method in the weather prediction application:
Training losses and RMSE of three test variables of the baseline method and the D-CHAG method in the weather prediction application
The following table shows the final comparison of the model on the 7-, 14-, and 30-day prediction tasks, including RMSE, MSE, and the Pearson correlation coefficient (i.e., wACC)
Percentage changes (% Δ) in MSE, RMSE, and wACC of the D-CHAG method compared with single-GPU training on the 7-, 14-, and 30-day prediction tasks
Overall, considering the figure and the table, the training losses are highly consistent with the baseline model, and the deviations of all indicators are extremely small.
Performance with Model Scale Expansion
The following figure shows the performance improvement of the D-CHAG method compared with using only TP for three model scales under channel configurations that require TP:
Performance improvement of each GPU for 7B, 15B, and 26B parameter models when the D-CHAG method is combined with TP compared with using only TP
The results show that for the 7B parameter model, using linear layers in the partial-channel aggregation module can achieve a 30% to 70% performance improvement, while using cross-attention layers can achieve a 10% to 60% improvement; for the 15B parameter model, the performance improvement exceeds 20% to 50%; for the 26B parameter model, the performance improvement is between 10% and 30%.
In addition, for a fixed model scale, as the number of channels increases, the performance improvement is more obvious. This is because, under a given architecture, increasing the number of channels does not increase the computational load of the transformer block but increases the workload of the tokenization and channel-aggregation modules.
On the other hand, using only TP cannot train a 26B parameter, 256-channel image, but when using the D-CHAG method, a 26B parameter, 512-channel model can be trained using less than 80% of the available memory - this indicates that this method can support the training of larger models on high-channel datasets.
ViT: Vision AI Transitions from Perceptual Models to General Vision Foundation Models
In the past decade, computer vision models have mainly focused on "single-task optimization" - classification, detection, segmentation, and reconstruction have developed independently. However, as the Transformer architecture has given rise to foundation models such as GPT and BERT in the field of natural language, the field of vision is also undergoing a similar paradigm shift: from task-specific models to general vision foundation models. In this trend, Vision Transformer (ViT) is regarded as a key technological cornerstone for vision foundation models.
Vision Transformer (ViT) was the first to fully introduce the Transformer architecture into computer vision tasks. Its core idea is to treat an image as a sequence of patch tokens and replace the local receptive field modeling of convolutional neural networks with a self-attention mechanism. Specifically, ViT divides the input image into fixed-size patches and maps each patch to an embedding token, and then models the global relationships between patches through the Transformer Encoder.
Compared with traditional CNNs, ViT has particular advantages for scientific data: it is suitable for high-dimensional multi-channel data (such as remote sensing, medical imaging, and