Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

使用baichuan-13b训练reward model loss不下降 #175

Open
zhangzuizui opened this issue Sep 5, 2023 · 0 comments
Open

使用baichuan-13b训练reward model loss不下降 #175

zhangzuizui opened this issue Sep 5, 2023 · 0 comments

Comments

@zhangzuizui
Copy link

baichuan-7b可以正常训练,环境是cuda117+torch2.0.1

代码如下:

class Baichuan13BForRewardModel(Baichuan13BPreTrainedModel):
    def __init__(self, config):
        super().__init__(config)
        self.config = config
        self.model = Baichuan13BModel(config)
        self.classifier = nn.Linear(config.hidden_size, 1)

        # Initialize weights and apply final processing
        self.post_init()

    def forward(
        self,
        input_ids: torch.LongTensor,
        attention_mask: Optional[torch.Tensor] = None,
        **kwargs
    ):
        outputs = self.model(input_ids, attention_mask=attention_mask)
        # (batch_size, seq_len, hidden_dim)
        last_hidden_state = outputs.last_hidden_state

        # 去掉padding的token
        logits = self.classifier(last_hidden_state)

        batch_size = input_ids.shape[0]
        sequence_lengths = (
            torch.ne(input_ids, self.config.pad_token_id).sum(-1) - 1
        ).to(logits.device)

        pooled_logits = logits[
            torch.arange(batch_size, device=logits.device), sequence_lengths
        ]
        return ModelOutput(logits=pooled_logits)

# loss计算
def get_rank_loss(
    model: AutoModelForSequenceClassification,
    text1_input_ids: torch.LongTensor,
    text2_input_ids: torch.LongTensor,
    text1_attention_mask: Optional[torch.Tensor] = None,
    text2_attention_mask: Optional[torch.Tensor] = None
) -> torch.Tensor:
    logits1 = model(
        text1_input_ids, attention_mask=text1_attention_mask
    ).logits
    logits2 = model(
        text2_input_ids, attention_mask=text2_attention_mask
    ).logits
    loss = torch.nn.functional.logsigmoid(logits1 - logits2).mean()
    return -loss

训练过程中发现loss一直在0.69上下波动(两个logits相减为0,取logsigmoid后就是-0.69)

请问有什么需要额外注意的点么?

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

1 participant