Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

使用distributed optimzer时grad_norm计算准确度的疑问 #56

Open
chivychao opened this issue Dec 27, 2023 · 1 comment
Open

使用distributed optimzer时grad_norm计算准确度的疑问 #56

chivychao opened this issue Dec 27, 2023 · 1 comment

Comments

@chivychao
Copy link

# Scale grad buffers by '1 / data_parallel_world_size'.
for model in self.models:
for dtype, gbuf in model._grad_buffers.items():
gbuf.data /= data_parallel_world_size
# Reduce-scatter all grads.
gbuf_view_items = self.get_model_grad_buffer_dp_views()
for index, (model_index, dtype, gbuf, gbuf_views) \
in enumerate(gbuf_view_items):
torch.distributed._reduce_scatter_base(
gbuf_views[data_parallel_rank],
gbuf,
group=data_parallel_group,
)

这里执行的应该是使得dp组内每个成员只获得自己维护的那一部分参数梯度的求和吧?

但这样做的话,在后面optimizer.step()中计算的grad_norm是不是就不是很准确了?

因为我看grad_norm计算的时候是dp组内每个成员把自己那部分模型的所有param的所有梯度都平方加和了,但是每个成员的grad只有一部分完成了dp组内求和,这样求出来的grad_norm感觉是错的。

请问是否确实存在这样的问题呢?

@li-yi-dong
Copy link
Collaborator

li-yi-dong commented Dec 29, 2023

torch.distributed.all_reduce(total_norm,

可以看看这段代码

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

2 participants