Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

ReFL implement details #57

Open
hkunzhe opened this issue Sep 11, 2023 · 2 comments
Open

ReFL implement details #57

hkunzhe opened this issue Sep 11, 2023 · 2 comments

Comments

@hkunzhe
Copy link

hkunzhe commented Sep 11, 2023

As mentioned in #24 and #34, the current ReFL code only the ReFL loss is implemented and the pre-training loss is not included. In addition, the two losses are optimized alternately.

I want to add pre-training data myself. If we don't use the gradient accumulation, the pseudo code would be like this:

# Given optimizer and lr_scheduler with unet.
# Compute Pre-training Loss `train_loss` with unet and update unet.
train_loss.backward()
optimizer.step()
lr_scheduler.step()  # is it necessary?
optimizer.zero_grad()

# Compute ReFL Loss `refl_loss` with unet and update unet.
refl_loss.backward()
optimizer.step()
lr_scheduler.step()
optimizer.zero_grad()

However, I'm confused about how to add accelerator.accumulate(unet) for gradient accumulation after reading this post. And I also raised the issue huggingface/accelerate#1870 and discussion in the huggingface accelerate github repo and forum. But I don't seem to get a clear answer. Can you give me some pseudo codes or hints? Thank you very much! @xujz18 @tongyx361

@hkunzhe
Copy link
Author

hkunzhe commented Sep 11, 2023

May be duplicated with #34 (comment). I was afraid it wouldn't be seen in a closed issue, so I raised this new issue.

@xujz18
Copy link
Member

xujz18 commented Sep 17, 2023

Your understanding is correct and I appreciate the discussion.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

No branches or pull requests

2 participants