Skip to content

Commit

Permalink
fix #1742
Browse files Browse the repository at this point in the history
  • Loading branch information
hiyouga committed Dec 16, 2023
1 parent 7ae6919 commit 870426f
Show file tree
Hide file tree
Showing 3 changed files with 22 additions and 17 deletions.
9 changes: 8 additions & 1 deletion src/llmtuner/model/loader.py
Original file line number Diff line number Diff line change
Expand Up @@ -81,9 +81,16 @@ def load_model_and_tokenizer(
if add_valuehead:
model: "AutoModelForCausalLMWithValueHead" = AutoModelForCausalLMWithValueHead.from_pretrained(model)
patcher.patch_valuehead_model(model)
vhead_params = load_valuehead_params(model_args)

if model_args.adapter_name_or_path is not None:
vhead_path = model_args.adapter_name_or_path[-1]
else:
vhead_path = model_args.model_name_or_path

vhead_params = load_valuehead_params(vhead_path, model_args)
if vhead_params is not None:
model.load_state_dict(vhead_params, strict=False)
logger.info("Loaded valuehead from checkpoint: {}".format(vhead_path))

if not is_trainable:
model.requires_grad_(False) # fix all model params
Expand Down
21 changes: 7 additions & 14 deletions src/llmtuner/model/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -85,34 +85,21 @@ def get_modelcard_args(
}


def load_valuehead_params(model_args: "ModelArguments") -> Dict[str, torch.Tensor]:
def load_valuehead_params(path_or_repo_id: str, model_args: "ModelArguments") -> Dict[str, torch.Tensor]:
r"""
Loads value head parameters from Hugging Face Hub or local disk.
Returns: dict with keys `v_head.summary.weight` and `v_head.summary.bias`.
"""
if model_args.adapter_name_or_path is not None:
path_or_repo_id = model_args.adapter_name_or_path[-1]
else:
path_or_repo_id = model_args.model_name_or_path

kwargs = {
"path_or_repo_id": path_or_repo_id,
"cache_dir": model_args.cache_dir,
"token": model_args.hf_hub_token
}

try:
vhead_file = cached_file(filename=WEIGHTS_NAME, **kwargs)
logger.info("Loaded valuehead from {}".format(path_or_repo_id))
return torch.load(vhead_file, map_location="cpu")
except Exception as err:
logger.info("Failed to load {}: {}".format(WEIGHTS_NAME, str(err)))

try:
from safetensors import safe_open
vhead_file = cached_file(filename=SAFE_WEIGHTS_NAME, **kwargs)
logger.info("Loaded valuehead from {}".format(path_or_repo_id))
with safe_open(vhead_file, framework="pt", device="cpu") as f:
return {
"v_head.summary.weight": f.get_tensor("v_head.summary.weight"),
Expand All @@ -121,6 +108,12 @@ def load_valuehead_params(model_args: "ModelArguments") -> Dict[str, torch.Tenso
except Exception as err:
logger.info("Failed to load {}: {}".format(SAFE_WEIGHTS_NAME, str(err)))

try:
vhead_file = cached_file(filename=WEIGHTS_NAME, **kwargs)
return torch.load(vhead_file, map_location="cpu")
except Exception as err:
logger.info("Failed to load {}: {}".format(WEIGHTS_NAME, str(err)))

logger.warning("Provided path ({}) does not contain valuehead weights.".format(path_or_repo_id))
return None

Expand Down
9 changes: 7 additions & 2 deletions src/llmtuner/train/ppo/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -206,6 +206,11 @@ def get_inputs(self, batch: Dict[str, torch.Tensor]) -> Tuple[List[torch.Tensor]
if self.finetuning_args.upcast_layernorm:
layernorm_params = dump_layernorm(self.model)

if batch["input_ids"].size(0) == 1: # handle llama2 ppo with gradient accumulation > 1
start_index = (batch["input_ids"][0] != self.tokenizer.pad_token_id).nonzero()[0].item()
for k, v in batch.items():
batch[k] = v[:, start_index:]

unwrapped_model: "AutoModelForCausalLMWithValueHead" = self.accelerator.unwrap_model(self.model)
generate_output: torch.Tensor = unwrapped_model.generate(
generation_config=self.generation_config,
Expand All @@ -220,15 +225,15 @@ def get_inputs(self, batch: Dict[str, torch.Tensor]) -> Tuple[List[torch.Tensor]
response = generate_output[:, batch["input_ids"].size(-1):].detach().cpu()
queries, responses = [], []
for i in range(len(query)):
query_length = (query[i] != self.tokenizer.pad_token_id).nonzero()[0].item()
query_start_index = (query[i] != self.tokenizer.pad_token_id).nonzero()[0].item()
response_index = (response[i] != self.tokenizer.pad_token_id).nonzero()

if len(response_index) == 0:
response_length = 1 # allow empty response
else:
response_length = response_index[-1].item() + 1

queries.append(query[i, query_length:]) # remove padding from left
queries.append(query[i, query_start_index:]) # remove padding from left
responses.append(response[i, :response_length]) # remove padding from right

return queries, responses
Expand Down

0 comments on commit 870426f

Please sign in to comment.