-
Notifications
You must be signed in to change notification settings - Fork 354
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Indices Operator #1735
base: main
Are you sure you want to change the base?
Indices Operator #1735
Conversation
Codecov ReportAttention: Patch coverage is
Additional details and impacted files@@ Coverage Diff @@
## main #1735 +/- ##
=======================================
Coverage 86.43% 86.44%
=======================================
Files 753 755 +2
Lines 87602 87651 +49
=======================================
+ Hits 75723 75771 +48
- Misses 11879 11880 +1 ☔ View full report in Codecov by Sentry. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks for working on this 🙂
Implementation looks good, some minor comments on form.
@@ -285,6 +285,7 @@ Those operations are only available for `Int` tensors. | |||
| ------------------------------------------------ | ------------------------------------------------------- | | |||
| `tensor.arange(5..10, device) ` | `tensor.arange(start=5, end=10, device=device)` | | |||
| `tensor.arange_step(5..10, 2, device)` | `tensor.arange(start=5, end=10, step=2, device=device)` | | |||
| `tensor.indices(shape, device)` | `torch.meshgrid(tensors)` | |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Two notes:
- Usage should be
Tensor::indices(shape, device)
likeTensor::cat
orTensor::empty
. - Given the implementation here, there is no 1-to-1 equivalent for torch since meshgrid takes tensors as input and not the shape of the desired grid. In this case this is much closer to
numpy.indices
but the indexing is in cartesian space. The actual torch equivalent for the 2D example in your unit tests would be:
yv, xv = torch.meshgrid([torch.arange(2), torch.arange(2)], indexing='xy')
grid = torch.stack((xv, yv), 2)
So in this case, I'm not sure we would have to provide the comparison in the table.
@nathanielsimard what do you think about the naming for this method? (see my comments for possible suggestions) |
/// println!("{}", result); | ||
/// } | ||
/// ``` | ||
pub fn indices<const D2: usize>(shape: Shape<D>, device: &B::Device) -> Tensor<B, D2, Int> { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think we should move this method to the backend API so that backends can optimize it, since calling a lot of arange and repeat can be very expansive for big matrices. We should keep default implementation in the backend definition.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@McArthur-Alford if you're not sure what that means, see for example the narrow
op. It is defined by a default implementation but also overridden by some backends.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
cartesian_grid
sounds good to me. Ill get started on moving it to a backend op.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@McArthur-Alford if you're not sure what that means, see for example the
narrow
op. It is defined by a default implementation but also overridden by some backends.
Just looking for some clarification on this. It seems like the backend narrow op (i'm specifically looking at int_narrow
) isn't actually used. Instead the API has its own definition of narrow that doesnt utilize the backend at all? Obviously you can still call B::int_narrow
, but my assumption would be that the API should be making calls to the backend implementation, which then themselves have a default implementation?
My thought process was that I could just put a int_cartesian_grid
in the int backend (and make necessary adjustments) and then have the api simply call B::int_cartesian_grid
. Maybe that's not the case?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Sorry for the delayed response.
int_narrow
has a default implementation that works across all backends
fn int_narrow<const D: usize>(
tensor: IntTensor<B, D>,
dim: usize,
start: usize,
length: usize,
) -> IntTensor<B, D> {
narrow::<B, D, Int>(tensor, dim, start, length)
}
And for backends that have the op, we override the method (see int_narrow
in burn-candle
or burn-tch
for example).
For this PR, you don't need to override the default implementation for any backend. This can be added later on if required.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I was more so wondering about the tensor API narrow function. That also calls the default narrow, rather than the backend int_narrow/bool_narrow etc. It doesn't seem like the tensor API gets overridden. In that case, overriding the backend op wouldn't actually change what the tensor API is doing right?
Regardless, I should have gotten it working much like narrow, with that last commit.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Crap, you're right! I didn't realize the default narrow was called at the tensor API level.. that means it probably never dispatches to the backend even if it's implemented 😅
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
ah, ok. I thought I was going crazy. Glad I helped find a bug at least!
Alright, everything has been renamed and moved around. Feel free to review, let me know if there are any other changes we would like. I think I moved it into the backend ops correctly (based on narrow). |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yep! Looks good to me. Just a minor comment regarding the API doc.
Otherwise, let's fix the clippy CI.
shape: S, | ||
device: &B::Device, | ||
) -> Tensor<B, D2, Int> { | ||
Tensor::new(B::int_cartesian_grid::<S, D, D2>(shape, device)) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yess 🔥
That should correctly dispatch to the backend implementations as we add them later on
Indices Operator for Int Tensors
Checklist
run-checks all
script has been executed.Related Issues/PRs
None
Changes
Added a indices function for int tensors. This is similar to pytorches
meshgrid
, or numpysindices
functions though with slightly different arrangement. For example, the output ofTensor::<B, 2, Int>::indices::<3>(Shape { dims: [2, 3] }, &device);
would be:Testing
Added a super basic but functional test to make sure indices produces some expected typical outputs.