Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Feature] Expose something like custom VJP in Python #1090

Open
awni opened this issue May 8, 2024 · 5 comments
Open

[Feature] Expose something like custom VJP in Python #1090

awni opened this issue May 8, 2024 · 5 comments
Labels
enhancement New feature or request

Comments

@awni
Copy link
Member

awni commented May 8, 2024

Consider allowing a function to have a custom VJP function attached to it.

@awni awni added the enhancement New feature or request label May 9, 2024
@angeloskath
Copy link
Member

Oops I thought I had exposed that. It might be worth having a high level API for that as well (which is probably why I didn't expose it), maybe implemented in python using the custom_vjp transform? Either of the following two feels more natural to use imho

Option a:

import mlx.core as mx

@mx.custom_function
def my_addition(a, b):
    return a + b

@my_addition.vjp
def my_addition_vjp(primals, outputs, cotangents):
    a, b = primals
    cotan = cotangents[0]
    return cotan * b, cotan * a

Option b:

import mlx.core as mx

@mx.custom_function
def my_addition(a, b):
    def my_vjp(primals, outputs, cotangents):
        a, b = primals
        cotan = cotangents[0]
        return cotan * b, cotan * a
    return a + b, my_vjp

I lean towards the first one. Obviously names and details need some changing.

@kaeru-shigure
Copy link

I think the interface your propose is good, but if it becomes pytorch compatible, it might be useful when migrating.
Anyway, I am eagerly waiting for this feature to be implemented.

Here is an example of a torch.autograd.Function compatible interface:

import mlx.core as mx

class MxCustomVjp():
    @classmethod
    def apply(cls, *args, **kwargs):
        ctx = cls()
        ctx.save_for_backward(*args)
        r = ctx.forward(ctx, *args, **kwargs)
        return r if type(r) == tuple else (r,), ctx
    def save_for_backward(self, *args):
        self.saved_tensors = args
    def vjp(self, grad_output):
        return self.__class__.backward(self, *grad_output)
    @staticmethod
    def forward(self, *args, **kwargs):
        raise NotImplementedError
    @staticmethod
    def backward(self, *args, **kwargs):
        raise NotImplementedError

class CustomLinear(MxCustomVjp):
    @staticmethod
    def forward(ctx, x, weight):
        ctx.save_for_backward(x, weight)
        return x @ weight
    @staticmethod
    def backward(ctx, grad_output):
        x, weight = ctx.saved_tensors
        grad_x = grad_output @ weight.T
        grad_weight = x.T @ grad_output
        return grad_x, grad_weight

pair2x = lambda a, b: (a * 2, b * 2,)
mxMean = lambda a: a.mean()
def test_forward(x, weight):
    r_pair2x = pair2x(x, weight)
    r_lin, ctx_lin = CustomLinear.apply(*r_pair2x)
    r_mean = mxMean(*r_lin)
    # r_pair2x is not used in backward
    backward_bucket = (r_lin, ctx_lin)
    return r_mean, backward_bucket

# I want same process with autograd
def manual_value_and_grad(x, weight):
    # forward
    r, backward_bucket = test_forward(x, weight)
    # backward
    _, grad_mean = mx.vjp(mxMean, backward_bucket[0], (mx.ones_like(r),))
    grad_lin = backward_bucket[1].vjp(grad_mean)
    _, grad = mx.vjp(pair2x, (x, weight,), grad_lin)
    return r, grad

x = mx.random.normal((1,3))
weight = mx.random.normal((x.shape[-1],3))
(r, _), grad = mx.value_and_grad(test_forward, [0, 1])(x, weight)
r_m, grad_m = manual_value_and_grad(x, weight)
assert mx.allclose(r, r_m).item()
assert mx.allclose(grad[0], grad_m[0]).item()
assert mx.allclose(grad[1], grad_m[1]).item()

@kemchenj
Copy link

kemchenj commented May 28, 2024

First, I want to thank you for all the incredible work you have done on MLX. It has been an invaluable tool for my projects.

I am currently working on a project that heavily relies on custom gradients, and the custom vjp functionality discussed here would be extremely beneficial for my use case. Are there any plans to implement this feature in the near future?

Or could you please let me know if there are any corresponding interfaces in the underlying C++ layer that I could use to implement a simple version of custom vjp? I am using mlx-swift, and I am considering directly utilizing the C++ interfaces to meet my needs.

Thank you once again for your hard work and dedication.

@awni
Copy link
Member Author

awni commented May 28, 2024

@kemchenj we have a custom_vjp transformation in C++.

@awni
Copy link
Member Author

awni commented May 28, 2024

@angeloskath regarding the Python interface. I think in option A is nicer for enabling any transformation to be customized which I think is preferable.

I'm not so crazy about the name custom_function. Maybe customizable or even extension. Does this make sense:

my_extendable_fun = mx.extension(my_fun)
my_extendable_fun.vjp = ...
my_extendable_fun.eval_cpu = ...
my_extendable_fun.vmap = ...

Presumably it will fall back to the default VJP / other transforms if they are not implemented?

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
enhancement New feature or request
Projects
None yet
Development

No branches or pull requests

4 participants