Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[BUG] Auto-nested tensordict bugs #106

Open
vmoens opened this issue Dec 8, 2022 · 3 comments
Open

[BUG] Auto-nested tensordict bugs #106

vmoens opened this issue Dec 8, 2022 · 3 comments
Assignees
Labels
bug Something isn't working

Comments

@vmoens
Copy link
Contributor

vmoens commented Dec 8, 2022

Describe the bug

Auto-nesting may be a desirable feature (e.g. to build graphs), but currently it is broken for multiple functions, e.g.

tensordict = TensorDict({}, [])
tensordict["self"] = tensordict
print(tensordict)  # fails
tensordict.flatten_keys()  # fails
list(tensordict.keys(include_nested=True))  # fails

Consideration

This is something that should be included in the tests. We could design a special test case in TestTensorDictsBase with a nested self.

## Solution

IMO there is not a single solution to this problem. For repr, could find a way of representing a nested tensordict, something like

TensorDict(fields={
   "self": ...
}, 
batch_size=[])

For keys, we could avoid returning a key if it a key pointing to the same value has already been returned (same for values and items).
For flatten_keys, it should be prohibited for TensorDict. The options are (1) leave it as it is since the maximum recursion already takes care of it or (2) build a wrapper around flatten_keys() to detect if the same method (i.e. the same call to the same method from the same class) is occurring twice, something like

def detect_self_nesting(fun):
    def new_fun(self, *args, **kwargs):
        if fun in self._being_called:
             raise RuntimeError
        self._being_called.append(fun)
        out = fun(self, *args, **kwargs)
        self._being_called.pop()
        return out
    return new_fun

There are probably other properties that I'm missing, but i'd expect them to be covered by the tests if we can design the dedicated test pipeline mentioned earlier.

@vmoens vmoens added the bug Something isn't working label Dec 8, 2022
@vmoens vmoens self-assigned this Dec 8, 2022
Zooll pushed a commit to Zooll/tensordict that referenced this issue Dec 16, 2022
Zooll pushed a commit to Zooll/tensordict that referenced this issue Dec 16, 2022
Zooll pushed a commit to Zooll/tensordict that referenced this issue Dec 16, 2022
Zooll pushed a commit to Zooll/tensordict that referenced this issue Dec 16, 2022
Zooll pushed a commit to Zooll/tensordict that referenced this issue Dec 16, 2022
Zooll pushed a commit to Zooll/tensordict that referenced this issue Dec 16, 2022
Zooll pushed a commit to Zooll/tensordict that referenced this issue Dec 16, 2022
Zooll pushed a commit to Zooll/tensordict that referenced this issue Dec 16, 2022
Zooll pushed a commit to Zooll/tensordict that referenced this issue Dec 30, 2022
Zooll pushed a commit to Zooll/tensordict that referenced this issue Dec 30, 2022
@ruleva1983
Copy link

Working to fix this issue

@vmoens
Copy link
Contributor Author

vmoens commented Feb 2, 2023

I think there are essentially 3 things to consider:
(1) iteration over keys: iterating over nested keys should lead to an infinite recursion, this is fine. The only thing that should work really is asserting if a (nested) key is in the auto-nested tensordict. For this to work, the __contains__ should keep track of what is being / has already been explored to avoid infinite recursion and to make sure that we explore every branch. As soon as the __contains__ reaches a self-nested tensordict it should interrupt the query for that tensordict and pass to the next
(2) tensordict methods that return tensordicts (or better, tensor-to-tensor methods): all methods that return a tensor out of a tensor can be wrapped under a common method IMO.

Here is an example of how to do it with to_tensordict(self) but I guess that a similar, generic solution could be found for apply, __eq__, to(device) etc

        def to_tensordict(
            tensordict, current_key: Tuple = None, being_computed: Dict = None
        ):
            """A version of to_tensordict that supports auto-nesting."""
            out_dict = {}
            if current_key is None:
                current_key = ()
            if being_computed is None:
                being_computed = {}
            being_computed[current_key] = id(tensordict)
            for key, value in tensordict.items():
                if isinstance(value, TensorDictBase):
                    nested_key = current_key + (key,)
                    if id(value) in being_computed.values():
                        being_computed[nested_key] = id(value)
                        continue
                    new_value = to_tensordict(
                        value, current_key=nested_key, being_computed=being_computed
                    )
                else:
                    new_value = value.clone()
                out_dict[key] = new_value
            out = TensorDict(
                out_dict,
                device=self.device,
                batch_size=self.batch_size,
                _run_checks=False,
            )
            for other_nested_key, other_value in being_computed.items():
                if other_nested_key != current_key:
                    if other_value == id(tensordict):
                        out[other_nested_key] = out
            return out

        return to_tensordict(self)

Again, we keep track of what is being processed. If something is being processed, we just ignore that for now and we delay the writing of that thing until completion of the operation on the nested tensordict.
This is "easy" because we know that the tree structure of the output will be similar to the input.

(3) some methods do not return a tensordict of the same structure but other stuff: eg: all() and any() return a boolean, unbind, split, chunk will return a tuple. torch.stack and torch.cat may also pose some challenges.

To resolve this issue, we should approach each problem independently: first the keys, second the tensor-to-tensor methods and lastly the others.

@tcbegley
Copy link
Contributor

tcbegley commented Feb 2, 2023

  1. I think there is no risk of recursion in contains right? For example, if the structure was

    td = TensorDict({"a": torch.rand(10)}, [10])
    td["self"] = td

    If I do something like "b" in td, we don't search for "b" in td["self"] after we fail to find it in td._tensordict right? There is no iteration over keys, you just check if the key is in the underlying dict, or if it's a tuple, if the first entry is in the dict and the remaining entries form a key contained in the value under the first key. The number of recursive calls to __contains__ is bounded by the length of the key.

    The only risk is that the user could do something weird like ("self",) * 1_000_000 + ("a,") in td, but that's on them!

  1. I think this pattern makes sense, though note that item in some_dict.values() is O(n) in the number of entries, so I think it would be better to maintain a set visited of ids along with a dictionary update which maps keys at which auto-nesting is detected to the auto-nested value. That way inside the main loop we can do

    if id(value) in visited:
        update[prefix + (key,)] = value

    then at the end, we repopulate the output with the auto-nested values

    for key, value in update.items():
        out[key] = value

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
bug Something isn't working
Projects
None yet
Development

No branches or pull requests

3 participants