Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

RoPE implementation differences #335

Closed
rlrs opened this issue May 15, 2024 · 7 comments
Closed

RoPE implementation differences #335

rlrs opened this issue May 15, 2024 · 7 comments

Comments

@rlrs
Copy link

rlrs commented May 15, 2024

I've been working with the pretrained Llama 3 weights, and found out that the RoPE implementation here does not match the one found in other places. The difference is whether you treat sequential entries of the embeddings as (real, imaginary), or you treat the first half as real, and the second half as imaginary.

The current torchtitan implementation uses the former, while both Transformers and llama.cpp for example use the latter.
This also seems to mean that loading weights from https://huggingface.co/meta-llama/Meta-Llama-3-8B does not work. I've verified numerically that you need to use the latter RoPE implementation to get correct results with existing weights. I'm slightly worried that I'm doing something wrong, but perhaps someone else can verify? I can post some code if that helps.

Here's a small change to apply_rotary_emb which can be used to make it match the cos/sin implementation numerically.

def apply_rotary_emb(
    xq: torch.Tensor,
    xk: torch.Tensor,
    freqs_cis: torch.Tensor,
) -> Tuple[torch.Tensor, torch.Tensor]:
    """
    Apply rotary embeddings to input tensors using the given frequency tensor.

    This function applies rotary embeddings to the given query 'xq' and key 'xk' tensors using the provided
    frequency tensor 'freqs_cis'. The input tensors are reshaped as complex numbers, and the frequency tensor
    is reshaped for broadcasting compatibility. The resulting tensors contain rotary embeddings and are
    returned as real tensors.

    Args:
        xq (torch.Tensor): Query tensor to apply rotary embeddings.
        xk (torch.Tensor): Key tensor to apply rotary embeddings.
        freqs_cis (torch.Tensor): Precomputed frequency tensor for complex exponentials.

    Returns:
        Tuple[torch.Tensor, torch.Tensor]: Tuple of modified query tensor and key tensor with rotary embeddings.
    """
    # first half is real, second half is imaginary
    xq_ = torch.complex(xq[..., :xq.shape[-1] // 2].float(), xq[..., xq.shape[-1] // 2:].float())
    xk_ = torch.complex(xk[..., :xk.shape[-1] // 2].float(), xk[..., xk.shape[-1] // 2:].float())
    freqs_cis = reshape_for_broadcast(freqs_cis, xq_)
    xq_out = torch.view_as_real(xq_ * freqs_cis).flatten(3)
    xk_out = torch.view_as_real(xk_ * freqs_cis).flatten(3)
    
    # added this
    xq_out = torch.cat([xq_out[..., ::2], xq_out[..., 1::2]], dim=-1)
    xk_out = torch.cat([xk_out[..., ::2], xk_out[..., 1::2]], dim=-1)
    
    return xq_out.type_as(xq), xk_out.type_as(xk)
@TJ-Solergibert
Copy link

Hi @rlrs ! Could you share the script to transform the weights from HF to dcp? Thanks!

@tianyu-l
Copy link
Contributor

Hi @rlrs, thanks for bringing up the concern!

We are using the same definition as in llama3 code https://github.com/meta-llama/llama3/blob/main/llama/model.py#L65
Would you provide more details on how you verified the loaded weights to be wrong / correct?

@rlrs
Copy link
Author

rlrs commented May 16, 2024

Hi @rlrs ! Could you share the script to transform the weights from HF to dcp? Thanks!

I'm using a modified script based on gpt-fast, will paste it here.

# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.

# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.
import json
import re
import sys
from pathlib import Path
from safetensors import safe_open
import torch.distributed.checkpoint as DCP

import torch

# support running without installing as a package
wd = Path(__file__).parent.parent.resolve()
sys.path.append(str(wd))

from maester.models import models_config


@torch.inference_mode()
def convert_hf_checkpoint(
    *,
    model_name: str,
    variant: str,
    checkpoint_dir: Path,
    output_dir: Path,
) -> None:
    if model_name is None:
        model_name = checkpoint_dir.name

    config = models_config[model_name][variant]
    print(f"Model config {config.__dict__}")

    # Load the json file containing weight mapping
    model_map_json = checkpoint_dir / "model.safetensors.index.json"

    assert model_map_json.is_file()

    with open(model_map_json) as json_map:
        bin_index = json.load(json_map)

    weight_map = {
        "model.embed_tokens.weight": "tok_embeddings.weight",
        "model.layers.{}.self_attn.q_proj.weight": "layers.{}.attention.wq.weight",
        "model.layers.{}.self_attn.k_proj.weight": "layers.{}.attention.wk.weight",
        "model.layers.{}.self_attn.v_proj.weight": "layers.{}.attention.wv.weight",
        "model.layers.{}.self_attn.o_proj.weight": "layers.{}.attention.wo.weight",
        'model.layers.{}.self_attn.rotary_emb.inv_freq': None,
        'model.layers.{}.mlp.gate_proj.weight': 'layers.{}.feed_forward.w1.weight',
        "model.layers.{}.mlp.up_proj.weight": "layers.{}.feed_forward.w3.weight",
        "model.layers.{}.mlp.down_proj.weight": "layers.{}.feed_forward.w2.weight",
        "model.layers.{}.input_layernorm.weight": "layers.{}.attention_norm.weight",
        "model.layers.{}.post_attention_layernorm.weight": "layers.{}.ffn_norm.weight",
        "model.norm.weight": "norm.weight",
        "lm_head.weight": "output.weight",
    }
    bin_files = {checkpoint_dir / bin for bin in bin_index["weight_map"].values()}

    merged_result = {}
    for file in sorted(bin_files):
        with safe_open(file, framework="pt", device="cpu") as f:
            for k in f.keys():
                merged_result[k] = f.get_tensor(k)
    final_result = {}
    
    for key, value in merged_result.items():
        if "layers" in key:
            abstract_key = re.sub(r'(\d+)', '{}', key)
            layer_num = re.search(r'\d+', key).group(0)
            new_key = weight_map[abstract_key]
            if new_key is None:
                continue
            new_key = new_key.format(layer_num)
        else:
            new_key = weight_map[key]

        final_result[new_key] = value

    output_dir.mkdir(parents=True, exist_ok=True)
    storage_writer = DCP.filesystem.FileSystemWriter(output_dir)
    DCP.save({"model": final_result}, 
             storage_writer=storage_writer)

if __name__ == '__main__':
    import argparse
    parser = argparse.ArgumentParser(description='Convert HuggingFace checkpoint.')
    parser.add_argument('--checkpoint', type=Path, required=True)
    parser.add_argument('--output', type=Path, required=True)
    parser.add_argument('--model', type=str, required=True)
    parser.add_argument('--variant', type=str, required=True)

    args = parser.parse_args()
    convert_hf_checkpoint(
        checkpoint_dir=args.checkpoint,
        output_dir=args.output,
        model_name=args.model,
        variant=args.variant,
    )

Hi @rlrs, thanks for bringing up the concern!

We are using the same definition as in llama3 code https://github.com/meta-llama/llama3/blob/main/llama/model.py#L65 Would you provide more details on how you verified the loaded weights to be wrong / correct?

After the conversion, I verified that all loaded weights match against HF transformers layer by layer. I also verified the input/output matches against HF transformers layer by layer (the only difference is in RoPE) and I manually checked that inference outputs match as well. Here's a snippet to compare attention layers after setting up and loading weights for both models:

    # Compare attention layers
    input_tensor = torch.randn(cfg.batch_size, cfg.seq_len, model_config.dim)
    freqs_cis = model.freqs_cis[0:cfg.seq_len]
    
    for i, (layer, hf_layer) in enumerate(zip(model.layers, hf_model.model.layers)):
        attention_output = layer.attention(layer.attention_norm(input_tensor), freqs_cis)
        hf_attention_output, _, _ = hf_layer.self_attn(
            hf_layer.input_layernorm(input_tensor), 
            position_ids=torch.arange(cfg.seq_len, dtype=torch.long).unsqueeze(0).expand(cfg.batch_size, -1)
        )

        assert torch.allclose(attention_output, hf_attention_output, atol=1e-5), f"Attention layer {i} outputs do not match"

@tianyu-l
Copy link
Contributor

tianyu-l commented May 16, 2024

@rlrs
It seems HF's llama implementation is different from the official llama's. We'll need to understand why that's the case.

HF: https://github.com/huggingface/transformers/blob/main/src/transformers/models/llama/modeling_llama.py#L184
meta-llama: https://github.com/meta-llama/llama3/blob/main/llama/model.py#L65

asked here: huggingface/transformers#30872

@rlrs
Copy link
Author

rlrs commented May 17, 2024

Thanks for asking over there. I didn't try to download the weights from anywhere other than HF, but I would be a bit surprised if there's some simple transformation you can do to the weights to change the rope implementation?
Anyways let's await some information from their side.

@rlrs
Copy link
Author

rlrs commented May 20, 2024

As discussed in the HF issue, there is indeed a permutation of the weights that causes the two implementations to be equivalent. I don't believe anything needs to be done in the torchtitan repo, and if you agree, feel free to close this issue.

@tianyu-l
Copy link
Contributor

@rlrs

This also seems to mean that loading weights from https://huggingface.co/meta-llama/Meta-Llama-3-8B does not work.

If I understand correctly, there are two ways you can download weights from HF. The first way is from the original folder which gives the same weights as downloaded from meta llama website; the second way is through HF api transformers.pipeline which probably does the conversion.

I think torchtitan at least should have code & tutorial to load the original weights. For the second, HF should support the conversion from HF transformer to llama weights.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

3 participants