You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
I am a freshman using the horovod. I am trying to implement multiple GPU training on a machine. I can run the example "pytorch_mnist.py" in the folder of PyTorch. I just followed what the code did in the file "pytorch_mnist" to implement my code for my task. However, my code will need some help. From my output, my code can only get data from a dataloader on one local rank but not the others. What did I miss?
my code is here:
rnd = 42
# Horovod: initialize library.
hvd.init()
torch.manual_seed(rnd) # attention
print("hvd initial done")
# Horovod: pin GPU to local rank.
torch.cuda.set_device(hvd.local_rank())
torch.cuda.manual_seed(rnd)
print("hvd pin GPU done")
# Horovod: limit # of CPU threads to be used per worker.
torch.set_num_threads(args.num_workers)
print("hvd limit num of CPU threads done")
all_vals = []
data1, data2 = load_fn(args)
i = 0
train_dataset, val_dataset, max_num_nodes, input_dim, assign_input_dim, train_sampler = \
prepare_val_data(data1, data2)
print("data prepare finished")
# must use multiple GPUs
device = hvd.local_rank()
model = my_model.to(device)
print("model building finished")
lr_scaler = hvd.local_size()
optimizer = torch.optim.Adam(filter(lambda p : p.requires_grad, model.parameters()), lr=0.001*lr_scaler)
# Horovod: broadcast parameters & optimizer state.
hvd.broadcast_parameters(model.state_dict(), root_rank=0)
hvd.broadcast_optimizer_state(optimizer, root_rank=0)
print("hvd broadcast set done")
# Horovod: (optional) compression algorithm.
compression = hvd.Compression.fp16
# Horovod: wrap optimizer with DistributedOptimizer.
optimizer = hvd.DistributedOptimizer(optimizer,
named_parameters=model.named_parameters(),
op=hvd.Average,
compression=compression
)
print("hvd wrap optimizer done")
_, val_accs = model_train(train_dataset, model, args, device, train_sampler, optimizer, val_dataset=val_dataset)
print(sys._getframe(1).f_lineno)
# hvd.barrier()
all_vals.append(np.array(val_accs))
all_vals = np.vstack(all_vals)
all_vals = np.mean(all_vals, axis=0)
print(all_vals)
print(np.max(all_vals))
print(np.argmax(all_vals))
print(sys._getframe(1).f_lineno)
In the function "model_train", there are functions: train_phase and eval_phase. The following code is the contents of "train_phase":
iter = 0
avg_loss = 0.0
model.train()
batch_idx = 0
begin_time = time.time()
# Horovod: set epoch to sampler for shuffling.
sampler.set_epoch(epoch)
print("hvd sampler for shuffling done")
for data in train_dataset:
print("iteration")
x_t1, l_t1, x_t2, l_t2 = data
if iter % 200 == 0:
print('got data done')
optimizer.zero_grad()
x_t1 = Variable(x_t1.long()).to(device)
x_t2 = Variable(x_t2.long()).to(device)
y_t1, y_t2 = model(x_t1, x_t2)
loss1 = model.loss1(y_t1, l_t1.float())
loss2 = model.loss2(y_t2, l_t2.float())
loss = loss1 + loss2
loss.backward()
if iter % 200 == 0:
print("got local loss and backward done")
# added following line based on the doc of hvd.DistributedOptimizer()
optimizer.synchronize()
nn.utils.clip_grad_norm_(model.parameters(), args.clip)
with optimizer.skip_synchronize():
optimizer.step()
iter += 1
avg_loss += loss
if iter % 200 == 0:
print('Iter: ', iter, ', loss: ', loss)
batch_idx += 1
print("local training phase done")
avg_loss /= batch_idx
elapsed = time.time() - begin_time
tot_avg_loss = metric_average(avg_loss, 'train_loss')
print('Epoch: ', epoch, '; Avg loss: ', tot_avg_loss, '; epoch time: ', elapsed)
The following is info from stdout and errout
[0] hvd initial done
[0] hvd pin GPU done
[0] hvd limit num of CPU threads done
[0] data prepare finished
[0] model building finished
[0] hvd broadcast set done
[0] hvd wrap optimizer done
[0] hvd sampler for shuffling done
[1] hvd initial done
[1] hvd pin GPU done
[1] hvd limit num of CPU threads done
[1] data prepare finished
[1] model building finished
[1] hvd broadcast set done
[1] hvd wrap optimizer done
[1] hvd sampler for shuffling done
[2] hvd initial done
[2] hvd pin GPU done
[2] hvd limit num of CPU threads done
[2] data prepare finished
[2] model building finished
[2] hvd broadcast set done
[2] hvd wrap optimizer done
[2] hvd sampler for shuffling done
[3] hvd initial done
[3] hvd pin GPU done
[3] hvd limit num of CPU threads done
[3] data prepare finished
[3] model building finished
[3] hvd broadcast set done
[3] hvd wrap optimizer done
[3] hvd sampler for shuffling done
[1] printing in __getitem__
...
[1] printing in __getitem__
[0] [2023-05-30 06:50:26.125711: W /tmp/pip-install-i_hsfg4g/horovod_c04a28a1943946a99c6b8d5051636d7b/horovod/common/stall_inspector.cc:107] One or more tensors were submitted to be reduced, gathered or broadcasted by subset of ranks and are waiting for remainder of ranks for more than 60 seconds. This may indicate that different ranks are trying to submit different tensors or that only subset of ranks is submitting tensors, which will cause deadlock.
[0] Missing ranks:
[0] 0: [allreduce.train_loss]
[0] 2: [allreduce.train_loss]
[0] 3: [allreduce.train_loss]
[0] [2023-05-30 06:51:26.126851: W /tmp/pip-install-i_hsfg4g/horovod_c04a28a1943946a99c6b8d5051636d7b/horovod/common/stall_inspector.cc:107] One or more tensors were submitted to be reduced, gathered or broadcasted by subset of ranks and are waiting for remainder of ranks for more than 60 seconds. This may indicate that different ranks are trying to submit different tensors or that only subset of ranks is submitting tensors, which will cause deadlock.
[0] Missing ranks:
[0] 0: [allreduce.train_loss]
[0] 1: [allreduce.assign_conv_block_modules.0.0.bias, allreduce.assign_conv_block_modules.0.0.weight, allreduce.assign_conv_first_modules.0.bias, allreduce.assign_conv_first_modules.0.weight, allreduce.assign_conv_last_modules.0.bias, allreduce.assign_conv_last_modules.0.weight ...]
[0] 2: [allreduce.train_loss]
[0] 3: [allreduce.train_loss]
[0] [2023-05-30 06:52:26.127902: W /tmp/pip-install-i_hsfg4g/horovod_c04a28a1943946a99c6b8d5051636d7b/horovod/common/stall_inspector.cc:107] One or more tensors were submitted to be reduced, gathered or broadcasted by subset of ranks and are waiting for remainder of ranks for more than 60 seconds. This may indicate that different ranks are trying to submit different tensors or that only subset of ranks is submitting tensors, which will cause deadlock.
[0] Missing ranks:
[0] 0: [allreduce.train_loss]
[0] 1: [allreduce.assign_conv_block_modules.0.0.bias, allreduce.assign_conv_block_modules.0.0.weight, allreduce.assign_conv_first_modules.0.bias, allreduce.assign_conv_first_modules.0.weight, allreduce.assign_conv_last_modules.0.bias, allreduce.assign_conv_last_modules.0.weight ...]
[0] 2: [allreduce.train_loss]
[0] 3: [allreduce.train_loss]
[0] [2023-05-30 06:53:26.128115: W /tmp/pip-install-i_hsfg4g/horovod_c04a28a1943946a99c6b8d5051636d7b/horovod/common/stall_inspector.cc:107] One or more tensors were submitted to be reduced, gathered or broadcasted by subset of ranks and are waiting for remainder of ranks for more than 60 seconds. This may indicate that different ranks are trying to submit different tensors or that only subset of ranks is submitting tensors, which will cause deadlock.
[0] Missing ranks:
[0] 0: [allreduce.train_loss]
[0] 1: [allreduce.assign_conv_block_modules.0.0.bias, allreduce.assign_conv_block_modules.0.0.weight, allreduce.assign_conv_first_modules.0.bias, allreduce.assign_conv_first_modules.0.weight, allreduce.assign_conv_last_modules.0.bias, allreduce.assign_conv_last_modules.0.weight ...]
[0] 2: [allreduce.train_loss]
[0] 3: [allreduce.train_loss]
[0] [2023-05-30 06:54:26.128618: W /tmp/pip-install-i_hsfg4g/horovod_c04a28a1943946a99c6b8d5051636d7b/horovod/common/stall_inspector.cc:107] One or more tensors were submitted to be reduced, gathered or broadcasted by subset of ranks and are waiting for remainder of ranks for more than 60 seconds. This may indicate that different ranks are trying to submit different tensors or that only subset of ranks is submitting tensors, which will cause deadlock.
[0] Missing ranks:
[0] 0: [allreduce.train_loss]
[0] 1: [allreduce.assign_conv_block_modules.0.0.bias, allreduce.assign_conv_block_modules.0.0.weight, allreduce.assign_conv_first_modules.0.bias, allreduce.assign_conv_first_modules.0.weight, allreduce.assign_conv_last_modules.0.bias, allreduce.assign_conv_last_modules.0.weight ...]
[0] 2: [allreduce.train_loss]
[0] 3: [allreduce.train_loss]
[0] [2023-05-30 06:55:26.129367: W /tmp/pip-install-i_hsfg4g/horovod_c04a28a1943946a99c6b8d5051636d7b/horovod/common/stall_inspector.cc:107] One or more tensors were submitted to be reduced, gathered or broadcasted by subset of ranks and are waiting for remainder of ranks for more than 60 seconds. This may indicate that different ranks are trying to submit different tensors or that only subset of ranks is submitting tensors, which will cause deadlock.
[0] Missing ranks:
[0] 0: [allreduce.train_loss]
[0] 1: [allreduce.assign_conv_block_modules.0.0.bias, allreduce.assign_conv_block_modules.0.0.weight, allreduce.assign_conv_first_modules.0.bias, allreduce.assign_conv_first_modules.0.weight, allreduce.assign_conv_last_modules.0.bias, allreduce.assign_conv_last_modules.0.weight ...]
[0] 2: [allreduce.train_loss]
[0] 3: [allreduce.train_loss]
[0] [2023-05-30 06:56:26.130076: W /tmp/pip-install-i_hsfg4g/horovod_c04a28a1943946a99c6b8d5051636d7b/horovod/common/stall_inspector.cc:107] One or more tensors were submitted to be reduced, gathered or broadcasted by subset of ranks and are waiting for remainder of ranks for more than 60 seconds. This may indicate that different ranks are trying to submit different tensors or that only subset of ranks is submitting tensors, which will cause deadlock.
[0] Missing ranks:
[0] 0: [allreduce.train_loss]
[0] 1: [allreduce.assign_conv_block_modules.0.0.bias, allreduce.assign_conv_block_modules.0.0.weight, allreduce.assign_conv_first_modules.0.bias, allreduce.assign_conv_first_modules.0.weight, allreduce.assign_conv_last_modules.0.bias, allreduce.assign_conv_last_modules.0.weight ...]
[0] 2: [allreduce.train_loss]
[0] 3: [allreduce.train_loss]
[0] [2023-05-30 06:57:26.131194: W /tmp/pip-install-i_hsfg4g/horovod_c04a28a1943946a99c6b8d5051636d7b/horovod/common/stall_inspector.cc:107] One or more tensors were submitted to be reduced, gathered or broadcasted by subset of ranks and are waiting for remainder of ranks for more than 60 seconds. This may indicate that different ranks are trying to submit different tensors or that only subset of ranks is submitting tensors, which will cause deadlock.
[0] Missing ranks:
[0] 0: [allreduce.train_loss]
[0] 1: [allreduce.assign_conv_block_modules.0.0.bias, allreduce.assign_conv_block_modules.0.0.weight, allreduce.assign_conv_first_modules.0.bias, allreduce.assign_conv_first_modules.0.weight, allreduce.assign_conv_last_modules.0.bias, allreduce.assign_conv_last_modules.0.weight ...]
[0] 2: [allreduce.train_loss]
[0] 3: [allreduce.train_loss]
[0] [2023-05-30 06:58:26.132300: W /tmp/pip-install-i_hsfg4g/horovod_c04a28a1943946a99c6b8d5051636d7b/horovod/common/stall_inspector.cc:107] One or more tensors were submitted to be reduced, gathered or broadcasted by subset of ranks and are waiting for remainder of ranks for more than 60 seconds. This may indicate that different ranks are trying to submit different tensors or that only subset of ranks is submitting tensors, which will cause deadlock.
[0] Missing ranks:
[0] 0: [allreduce.train_loss]
[0] 1: [allreduce.assign_conv_block_modules.0.0.bias, allreduce.assign_conv_block_modules.0.0.weight, allreduce.assign_conv_first_modules.0.bias, allreduce.assign_conv_first_modules.0.weight, allreduce.assign_conv_last_modules.0.bias, allreduce.assign_conv_last_modules.0.weight ...]
[0] 2: [allreduce.train_loss]
[0] 3: [allreduce.train_loss]
Another weird error is that "hvd initial done" can't be printed out sometimes.
reacted with thumbs up emoji reacted with thumbs down emoji reacted with laugh emoji reacted with hooray emoji reacted with confused emoji reacted with heart emoji reacted with rocket emoji reacted with eyes emoji
-
I am a freshman using the horovod. I am trying to implement multiple GPU training on a machine. I can run the example "pytorch_mnist.py" in the folder of PyTorch. I just followed what the code did in the file "pytorch_mnist" to implement my code for my task. However, my code will need some help. From my output, my code can only get data from a dataloader on one local rank but not the others. What did I miss?
my code is here:
In the function "model_train", there are functions: train_phase and eval_phase. The following code is the contents of "train_phase":
The following is info from stdout and errout
Another weird error is that "hvd initial done" can't be printed out sometimes.
Beta Was this translation helpful? Give feedback.
All reactions