Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Feature suggestion] Optionally return the inferred dimensions #321

Open
ml-w opened this issue May 14, 2024 · 0 comments
Open

[Feature suggestion] Optionally return the inferred dimensions #321

ml-w opened this issue May 14, 2024 · 0 comments

Comments

@ml-w
Copy link

ml-w commented May 14, 2024

Description

Einops calculates automatically some of the inferred dimension. For example, in ViT, Einops is great for decomposing a tensor into stacked patches like this:

x = torch.ones([1, 1, 256, 256]) 
patches = rearrange(x, 'b c (nx px) (ny py) -> b (c nx ny) px py', px=32, py=32)

This decompose the input 256 x 256 image into a stack of 32 x 32 patches, with the number of patches nx and ny calculated automatically.

Now the calculated nx and ny sometimes need to be reuse. When there's this requirement, we generally need to calculate it ourselves before invoking rearrange, which IMO defeat the purpose of using rearrange. It would be great if there's a function that returns the inferred dimensions as a dictionary when using Einops.

Usecase

Sometimes we might want to reuse the calculated dimensions that might be variables. Especially important when the input is a sequence with variable length. Usually, people pad it to a standard length, but this standard length can also be a variable (set to be divisible by a fix number [e.g., padded to a multiple of 64]). This proposed function will be useful in these usecases.

Examples implementation

I think the implementation is straight forward, either return as a dictionary when a flag is set to True, or populate an input dictionary with the calculated infer axis symbols.

x = torch.ones([1, 1, 256, 256])
# Add optional flag `return_dims`
patches, dims = rearrange(x, 'b c (nx px) (ny py) -> b (c nx ny) px py', px=32, py=32, return_dims=True)
# Add optional input dictionary for populating
dims ={}
patches = rearrange(x, 'b c (nx px) (ny py) -> b (c nx ny) px py', px=32, py=32, return_dims_to_dict=dims)
# dims = {
  'b': 1, 
  'c': 1, 
  'nx': 8, 
  'ny': 8,
  'px': 32,
  'py': 32
}

Something like that.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

No branches or pull requests

1 participant