An Example of Integrating DYAD into PyToch DataLoader

Unet3D DataLoader in DLIO
  1"""
  2   Copyright (c) 2025, UChicago Argonne, LLC
  3   All Rights Reserved
  4
  5   Licensed under the Apache License, Version 2.0 (the "License");
  6   you may not use this file except in compliance with the License.
  7   You may obtain a copy of the License at
  8
  9       http://www.apache.org/licenses/LICENSE-2.0
 10
 11   Unless required by applicable law or agreed to in writing, software
 12   distributed under the License is distributed on an "AS IS" BASIS,
 13   WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 14   See the License for the specific language governing permissions and
 15   limitations under the License.
 16"""
 17import math
 18import pickle
 19import torch
 20from torch.utils.data import Dataset, DataLoader
 21from torch.utils.data.sampler import Sampler
 22
 23from dlio_benchmark.common.constants import MODULE_DATA_LOADER
 24from dlio_benchmark.common.enumerations import DatasetType, DataLoaderType
 25from dlio_benchmark.data_loader.base_data_loader import BaseDataLoader
 26from dlio_benchmark.reader.reader_factory import ReaderFactory
 27from dlio_benchmark.utils.utility import utcnow, DLIOMPI, Profile, dft_ai
 28from dlio_benchmark.utils.config import ConfigArguments
 29from pydyad import Dyad, dyad_open
 30from pydyad.bindings import DTLMode, DTLCommMode
 31import numpy as np
 32#import flux
 33import os
 34
 35dlp = Profile(MODULE_DATA_LOADER)
 36
 37
 38class DYADTorchDataset(Dataset):
 39    """
 40    Currently, we only support loading one sample per file
 41    TODO: support multiple samples per file
 42    """
 43
 44    @dlp.log_init
 45    def __init__(self, format_type, dataset_type, epoch, num_samples, num_workers, batch_size):
 46        self.format_type = format_type
 47        self.dataset_type = dataset_type
 48        self.epoch_number = epoch
 49        self.num_samples = num_samples
 50        self.reader = None
 51        self.num_images_read = 0
 52        self.batch_size = batch_size
 53        args = ConfigArguments.get_instance()
 54        self.serial_args = pickle.dumps(args)
 55        self.logger = args.logger
 56        self.dlp_logger = None
 57        if num_workers == 0:
 58            self.worker_init(-1)
 59
 60    @dlp.log
 61    def worker_init(self, worker_id):
 62        pickle.loads(self.serial_args)
 63        _args = ConfigArguments.get_instance()
 64        _args.configure_dlio_logging(is_child=True)
 65        self.dlp_logger = _args.configure_dftracer(is_child=True, use_pid=True)
 66        self.logger.debug(f"{utcnow()} worker initialized {worker_id} with format {self.format_type}")
 67        self.reader = ReaderFactory.get_reader(type=self.format_type,
 68                                               dataset_type=self.dataset_type,
 69                                               thread_index=worker_id,
 70                                               epoch_number=self.epoch_number)
 71        self.dyad_io = Dyad()
 72        self.namespace = os.getenv("DYAD_KVS_NAMESPACE")
 73        #f = flux.Flux()
 74        self.my_node_index = 0 #f.get_rank()
 75        self.dyad_managed_directory = os.getenv("DYAD_PATH_PRODUCER")
 76        #self.dyad_managed_directory = os.getenv("DYAD_PATH_CONSUMER")
 77        mode = DTLMode.DYAD_DTL_MARGO
 78        self.dyad_io.init(debug=False, check=False, shared_storage=False, reinit=False,
 79                          async_publish=True, fsync_write=False, key_depth=3,
 80                          service_mux=1,
 81                          key_bins=1024, kvs_namespace=self.namespace,
 82                          prod_managed_path=self.dyad_managed_directory,
 83                          cons_managed_path=self.dyad_managed_directory,
 84                          dtl_mode=mode, dtl_comm_mode=DTLCommMode.DYAD_COMM_RECV)
 85        if self.dataset_type is DatasetType.TRAIN:
 86            self.global_index_map = _args.train_global_index_map
 87            self.file_map = _args.train_file_map
 88        else:
 89            self.file_map = _args.val_file_map
 90            self.global_index_map = _args.val_global_index_map
 91
 92    def __del__(self):
 93        if self.dlp_logger:
 94            self.dlp_logger.finalize()
 95
 96    @dlp.log
 97    def __len__(self):
 98        return self.num_samples
 99
100    @dlp.log
101    def __getitem__(self, image_idx):
102        self.num_images_read += 1
103        step = int(math.ceil(self.num_images_read / self.batch_size))
104        self.logger.debug(f"{utcnow()} Rank {DLIOMPI.get_instance().rank()} reading {image_idx} sample")
105        filename, sample_index = self.global_index_map[image_idx]
106        is_present = False
107        file_obj = None
108        base_fname = filename
109        dlp.update(args={"fname":filename})
110        dlp.update(args={"image_idx":image_idx})
111        if self.dyad_managed_directory != "":
112            self.logger.debug(f"{utcnow()} Rank {DLIOMPI.get_instance().rank()} reading metadata")
113            base_fname = os.path.join(self.dyad_managed_directory, os.path.basename(filename))
114            file_obj = self.dyad_io.get_metadata(fname=base_fname, should_wait=False, raw=True)
115            self.logger.debug(f"Using managed directory {self.dyad_managed_directory} {base_fname} {file_obj}")
116            is_present = True
117        if file_obj:
118            access_mode = "remote"
119            if self.my_node_index == file_obj.contents.owner_rank:
120                access_mode = "local"
121            dlp.update(args={"owner_rank":str(file_obj.contents.owner_rank)})
122            dlp.update(args={"mode":"dyad"})
123            dlp.update(args={"access":access_mode})
124            self.logger.debug(f"Reading from managed directory {base_fname}")
125            with dyad_open(base_fname, "rb", dyad_ctx=self.dyad_io, metadata_wrapper=file_obj) as f:
126                try:
127                    data = np.load(f, allow_pickle=True)["x"]
128                except:
129                    data = self._args.resized_image
130            self.dyad_io.free_metadata(file_obj)
131        else:
132            dlp.update(args={"mode":"pfs"})
133            dlp.update(args={"access":"remote"})
134            self.logger.debug(f"Reading from pfs {base_fname}")
135            data = self.reader.read_index(image_idx, step)
136            if is_present:
137                self.logger.debug(f"Writing to managed_directory {base_fname}")
138                with dyad_open(base_fname, "wb", dyad_ctx=self.dyad_io) as f:
139                    np.savez(f, x=data)
140            self.logger.debug(f"Read from pfs {base_fname}")
141
142        dlp.update(step=step)
143        dft_ai.update(step=step)
144        return self.reader.read_index(image_idx, step)
145
146
147class dlio_sampler(Sampler):
148    def __init__(self, rank, size, num_samples, epochs):
149        self.size = size
150        self.rank = rank
151        self.num_samples = num_samples
152        self.epochs = epochs
153        samples_per_proc = int(math.ceil(num_samples/size)) 
154        start_sample = self.rank * samples_per_proc
155        end_sample = (self.rank + 1) * samples_per_proc - 1
156        if end_sample > num_samples - 1:
157            end_sample = num_samples - 1
158        self.indices = list(range(start_sample, end_sample + 1))
159
160
161    def __len__(self):
162        return self.num_samples
163
164    def __iter__(self):
165        for sample in self.indices:
166            yield sample
167
168
169class DyadTorchDataLoader(BaseDataLoader):
170    @dlp.log_init
171    def __init__(self, format_type, dataset_type, epoch_number):
172        super().__init__(format_type, dataset_type, epoch_number, DataLoaderType.PYTORCH)
173
174    @dlp.log
175    def read(self):
176        dataset = DYADTorchDataset(self.format_type, self.dataset_type, self.epoch_number, self.num_samples,
177                                   self._args.read_threads, self.batch_size)
178        sampler = dlio_sampler(self._args.my_rank, self._args.comm_size, self.num_samples, self._args.epochs)
179        if self._args.read_threads >= 1:
180            prefetch_factor = math.ceil(self._args.prefetch_size / self._args.read_threads)
181        else:
182            prefetch_factor = self._args.prefetch_size
183        if prefetch_factor > 0:
184            if self._args.my_rank == 0:
185                self.logger.debug(
186                    f"{utcnow()} Prefetch size is {self._args.prefetch_size}; prefetch factor of {prefetch_factor} will be set to Torch DataLoader.")
187        else:
188            prefetch_factor = 2
189            if self._args.my_rank == 0:
190                self.logger.debug(
191                    f"{utcnow()} Prefetch size is 0; a default prefetch factor of 2 will be set to Torch DataLoader.")
192        self.logger.debug(f"{utcnow()} Setup dataloader with {self._args.read_threads} workers {torch.__version__}")
193        if self._args.read_threads==0:
194            kwargs={}
195        else:
196            kwargs={'multiprocessing_context':self._args.multiprocessing_context,
197                    'prefetch_factor': prefetch_factor}
198            if torch.__version__ != '1.3.1':       
199                kwargs['persistent_workers'] = True
200        if torch.__version__ == '1.3.1':
201            if 'prefetch_factor' in kwargs:
202                del kwargs['prefetch_factor']
203            self._dataset = DataLoader(dataset,
204                                       batch_size=self.batch_size,
205                                       sampler=sampler,
206                                       num_workers=self._args.read_threads,
207                                       pin_memory=self._args.pin_memory,
208                                       drop_last=True,
209                                       worker_init_fn=dataset.worker_init, 
210                                       **kwargs)
211        else: 
212            self._dataset = DataLoader(dataset,
213                                       batch_size=self.batch_size,
214                                       sampler=sampler,
215                                       num_workers=self._args.read_threads,
216                                       pin_memory=self._args.pin_memory,
217                                       drop_last=True,
218                                       worker_init_fn=dataset.worker_init,
219                                       **kwargs)  # 2 is the default value
220        self.logger.debug(f"{utcnow()} Rank {self._args.my_rank} will read {len(self._dataset) * self.batch_size} files")
221
222        # self._dataset.sampler.set_epoch(epoch_number)
223
224    @dlp.log
225    def next(self):
226        super().next()
227        total = self._args.training_steps if self.dataset_type is DatasetType.TRAIN else self._args.eval_steps
228        self.logger.debug(f"{utcnow()} Rank {self._args.my_rank} should read {total} batches")
229        step = 1
230        for batch in dft_ai.dataloader.fetch.iter(self._dataset):
231            dlp.update(step=step)
232            dft_ai.update(step=step)
233            step += 1
234            yield batch
235        self.epoch_number += 1
236        dlp.update(epoch=self.epoch_number)
237        dft_ai.update(epoch=self.epoch_number)
238
239    @dlp.log
240    def finalize(self):
241        pass

The lines relevant to DYAD are:

  • The line 29 for importing PyDAYD

  • The lines 71-84 for initialization

  • The lines in __get_item__()

  • dlp. lines are for profiler and not directly relevant to DYAD