Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[BUG] tensordict.pad_sequence silently ignores non-tensor attributes in tensorclasses or TensorDicts #783

Open
3 tasks done
egaznep opened this issue May 18, 2024 · 4 comments · Fixed by #784
Open
3 tasks done
Assignees
Labels
bug Something isn't working

Comments

@egaznep
Copy link

egaznep commented May 18, 2024

Describe the bug

I have some tensorclasses that store an audio file, with some metadata including speaker id and utterance id. I would like to collate these tensorclasses to form a batch, however when I do so, the metadata is lost (the metadata from the first tensordict is kept for every item in the batch) and the user is not warned about this either.

To Reproduce

Steps to reproduce the behavior.

from tensordict import pad_sequence, TensorDict

d1 = TensorDict({'a': torch.tensor([0]), 'b': ['asd']})
d2 = TensorDict({'a': torch.tensor([0]), 'b': ['efg']})

pad_sequence([d1, d2])
TensorDict(
    fields={
        a: Tensor(shape=torch.Size([2, 1]), device=cpu, dtype=torch.int64, is_shared=False),
        b: NonTensorData(data=asd, batch_size=torch.Size([2]), device=None)},
    batch_size=torch.Size([2]),
    device=None,
    is_shared=False)

Expected behavior

I should either get a properly joined tensordict, e.g.,

TensorDict(
    fields={
        a: Tensor(shape=torch.Size([2, 1]), device=cpu, dtype=torch.int64, is_shared=False),
        b: NonTensorData(data=['asd', 'efg'], batch_size=torch.Size([2]), device=None)}, # in a list same shape as the batch_size
    batch_size=torch.Size([2]),
    device=None,
    is_shared=False)

or tensordict.pad_sequence should warn the user that the metadata is being discarded.

Screenshots

System info

tensordict-nightly              2024.5.18

Additional context

Add any other context about the problem here.

Reason and Possible fixes

If you know or suspect the reason for this bug, paste the code lines and suggest modifications.

Checklist

  • I have checked that there is no similar issue in the repo (required)
  • I have read the documentation (required)
  • I have provided a minimal working example to reproduce the bug (required)
@egaznep egaznep added the bug Something isn't working label May 18, 2024
@vmoens
Copy link
Contributor

vmoens commented May 20, 2024

If you pass a list it will be cast to numpy ndarray (this is something we can re-consider in the future)
But if you use a plain string the following code will do what you want I think (given #784)

from tensordict import pad_sequence, TensorDict
import torch

d1 = TensorDict({'a': torch.tensor([1, 1]), 'b': 'asd'})
d2 = TensorDict({'a': torch.tensor([2]), 'b': 'efg'})

print(d1['b'])
print(pad_sequence([d1, d2]))
print(pad_sequence([d1, d2])['b'])

@vmoens vmoens linked a pull request May 20, 2024 that will close this issue
@egaznep
Copy link
Author

egaznep commented May 20, 2024

If you pass a list it will be cast to numpy ndarray (this is something we can re-consider in the future) But if you use a plain string the following code will do what you want I think (given #784)

from tensordict import pad_sequence, TensorDict
import torch

d1 = TensorDict({'a': torch.tensor([1, 1]), 'b': 'asd'})
d2 = TensorDict({'a': torch.tensor([2]), 'b': 'efg'})

print(d1['b'])
print(pad_sequence([d1, d2]))
print(pad_sequence([d1, d2])['b'])

Tested this and indeed, it works! Thank you for the quick and neat fix 🙂 Would this change make its way into the next nightly release?

@vmoens
Copy link
Contributor

vmoens commented May 22, 2024

Sorry I dropped the ball on this :(
The PR is almost ready but there's some non trivial issue with Peristent (H5) tensordicts that need to be solved before merging. I'll do my best to do it today!

@egaznep
Copy link
Author

egaznep commented May 23, 2024

Hi again,

I noticed that this doesn't work for tensorclasses.

MWE:

@tensorclass
class Sample:
    a: torch.Tensor
    b: str

d1 = Sample(**{'a': torch.tensor([1, 1]), 'b': 'asd'}, batch_size=[])
d2 = Sample(**{'a': torch.tensor([2]), 'b': 'efg'}, batch_size=[])
print(pad_sequence([d1, d2])[1].b) # gives you 'asd' and not 'efg'

@vmoens vmoens reopened this May 25, 2024
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
bug Something isn't working
Projects
None yet
Development

Successfully merging a pull request may close this issue.

2 participants