This repository has been archived by the owner on Jun 3, 2022. It is now read-only.
-
Notifications
You must be signed in to change notification settings - Fork 0
/
dlrm_main.py
450 lines (402 loc) · 15.8 KB
/
dlrm_main.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
#!/usr/bin/env python3
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.
import argparse
import itertools
import os
import sys
from dataclasses import dataclass, field
from typing import cast, Iterator, List, Optional, Tuple
import torch
import torchmetrics as metrics
from pyre_extensions import none_throws
from torch import nn, distributed as dist
from torch.utils.data import DataLoader
from torchrec import EmbeddingBagCollection
from torchrec.datasets.criteo import DEFAULT_CAT_NAMES, DEFAULT_INT_NAMES
from torchrec.datasets.utils import Batch
from torchrec.distributed import TrainPipelineSparseDist
from torchrec.distributed.embeddingbag import EmbeddingBagCollectionSharder
from torchrec.distributed.model_parallel import DistributedModelParallel
from torchrec.distributed.types import ModuleSharder
from torchrec.modules.embedding_configs import EmbeddingBagConfig
from torchrec.optim.keyed import CombinedOptimizer, KeyedOptimizerWrapper
from tqdm import tqdm
# OSS import
try:
# pyre-ignore[21]
# @manual=//torchrec/github/examples/dlrm/data:dlrm_dataloader
from data.dlrm_dataloader import get_dataloader, STAGES
# pyre-ignore[21]
# @manual=//torchrec/github/examples/dlrm/modules:dlrm_train
from modules.dlrm_train import DLRMTrain
except ImportError:
pass
# internal import
try:
from .data.dlrm_dataloader import ( # noqa F811
get_dataloader,
STAGES,
)
from .modules.dlrm_train import DLRMTrain # noqa F811
except ImportError:
pass
TRAIN_PIPELINE_STAGES = 3 # Number of stages in TrainPipelineSparseDist.
def parse_args(argv: List[str]) -> argparse.Namespace:
parser = argparse.ArgumentParser(description="torchrec dlrm example trainer")
parser.add_argument(
"--epochs", type=int, default=1, help="number of epochs to train"
)
parser.add_argument(
"--batch_size", type=int, default=32, help="batch size to use for training"
)
parser.add_argument(
"--limit_train_batches",
type=int,
default=None,
help="number of train batches",
)
parser.add_argument(
"--limit_val_batches",
type=int,
default=None,
help="number of validation batches",
)
parser.add_argument(
"--limit_test_batches",
type=int,
default=None,
help="number of test batches",
)
parser.add_argument(
"--dataset_name",
type=str,
default="criteo_1t",
help="dataset for experiment, current support criteo_1tb, criteo_kaggle",
)
parser.add_argument(
"--num_workers",
type=int,
default=2,
help="number of dataloader workers",
)
parser.add_argument(
"--num_embeddings",
type=int,
default=100_000,
help="max_ind_size. The number of embeddings in each embedding table. Defaults"
" to 100_000 if num_embeddings_per_feature is not supplied.",
)
parser.add_argument(
"--num_embeddings_per_feature",
type=str,
default=None,
help="Comma separated max_ind_size per sparse feature. The number of embeddings"
" in each embedding table. 26 values are expected for the Criteo dataset.",
)
parser.add_argument(
"--dense_arch_layer_sizes",
type=str,
default="512,256,64",
help="Comma separated layer sizes for dense arch.",
)
parser.add_argument(
"--over_arch_layer_sizes",
type=str,
default="512,512,256,1",
help="Comma separated layer sizes for over arch.",
)
parser.add_argument(
"--embedding_dim",
type=int,
default=64,
help="Size of each embedding.",
)
parser.add_argument(
"--undersampling_rate",
type=float,
help="Desired proportion of zero-labeled samples to retain (i.e. undersampling zero-labeled rows)."
" Ex. 0.3 indicates only 30pct of the rows with label 0 will be kept."
" All rows with label 1 will be kept. Value should be between 0 and 1."
" When not supplied, no undersampling occurs.",
)
parser.add_argument(
"--seed",
type=float,
help="Random seed for reproducibility.",
)
parser.add_argument(
"--pin_memory",
dest="pin_memory",
action="store_true",
help="Use pinned memory when loading data.",
)
parser.add_argument(
"--in_memory_binary_criteo_path",
type=str,
default=None,
help="Path to a folder containing the binary (npy) files for the Criteo dataset."
" When supplied, InMemoryBinaryCriteoIterDataPipe is used.",
)
parser.add_argument(
"--learning_rate",
type=float,
default=15.0,
help="Learning rate.",
)
parser.add_argument(
"--shuffle_batches",
type=bool,
default=False,
help="Shuffle each batch during training.",
)
parser.set_defaults(pin_memory=None)
return parser.parse_args(argv)
def _evaluate(
args: argparse.Namespace,
train_pipeline: TrainPipelineSparseDist,
iterator: Iterator[Batch],
next_iterator: Iterator[Batch],
stage: str,
) -> Tuple[float, float]:
"""
Evaluate model. Computes and prints metrics including AUROC and Accuracy. Helper
function for train_val_test.
Args:
args (argparse.Namespace): parsed command line args.
train_pipeline (TrainPipelineSparseDist): pipelined model.
iterator (Iterator[Batch]): Iterator used for val/test batches.
next_iterator (Iterator[Batch]): Iterator used for the next phase (either train
if there are more epochs to train on or test if all epochs are complete).
Used to queue up the next TRAIN_PIPELINE_STAGES - 1 batches before
train_val_test switches to the next phase. This is done so that when the
next phase starts, the first output train_pipeline generates an output for
is the 1st batch for that phase.
stage (str): "val" or "test".
Returns:
None.
"""
model = train_pipeline._model
model.eval()
device = train_pipeline._device
limit_batches = (
args.limit_val_batches if stage == "val" else args.limit_test_batches
)
if limit_batches is not None:
limit_batches -= TRAIN_PIPELINE_STAGES - 1
# Because TrainPipelineSparseDist buffer batches internally, we load in
# TRAIN_PIPELINE_STAGES - 1 batches from the next_iterator into the buffers so that
# when train_val_test switches to the next phase, train_pipeline will start
# producing results for the TRAIN_PIPELINE_STAGES - 1 buffered batches (as opposed
# to the last TRAIN_PIPELINE_STAGES - 1 batches from iterator).
combined_iterator = itertools.chain(
iterator
if limit_batches is None
else itertools.islice(iterator, limit_batches),
itertools.islice(next_iterator, TRAIN_PIPELINE_STAGES - 1),
)
auroc = metrics.AUROC(compute_on_step=False).to(device)
accuracy = metrics.Accuracy(compute_on_step=False).to(device)
# Infinite iterator instead of while-loop to leverage tqdm progress bar.
for _ in tqdm(iter(int, 1), desc=f"Evaluating {stage} set"):
try:
_loss, logits, labels = train_pipeline.progress(combined_iterator)
auroc(logits, labels)
accuracy(logits, labels)
except StopIteration:
break
auroc_result = auroc.compute().item()
accuracy_result = accuracy.compute().item()
if dist.get_rank() == 0:
print(f"AUROC over {stage} set: {auroc_result}.")
print(f"Accuracy over {stage} set: {accuracy_result}.")
return auroc_result, accuracy_result
def _train(
args: argparse.Namespace,
train_pipeline: TrainPipelineSparseDist,
iterator: Iterator[Batch],
next_iterator: Iterator[Batch],
epoch: int,
) -> None:
"""
Train model for 1 epoch. Helper function for train_val_test.
Args:
args (argparse.Namespace): parsed command line args.
train_pipeline (TrainPipelineSparseDist): pipelined model.
iterator (Iterator[Batch]): Iterator used for training batches.
next_iterator (Iterator[Batch]): Iterator used for validation batches. Used to
queue up the next TRAIN_PIPELINE_STAGES - 1 batches before train_val_test
switches to validation mode. This is done so that when validation starts,
the first output train_pipeline generates an output for is the 1st
validation batch (as opposed to a buffered train batch).
epoch (int): Which epoch the model is being trained on.
Returns:
None.
"""
train_pipeline._model.train()
limit_batches = args.limit_train_batches
# For the first epoch, train_pipeline has no buffered batches, but for all other
# epochs, train_pipeline will have TRAIN_PIPELINE_STAGES - 1 from iterator already
# present in its buffer.
if limit_batches is not None and epoch > 0:
limit_batches -= TRAIN_PIPELINE_STAGES - 1
# Because TrainPipelineSparseDist buffer batches internally, we load in
# TRAIN_PIPELINE_STAGES - 1 batches from the next_iterator into the buffers so that
# when train_val_test switches to the next phase, train_pipeline will start
# producing results for the TRAIN_PIPELINE_STAGES - 1 buffered batches (as opposed
# to the last TRAIN_PIPELINE_STAGES - 1 batches from iterator).
combined_iterator = itertools.chain(
iterator
if args.limit_train_batches is None
else itertools.islice(iterator, limit_batches),
itertools.islice(next_iterator, TRAIN_PIPELINE_STAGES - 1),
)
# Infinite iterator instead of while-loop to leverage tqdm progress bar.
for _ in tqdm(iter(int, 1), desc=f"Epoch {epoch}"):
try:
train_pipeline.progress(combined_iterator)
except StopIteration:
break
@dataclass
class TrainValTestResults:
val_accuracies: List[float] = field(default_factory=list)
val_aurocs: List[float] = field(default_factory=list)
test_accuracy: Optional[float] = None
test_auroc: Optional[float] = None
def train_val_test(
args: argparse.Namespace,
train_pipeline: TrainPipelineSparseDist,
train_dataloader: DataLoader,
val_dataloader: DataLoader,
test_dataloader: DataLoader,
) -> TrainValTestResults:
"""
Train/validation/test loop. Contains customized logic to ensure each dataloader's
batches are used for the correct designated purpose (train, val, test). This logic
is necessary because TrainPipelineSparseDist buffers batches internally (so we
avoid batches designated for one purpose like training getting buffered and used for
another purpose like validation).
Args:
args (argparse.Namespace): parsed command line args.
train_pipeline (TrainPipelineSparseDist): pipelined model.
train_dataloader (DataLoader): DataLoader used for training.
val_dataloader (DataLoader): DataLoader used for validation.
test_dataloader (DataLoader): DataLoader used for testing.
Returns:
TrainValTestResults.
"""
train_val_test_results = TrainValTestResults()
train_iterator = iter(train_dataloader)
test_iterator = iter(test_dataloader)
for epoch in range(args.epochs):
val_iterator = iter(val_dataloader)
_train(args, train_pipeline, train_iterator, val_iterator, epoch)
train_iterator = iter(train_dataloader)
val_next_iterator = (
test_iterator if epoch == args.epochs - 1 else train_iterator
)
val_accuracy, val_auroc = _evaluate(
args, train_pipeline, val_iterator, val_next_iterator, "val"
)
train_val_test_results.val_accuracies.append(val_accuracy)
train_val_test_results.val_aurocs.append(val_auroc)
test_accuracy, test_auroc = _evaluate(
args, train_pipeline, test_iterator, iter(test_dataloader), "test"
)
train_val_test_results.test_accuracy = test_accuracy
train_val_test_results.test_auroc = test_auroc
return train_val_test_results
def main(argv: List[str]) -> None:
"""
Trains, validates, and tests a Deep Learning Recommendation Model (DLRM)
(https://arxiv.org/abs/1906.00091). The DLRM model contains both data parallel
components (e.g. multi-layer perceptrons & interaction arch) and model parallel
components (e.g. embedding tables). The DLRM model is pipelined so that dataloading,
data-parallel to model-parallel comms, and forward/backward are overlapped. Can be
run with either a random dataloader or an in-memory Criteo 1 TB click logs dataset
(https://ailab.criteo.com/download-criteo-1tb-click-logs-dataset/).
Args:
argv (List[str]): command line args.
Returns:
None.
"""
args = parse_args(argv)
rank = int(os.environ["LOCAL_RANK"])
if torch.cuda.is_available():
device: torch.device = torch.device(f"cuda:{rank}")
backend = "nccl"
torch.cuda.set_device(device)
else:
device: torch.device = torch.device("cpu")
backend = "gloo"
if not torch.distributed.is_initialized():
dist.init_process_group(backend=backend)
if args.num_embeddings_per_feature is not None:
args.num_embeddings_per_feature = list(
map(int, args.num_embeddings_per_feature.split(","))
)
args.num_embeddings = None
# TODO add CriteoIterDataPipe support and add random_dataloader arg
train_dataloader = get_dataloader(args, backend, "train")
val_dataloader = get_dataloader(args, backend, "val")
test_dataloader = get_dataloader(args, backend, "test")
# Sets default limits for random dataloader iterations when left unspecified.
if args.in_memory_binary_criteo_path is None:
for stage in STAGES:
attr = f"limit_{stage}_batches"
if getattr(args, attr) is None:
setattr(args, attr, 10)
eb_configs = [
EmbeddingBagConfig(
name=f"t_{feature_name}",
embedding_dim=args.embedding_dim,
num_embeddings=none_throws(args.num_embeddings_per_feature)[feature_idx]
if args.num_embeddings is None
else args.num_embeddings,
feature_names=[feature_name],
)
for feature_idx, feature_name in enumerate(DEFAULT_CAT_NAMES)
]
sharded_module_kwargs = {}
if args.over_arch_layer_sizes is not None:
sharded_module_kwargs["over_arch_layer_sizes"] = list(
map(int, args.over_arch_layer_sizes.split(","))
)
train_model = DLRMTrain(
embedding_bag_collection=EmbeddingBagCollection(
tables=eb_configs, device=torch.device("meta")
),
dense_in_features=len(DEFAULT_INT_NAMES),
dense_arch_layer_sizes=list(map(int, args.dense_arch_layer_sizes.split(","))),
over_arch_layer_sizes=list(map(int, args.over_arch_layer_sizes.split(","))),
dense_device=device,
)
fused_params = {
"learning_rate": args.learning_rate,
}
sharders = [
EmbeddingBagCollectionSharder(fused_params=fused_params),
]
model = DistributedModelParallel(
module=train_model,
device=device,
sharders=cast(List[ModuleSharder[nn.Module]], sharders),
)
dense_optimizer = KeyedOptimizerWrapper(
dict(model.named_parameters()),
lambda params: torch.optim.SGD(params, lr=args.learning_rate),
)
optimizer = CombinedOptimizer([model.fused_optimizer, dense_optimizer])
train_pipeline = TrainPipelineSparseDist(
model,
optimizer,
device,
)
train_val_test(
args, train_pipeline, train_dataloader, val_dataloader, test_dataloader
)
if __name__ == "__main__":
main(sys.argv[1:])