-
Notifications
You must be signed in to change notification settings - Fork 856
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[Feature] Expose something like custom VJP in Python #1090
Comments
Oops I thought I had exposed that. It might be worth having a high level API for that as well (which is probably why I didn't expose it), maybe implemented in python using the Option a: import mlx.core as mx
@mx.custom_function
def my_addition(a, b):
return a + b
@my_addition.vjp
def my_addition_vjp(primals, outputs, cotangents):
a, b = primals
cotan = cotangents[0]
return cotan * b, cotan * a Option b: import mlx.core as mx
@mx.custom_function
def my_addition(a, b):
def my_vjp(primals, outputs, cotangents):
a, b = primals
cotan = cotangents[0]
return cotan * b, cotan * a
return a + b, my_vjp I lean towards the first one. Obviously names and details need some changing. |
I think the interface your propose is good, but if it becomes pytorch compatible, it might be useful when migrating. Here is an example of a torch.autograd.Function compatible interface: import mlx.core as mx
class MxCustomVjp():
@classmethod
def apply(cls, *args, **kwargs):
ctx = cls()
ctx.save_for_backward(*args)
r = ctx.forward(ctx, *args, **kwargs)
return r if type(r) == tuple else (r,), ctx
def save_for_backward(self, *args):
self.saved_tensors = args
def vjp(self, grad_output):
return self.__class__.backward(self, *grad_output)
@staticmethod
def forward(self, *args, **kwargs):
raise NotImplementedError
@staticmethod
def backward(self, *args, **kwargs):
raise NotImplementedError
class CustomLinear(MxCustomVjp):
@staticmethod
def forward(ctx, x, weight):
ctx.save_for_backward(x, weight)
return x @ weight
@staticmethod
def backward(ctx, grad_output):
x, weight = ctx.saved_tensors
grad_x = grad_output @ weight.T
grad_weight = x.T @ grad_output
return grad_x, grad_weight
pair2x = lambda a, b: (a * 2, b * 2,)
mxMean = lambda a: a.mean()
def test_forward(x, weight):
r_pair2x = pair2x(x, weight)
r_lin, ctx_lin = CustomLinear.apply(*r_pair2x)
r_mean = mxMean(*r_lin)
# r_pair2x is not used in backward
backward_bucket = (r_lin, ctx_lin)
return r_mean, backward_bucket
# I want same process with autograd
def manual_value_and_grad(x, weight):
# forward
r, backward_bucket = test_forward(x, weight)
# backward
_, grad_mean = mx.vjp(mxMean, backward_bucket[0], (mx.ones_like(r),))
grad_lin = backward_bucket[1].vjp(grad_mean)
_, grad = mx.vjp(pair2x, (x, weight,), grad_lin)
return r, grad
x = mx.random.normal((1,3))
weight = mx.random.normal((x.shape[-1],3))
(r, _), grad = mx.value_and_grad(test_forward, [0, 1])(x, weight)
r_m, grad_m = manual_value_and_grad(x, weight)
assert mx.allclose(r, r_m).item()
assert mx.allclose(grad[0], grad_m[0]).item()
assert mx.allclose(grad[1], grad_m[1]).item() |
First, I want to thank you for all the incredible work you have done on MLX. It has been an invaluable tool for my projects. I am currently working on a project that heavily relies on custom gradients, and the custom vjp functionality discussed here would be extremely beneficial for my use case. Are there any plans to implement this feature in the near future? Or could you please let me know if there are any corresponding interfaces in the underlying C++ layer that I could use to implement a simple version of custom vjp? I am using mlx-swift, and I am considering directly utilizing the C++ interfaces to meet my needs. Thank you once again for your hard work and dedication. |
@kemchenj we have a custom_vjp transformation in C++. |
@angeloskath regarding the Python interface. I think in option A is nicer for enabling any transformation to be customized which I think is preferable. I'm not so crazy about the name my_extendable_fun = mx.extension(my_fun)
my_extendable_fun.vjp = ...
my_extendable_fun.eval_cpu = ...
my_extendable_fun.vmap = ... Presumably it will fall back to the default VJP / other transforms if they are not implemented? |
Consider allowing a function to have a custom VJP function attached to it.
The text was updated successfully, but these errors were encountered: