-
Notifications
You must be signed in to change notification settings - Fork 98
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[NewFeature] Add Minigpt4 Train Code #119
base: develop
Are you sure you want to change the base?
Conversation
"CosineDecayWithWarmup", | ||
"FilterParamsName", | ||
] | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
这部分现在放在optimization
里,如果是定制化较强的opt,再新建xxx_optimizer文件
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
在optimization.py修改了之前的CosineDecayWithWarmup,在clip_optimizer.py中,删掉了
# limitations under the License. | ||
|
||
import os | ||
os.environ["CUDA_VISIBLE_DEVICES"] = "7" |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
删除debug代码
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
已删除
# if you wanna train from scratch, you can set del_keys = ["language_projection.weight", "language_projection.bias"] | ||
del_keys = [] | ||
logger.info("Try to load the specified model.") | ||
load_pretrained_model(model, training_args.pretrained_model_path, del_keys=del_keys) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
这里可以统一用model.load_pretrained接口,参考blip2实现
另外需要统一trainer
@@ -40,41 +40,43 @@ class CosineDecayWithWarmup(LRScheduler): | |||
def __init__( | |||
self, | |||
learning_rate, | |||
epochs, | |||
total_steps, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
blip2需要适配
@@ -150,20 +188,14 @@ def process_texts( | |||
assemble_texts.append(prompt.replace(self.text_tag, text)) | |||
else: | |||
assemble_texts.append(text) | |||
|
|||
# breakpoint() |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
删除debug代码
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
已删除
No description provided.