Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Bug]: 使用amp_master_grad的同时开启recompute,weight没有main_grad #8365

Open
1 task done
Wong4j opened this issue May 6, 2024 · 5 comments
Open
1 task done
Assignees
Labels
bug Something isn't working

Comments

@Wong4j
Copy link
Contributor

Wong4j commented May 6, 2024

软件环境

- paddlepaddle: 
- paddlepaddle-gpu: 2.6
- paddlenlp: 2.7.1.post0

重复问题

  • I have searched the existing issues

错误描述

正常情况下,开启--amp_master_grad后,所有的weight都会有main_grad。
但是当使用recompute=full后,自定义python op 的backward中的weight却没有main_grad。

稳定复现步骤 & 代码

以llama训练为例

  • --amp_master_grad true开启main_grad
  • 设置--recompute true --recompute_granularity full来开启recompute,
  • 设置--enable_linear_fused_grad_add true来调用llm/llama/fused_layers.py。因为这个问题是我在开发一个类似linear_fused_grad_add的功能时发现的。

修改fused_layers.py #L32-L41的代码为:

    def forward(ctx, x, weight, bias=None, name=None):
        y = origin_linear(x, weight, bias)

        ctx.save_for_backward(weight)
        ctx.x = x
        ctx.bias = bias
        return y

    @staticmethod
    def backward(ctx, y_grad):
        weight, = ctx.saved_tensor()  #这个weight没有main_grad
        x = ctx.x
        bias = ctx.bias
        if hasattr(weight, "main_grad"):
            print("weight has main_grad")
        else:
            print("weight has no main_grad")

运行llama训练,backward就会报weight没有main_grad

而如果不使用ctx.save_for_backwardctx.saved_tensor(),用ctx.weight=weightweight=ctx.weight替代,则weight会有main_grad。

我debug发现,这大概是因为在开启recompute时,save_for_backward会触发recompute.py#L340这里的拷贝,将weight拷贝给一个名为weight.name+"cpy"的tensor,但并没有拷贝main_grad。

@Wong4j Wong4j added the bug Something isn't working label May 6, 2024
@GuoxiaWang
Copy link
Contributor

        ctx.save_for_backward(weight)
        ctx.x = x
        ctx.bias = bias

这里为什么是拆开写的?试试下面的写法?

ctx.save_for_backward(x, weight, bias) 
x, weight, bias = ctx.saved_tensor()

@Wong4j
Copy link
Contributor Author

Wong4j commented May 7, 2024

@GuoxiaWang 因为这个issue里面我关心的重点是:开启recompute的时候ctx.save_for_backward(weight)这种写法会遇到backward中的weight没有main_grad的问题。

你说的这种写法是fused_layer.py中原本的写法,我也测试过,开启recompute=full后会遇到下面这个奇怪的错误,这就需要开另外一个issue了。

    outputs = model(**inputs)
  File "/usr/local/lib/python3.10/dist-packages/paddle/nn/layer/layers.py", line 1429, in __call__
    return self.forward(*inputs, **kwargs)
  File "/usr/local/lib/python3.10/dist-packages/paddle/distributed/fleet/meta_parallel/meta_parallel_base.py", line 37, in forward
    output = self._layers(*inputs, **kwargs)
  File "/usr/local/lib/python3.10/dist-packages/paddle/nn/layer/layers.py", line 1429, in __call__
    return self.forward(*inputs, **kwargs)
  File "/workspace/PaddleNLP/paddlenlp/transformers/llama/modeling.py", line 1913, in forward
    outputs = self.llama(
  File "/usr/local/lib/python3.10/dist-packages/paddle/nn/layer/layers.py", line 1429, in __call__
    return self.forward(*inputs, **kwargs)
  File "/workspace/PaddleNLP/paddlenlp/transformers/llama/modeling.py", line 1664, in forward
    layer_outputs = self.recompute_training_full(
  File "/workspace/PaddleNLP/paddlenlp/transformers/llama/modeling.py", line 1535, in recompute_training_full
    hidden_states = self.recompute_func(
  File "/usr/local/lib/python3.10/dist-packages/paddle/distributed/fleet/utils/__init__.py", line 142, in recompute
    return fleet.recompute.recompute(function, *args, **kwargs)
  File "/usr/local/lib/python3.10/dist-packages/paddle/distributed/fleet/recompute/recompute.py", line 532, in recompute
    return _recompute_without_reentrant(function, preserve, *args, **kwargs)
  File "/usr/local/lib/python3.10/dist-packages/paddle/distributed/fleet/recompute/recompute.py", line 399, in _recompute_without_reentrant
    outputs = function(*args, **kwargs)
  File "/workspace/PaddleNLP/paddlenlp/transformers/llama/modeling.py", line 1531, in custom_forward
    return module(*inputs)
  File "/usr/local/lib/python3.10/dist-packages/paddle/nn/layer/layers.py", line 1429, in __call__
    return self.forward(*inputs, **kwargs)
  File "/workspace/PaddleNLP/paddlenlp/transformers/llama/modeling.py", line 1228, in forward
    outputs = self.self_attn(
  File "/usr/local/lib/python3.10/dist-packages/paddle/nn/layer/layers.py", line 1429, in __call__
    return self.forward(*inputs, **kwargs)
  File "/workspace/PaddleNLP/paddlenlp/transformers/llama/modeling.py", line 901, in forward
    query_states = self.q_proj(hidden_states)
  File "/usr/local/lib/python3.10/dist-packages/paddle/nn/layer/layers.py", line 1429, in __call__
    return self.forward(*inputs, **kwargs)
  File "/usr/local/lib/python3.10/dist-packages/paddle/distributed/fleet/layers/mpu/mp_layers.py", line 516, in forward
    output_parallel = self.linear(
  File "/workspace/PaddleNLP/llm/fused_layers.py", line 36, in forward
    ctx.save_for_backward(x, weight, bias)
  File "/usr/local/lib/python3.10/dist-packages/paddle/autograd/py_layer.py", line 91, in save_for_backward
    self.container = tensors
ValueError: (InvalidArgument) save_for_backward only support Tensor, list of Tensor, tuple of Tensor. (at /opt/paddle/paddle/paddle/fluid/pybind/eager_py_layer.cc:644)

@Wong4j
Copy link
Contributor Author

Wong4j commented May 9, 2024

更新一下,recompute设置reentrant=True,可以避开这个bug。仅reentrant = False会遇到这个bug。

@Wong4j
Copy link
Contributor Author

Wong4j commented May 9, 2024

更新一下,recompute设置reentrant=True,可以避开这个bug。仅reentrant = False会遇到这个bug。

@Xreki 麻烦帮忙找Paddle这边熟悉recompute的同学看一下

@Xreki
Copy link
Contributor

Xreki commented May 9, 2024

![image](https://github.com@Wong4j PaddlePaddle/PaddleNLP/assets/12538138/84258d77-048e-41a2-9641-6d7a303ba6bf)

@Wong4j 这个倒是reentrant=False时的已知问题

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
bug Something isn't working
Projects
None yet
Development

No branches or pull requests

4 participants