You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
Tensorclasses (in contrast to tensordicts) don't list non-tensor data when iterating through it with .keys() or .items(). This also affects e.g. apply().
For keys() and items() I'd expect it to iterate over non-tensor data.
For apply() probably as well, but whether or not it iterates over non-tensor data could also be dependent on an input flag (as one might explicitly want to only apply something to tensors).
Screenshots
If applicable, add screenshots to help explain your problem.
System info
tensordict version: '0.4.0+b4c91e8'
Checklist
I have checked that there is no similar issue in the repo (required)
This used to be a design choice before we introduced NonTensorData
Context
We initially thought it could be interesting to let tensorclass carry non-tensor data, but thought it was better to exclude it from the keys since things like apply or even any other op that iterates through the keys (reshape, gather, ...) would have been meaningless with non-tensor data.
Then we introduces the NonTensorData which is a simple subclass of the a tensorclass that can only carry non-tensor data. If you call tensordict.keys(include_nested=True) you will have a NonTensorData node appearing as if it had no leaves, but in reality it is a leaf. If you call apply over NonTensorData it's ok because it does not access the data field (it's not part of the keys). But becaue we hack through __getitem__ and __setitem__ to access NonTensorData.data, the situation is less clear (one would imagine that the non-tensor data is now part of the keys since we use the key to set or access the value).
Solution
We introduced a is_leaf function in keys, items, values and apply to quickly check if a node is a leaf or not. That allows us to avoid things like reshape to access the non-tensor data thinking it's a leaf while still presenting non-tensor data as leaves to the user.
We should now apply something similar with all tensorclasses: by default, keys will return all the data (tensor and non-tensor) but internally operations will only be applied to tensor-data.
Sounds good to me!
If I understand it correctly then 'real' leafs are tensors and non-tensors are not leafs because they're wrapped in NonTensorData, which is a dataclass-subclass?
As a user, I'd find that slightly confusing, since to me non-tensor data does appear as leafs. But I guess the is_leaf is only an internal implementation details?
Describe the bug
Tensorclasses (in contrast to tensordicts) don't list non-tensor data when iterating through it with
.keys()
or.items()
. This also affects e.g.apply()
.To Reproduce
Expected behavior
For
keys()
anditems()
I'd expect it to iterate over non-tensor data.For
apply()
probably as well, but whether or not it iterates over non-tensor data could also be dependent on an input flag (as one might explicitly want to only apply something to tensors).Screenshots
If applicable, add screenshots to help explain your problem.
System info
tensordict version: '0.4.0+b4c91e8'
Checklist
The text was updated successfully, but these errors were encountered: