Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Convert Orbax ckpt to HuggingFace #581

Open
wants to merge 1 commit into
base: main
Choose a base branch
from
Open

Convert Orbax ckpt to HuggingFace #581

wants to merge 1 commit into from

Conversation

A9isha
Copy link
Collaborator

@A9isha A9isha commented Apr 9, 2024

No description provided.

Copy link
Collaborator

@rwitten rwitten left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Change LGTMs but without a test I'm skeptical. I don't think this needs to be tested as exhaustively but how could we test it somewhat?

@rwitten rwitten removed their assignment Apr 11, 2024
@A9isha
Copy link
Collaborator Author

A9isha commented Apr 11, 2024

Change LGTMs but without a test I'm skeptical. I don't think this needs to be tested as exhaustively but how could we test it somewhat?

@rwitten I have tested locally but were you thinking running this at a nightly cadence?

@rwitten rwitten removed their assignment Apr 11, 2024
@thiagolaitz
Copy link

thiagolaitz commented May 6, 2024

@A9isha Hello, do you by any chance have a script that does the opposite, converting HF to Orbax?

@A9isha
Copy link
Collaborator Author

A9isha commented May 6, 2024

@A9isha Hello, do you by any chance have a script that does the opposite, converting HF to Orbax?

We have the script llama_or_mistral_ckpt.py to convert the original PyTorch Llama2 checkpoint that Meta provides into MaxText checkpoint.

You can see the usage here for Llama2-7b for e.g.

@hxssgaa
Copy link

hxssgaa commented May 10, 2024

Hi @A9isha , I found two bugs in your conversion code, and I have fixed it and validated the weights converted from maxtext version of llama3-8b with the HF one.

First one is the unpermute function is wrong, the original maxtext ckpt script used a step size of 2, the correct way is to stack the odd and even tensors and reshape it:

def unpermute_from_match_maxtext_rope(arr):
  """
  Function to get the RoPE values in correct ordering
  """
  split_size = arr.shape[-1] // 2  # Assuming half for evens, half for odds
  evens = arr[..., :split_size]
  odds = arr[..., split_size:]
  return jax.numpy.stack([evens, odds], axis=len(arr.shape)).reshape(arr.shape)

Second bug is related to Q and K, I understand it's easy to make mistakes here because both original LLaMA, LLaMA-HF and maxtext stored the tensor differently, the correct way is to do following by reversing first to original LLaMA weight then to HF weight:

    hf_model_params[f"model.layers.{layer_int}.self_attn.q_proj.weight"] = torch.tensor(np.asarray(
        unpermute_from_match_maxtext_rope(
          reverse_scale(
            training_state.params['params']["decoder"]["layers"]["self_attention"]["query"]["kernel"][:, layer_int, :, :]
            ,head_dim
            )
        )),
        dtype=torch.bfloat16
    )
    hf_model_params[f"model.layers.{layer_int}.self_attn.q_proj.weight"] = hf_model_params[f"model.layers.{layer_int}.self_attn.q_proj.weight"].view(base_num_query_heads * head_dim, base_num_query_heads * head_dim).T.view(base_num_query_heads, head_dim // 2, 2, base_num_query_heads * head_dim).transpose(1, 2).reshape(-1, base_num_query_heads * head_dim)

    hf_model_params[f"model.layers.{layer_int}.self_attn.k_proj.weight"] = torch.tensor(np.asarray(
        unpermute_from_match_maxtext_rope(
          training_state.params['params']["decoder"]["layers"]["self_attention"]["key"]["kernel"][:, layer_int, :, :]
          )),
        dtype=torch.bfloat16
    )
    hf_model_params[f"model.layers.{layer_int}.self_attn.k_proj.weight"] = hf_model_params[f"model.layers.{layer_int}.self_attn.k_proj.weight"].view(base_num_query_heads * head_dim, base_num_kv_heads * head_dim).T.reshape(base_num_kv_heads, head_dim // 2, 2, base_num_query_heads * head_dim).transpose(1, 2).reshape(-1 ,base_num_query_heads * head_dim)

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

None yet

5 participants