Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[NewFeature] Add Minigpt4 Train Code #119

Open
wants to merge 5 commits into
base: develop
Choose a base branch
from

Conversation

1649759610
Copy link
Collaborator

No description provided.

"CosineDecayWithWarmup",
"FilterParamsName",
]

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

这部分现在放在optimization
里,如果是定制化较强的opt,再新建xxx_optimizer文件

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

在optimization.py修改了之前的CosineDecayWithWarmup,在clip_optimizer.py中,删掉了

# limitations under the License.

import os
os.environ["CUDA_VISIBLE_DEVICES"] = "7"
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

删除debug代码

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

已删除

# if you wanna train from scratch, you can set del_keys = ["language_projection.weight", "language_projection.bias"]
del_keys = []
logger.info("Try to load the specified model.")
load_pretrained_model(model, training_args.pretrained_model_path, del_keys=del_keys)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

这里可以统一用model.load_pretrained接口,参考blip2实现
另外需要统一trainer

@@ -40,41 +40,43 @@ class CosineDecayWithWarmup(LRScheduler):
def __init__(
self,
learning_rate,
epochs,
total_steps,
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

blip2需要适配

@@ -150,20 +188,14 @@ def process_texts(
assemble_texts.append(prompt.replace(self.text_tag, text))
else:
assemble_texts.append(text)

# breakpoint()
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

删除debug代码

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

已删除

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

None yet

2 participants