An Example of Integrating DYAD into PyToch DataLoader
Unet3D DataLoader in DLIO
1"""
2 Copyright (c) 2025, UChicago Argonne, LLC
3 All Rights Reserved
4
5 Licensed under the Apache License, Version 2.0 (the "License");
6 you may not use this file except in compliance with the License.
7 You may obtain a copy of the License at
8
9 http://www.apache.org/licenses/LICENSE-2.0
10
11 Unless required by applicable law or agreed to in writing, software
12 distributed under the License is distributed on an "AS IS" BASIS,
13 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14 See the License for the specific language governing permissions and
15 limitations under the License.
16"""
17import math
18import pickle
19import torch
20from torch.utils.data import Dataset, DataLoader
21from torch.utils.data.sampler import Sampler
22
23from dlio_benchmark.common.constants import MODULE_DATA_LOADER
24from dlio_benchmark.common.enumerations import DatasetType, DataLoaderType
25from dlio_benchmark.data_loader.base_data_loader import BaseDataLoader
26from dlio_benchmark.reader.reader_factory import ReaderFactory
27from dlio_benchmark.utils.utility import utcnow, DLIOMPI, Profile, dft_ai
28from dlio_benchmark.utils.config import ConfigArguments
29from pydyad import Dyad, dyad_open
30from pydyad.bindings import DTLMode, DTLCommMode
31import numpy as np
32#import flux
33import os
34
35dlp = Profile(MODULE_DATA_LOADER)
36
37
38class DYADTorchDataset(Dataset):
39 """
40 Currently, we only support loading one sample per file
41 TODO: support multiple samples per file
42 """
43
44 @dlp.log_init
45 def __init__(self, format_type, dataset_type, epoch, num_samples, num_workers, batch_size):
46 self.format_type = format_type
47 self.dataset_type = dataset_type
48 self.epoch_number = epoch
49 self.num_samples = num_samples
50 self.reader = None
51 self.num_images_read = 0
52 self.batch_size = batch_size
53 args = ConfigArguments.get_instance()
54 self.serial_args = pickle.dumps(args)
55 self.logger = args.logger
56 self.dlp_logger = None
57 if num_workers == 0:
58 self.worker_init(-1)
59
60 @dlp.log
61 def worker_init(self, worker_id):
62 pickle.loads(self.serial_args)
63 _args = ConfigArguments.get_instance()
64 _args.configure_dlio_logging(is_child=True)
65 self.dlp_logger = _args.configure_dftracer(is_child=True, use_pid=True)
66 self.logger.debug(f"{utcnow()} worker initialized {worker_id} with format {self.format_type}")
67 self.reader = ReaderFactory.get_reader(type=self.format_type,
68 dataset_type=self.dataset_type,
69 thread_index=worker_id,
70 epoch_number=self.epoch_number)
71 self.dyad_io = Dyad()
72 self.namespace = os.getenv("DYAD_KVS_NAMESPACE")
73 #f = flux.Flux()
74 self.my_node_index = 0 #f.get_rank()
75 self.dyad_managed_directory = os.getenv("DYAD_PATH_PRODUCER")
76 #self.dyad_managed_directory = os.getenv("DYAD_PATH_CONSUMER")
77 mode = DTLMode.DYAD_DTL_MARGO
78 self.dyad_io.init(debug=False, check=False, shared_storage=False, reinit=False,
79 async_publish=True, fsync_write=False, key_depth=3,
80 service_mux=1,
81 key_bins=1024, kvs_namespace=self.namespace,
82 prod_managed_path=self.dyad_managed_directory,
83 cons_managed_path=self.dyad_managed_directory,
84 dtl_mode=mode, dtl_comm_mode=DTLCommMode.DYAD_COMM_RECV)
85 if self.dataset_type is DatasetType.TRAIN:
86 self.global_index_map = _args.train_global_index_map
87 self.file_map = _args.train_file_map
88 else:
89 self.file_map = _args.val_file_map
90 self.global_index_map = _args.val_global_index_map
91
92 def __del__(self):
93 if self.dlp_logger:
94 self.dlp_logger.finalize()
95
96 @dlp.log
97 def __len__(self):
98 return self.num_samples
99
100 @dlp.log
101 def __getitem__(self, image_idx):
102 self.num_images_read += 1
103 step = int(math.ceil(self.num_images_read / self.batch_size))
104 self.logger.debug(f"{utcnow()} Rank {DLIOMPI.get_instance().rank()} reading {image_idx} sample")
105 filename, sample_index = self.global_index_map[image_idx]
106 is_present = False
107 file_obj = None
108 base_fname = filename
109 dlp.update(args={"fname":filename})
110 dlp.update(args={"image_idx":image_idx})
111 if self.dyad_managed_directory != "":
112 self.logger.debug(f"{utcnow()} Rank {DLIOMPI.get_instance().rank()} reading metadata")
113 base_fname = os.path.join(self.dyad_managed_directory, os.path.basename(filename))
114 file_obj = self.dyad_io.get_metadata(fname=base_fname, should_wait=False, raw=True)
115 self.logger.debug(f"Using managed directory {self.dyad_managed_directory} {base_fname} {file_obj}")
116 is_present = True
117 if file_obj:
118 access_mode = "remote"
119 if self.my_node_index == file_obj.contents.owner_rank:
120 access_mode = "local"
121 dlp.update(args={"owner_rank":str(file_obj.contents.owner_rank)})
122 dlp.update(args={"mode":"dyad"})
123 dlp.update(args={"access":access_mode})
124 self.logger.debug(f"Reading from managed directory {base_fname}")
125 with dyad_open(base_fname, "rb", dyad_ctx=self.dyad_io, metadata_wrapper=file_obj) as f:
126 try:
127 data = np.load(f, allow_pickle=True)["x"]
128 except:
129 data = self._args.resized_image
130 self.dyad_io.free_metadata(file_obj)
131 else:
132 dlp.update(args={"mode":"pfs"})
133 dlp.update(args={"access":"remote"})
134 self.logger.debug(f"Reading from pfs {base_fname}")
135 data = self.reader.read_index(image_idx, step)
136 if is_present:
137 self.logger.debug(f"Writing to managed_directory {base_fname}")
138 with dyad_open(base_fname, "wb", dyad_ctx=self.dyad_io) as f:
139 np.savez(f, x=data)
140 self.logger.debug(f"Read from pfs {base_fname}")
141
142 dlp.update(step=step)
143 dft_ai.update(step=step)
144 return self.reader.read_index(image_idx, step)
145
146
147class dlio_sampler(Sampler):
148 def __init__(self, rank, size, num_samples, epochs):
149 self.size = size
150 self.rank = rank
151 self.num_samples = num_samples
152 self.epochs = epochs
153 samples_per_proc = int(math.ceil(num_samples/size))
154 start_sample = self.rank * samples_per_proc
155 end_sample = (self.rank + 1) * samples_per_proc - 1
156 if end_sample > num_samples - 1:
157 end_sample = num_samples - 1
158 self.indices = list(range(start_sample, end_sample + 1))
159
160
161 def __len__(self):
162 return self.num_samples
163
164 def __iter__(self):
165 for sample in self.indices:
166 yield sample
167
168
169class DyadTorchDataLoader(BaseDataLoader):
170 @dlp.log_init
171 def __init__(self, format_type, dataset_type, epoch_number):
172 super().__init__(format_type, dataset_type, epoch_number, DataLoaderType.PYTORCH)
173
174 @dlp.log
175 def read(self):
176 dataset = DYADTorchDataset(self.format_type, self.dataset_type, self.epoch_number, self.num_samples,
177 self._args.read_threads, self.batch_size)
178 sampler = dlio_sampler(self._args.my_rank, self._args.comm_size, self.num_samples, self._args.epochs)
179 if self._args.read_threads >= 1:
180 prefetch_factor = math.ceil(self._args.prefetch_size / self._args.read_threads)
181 else:
182 prefetch_factor = self._args.prefetch_size
183 if prefetch_factor > 0:
184 if self._args.my_rank == 0:
185 self.logger.debug(
186 f"{utcnow()} Prefetch size is {self._args.prefetch_size}; prefetch factor of {prefetch_factor} will be set to Torch DataLoader.")
187 else:
188 prefetch_factor = 2
189 if self._args.my_rank == 0:
190 self.logger.debug(
191 f"{utcnow()} Prefetch size is 0; a default prefetch factor of 2 will be set to Torch DataLoader.")
192 self.logger.debug(f"{utcnow()} Setup dataloader with {self._args.read_threads} workers {torch.__version__}")
193 if self._args.read_threads==0:
194 kwargs={}
195 else:
196 kwargs={'multiprocessing_context':self._args.multiprocessing_context,
197 'prefetch_factor': prefetch_factor}
198 if torch.__version__ != '1.3.1':
199 kwargs['persistent_workers'] = True
200 if torch.__version__ == '1.3.1':
201 if 'prefetch_factor' in kwargs:
202 del kwargs['prefetch_factor']
203 self._dataset = DataLoader(dataset,
204 batch_size=self.batch_size,
205 sampler=sampler,
206 num_workers=self._args.read_threads,
207 pin_memory=self._args.pin_memory,
208 drop_last=True,
209 worker_init_fn=dataset.worker_init,
210 **kwargs)
211 else:
212 self._dataset = DataLoader(dataset,
213 batch_size=self.batch_size,
214 sampler=sampler,
215 num_workers=self._args.read_threads,
216 pin_memory=self._args.pin_memory,
217 drop_last=True,
218 worker_init_fn=dataset.worker_init,
219 **kwargs) # 2 is the default value
220 self.logger.debug(f"{utcnow()} Rank {self._args.my_rank} will read {len(self._dataset) * self.batch_size} files")
221
222 # self._dataset.sampler.set_epoch(epoch_number)
223
224 @dlp.log
225 def next(self):
226 super().next()
227 total = self._args.training_steps if self.dataset_type is DatasetType.TRAIN else self._args.eval_steps
228 self.logger.debug(f"{utcnow()} Rank {self._args.my_rank} should read {total} batches")
229 step = 1
230 for batch in dft_ai.dataloader.fetch.iter(self._dataset):
231 dlp.update(step=step)
232 dft_ai.update(step=step)
233 step += 1
234 yield batch
235 self.epoch_number += 1
236 dlp.update(epoch=self.epoch_number)
237 dft_ai.update(epoch=self.epoch_number)
238
239 @dlp.log
240 def finalize(self):
241 pass
The lines relevant to DYAD are:
The line 29 for importing PyDAYD
The lines 71-84 for initialization
The lines in
__get_item__()dlp. lines are for profiler and not directly relevant to DYAD