Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Allow applying on all modules, not just immediate children #10

Open
wants to merge 6 commits into
base: main
Choose a base branch
from

Conversation

hub-bla
Copy link

@hub-bla hub-bla commented Apr 26, 2024

I've made nested module selection based on the way the CSS children selector works. By using '>' we can now select nested modules.

Example:

opts = argparse.Namespace(**{"model.freeze_modules": "model1>ins_model"})

inside_model2 = nn.Sequential(
     OrderedDict([
          ('conv1', nn.Conv2d(1,20,5)),
          ('relu1', nn.ReLU()),
          ('conv2', nn.Conv2d(20,64,5)),
          ('relu2', nn.ReLU())
        ])
)

inside_model = nn.Sequential(
     OrderedDict([
          ('ins_model', inside_model2),
          ('conv1', nn.Conv2d(1,20,5))
        ])
)

model = nn.Sequential(
     OrderedDict([
          ('model1', inside_model),
          ('conv1', nn.Conv2d(20,64,5)),
          ('conv2', nn.Conv2d(20,64,5))
        ])
)

print(freeze_modules_based_on_opts(opts, model))

returns:
example_result

@hub-bla hub-bla changed the title Allow applying on all modules, not just immediate chidren Allow applying on all modules, not just immediate children Apr 26, 2024
@mohammad7t
Copy link
Collaborator

Hi @hub-bla . Thank you for your contribution. Since freeze_modules accepts regex, I'm wondering if a more flexible regex could select nested modules with the existing code? For example, the following regex seems to work for model1>conv1:

import argparse
from collections import OrderedDict
from torch import nn

from corenet.modeling.misc.common import freeze_modules_based_on_opts


opts = argparse.Namespace(**{"model.freeze_modules": r"model1(.*)\.conv1"})

inside_model2 = nn.Sequential(
     OrderedDict([
          ('conv1', nn.Conv2d(1,20,5)),
          ('relu1', nn.ReLU()),
          ('conv2', nn.Conv2d(20,64,5)),
          ('relu2', nn.ReLU())
        ])
)

inside_model = nn.Sequential(
     OrderedDict([
          ('ins_model', inside_model2),
          ('conv1', nn.Conv2d(1,20,5))
        ])
)

model = nn.Sequential(
     OrderedDict([
          ('model1', inside_model),
          ('conv1', nn.Conv2d(20,64,5)),
          ('conv2', nn.Conv2d(20,64,5))
        ])
)

print(freeze_modules_based_on_opts(opts, model))

@mohammad7t mohammad7t added enhancement New feature or request awaiting-response labels May 2, 2024
@hub-bla
Copy link
Author

hub-bla commented May 3, 2024

Hi @mohammad7t, I agree that existing code can support that operations using loop on named_parameters. To be honest I didn't checked if it works before I started implementing the enhancement. I followed a comment that is above the loop with named_children. # TODO: allow applying on all modules, not just immediate chidren? How?
The thing that my code does additionally is that it reduces number of logs. For example, If you want to freeze nested module that is made of a lot of nesting, it won't produce log for every parameter that is going to be freezed. Instead, it will only show that the whole module is now frozen.

To be honest I wonder if that loop with named_modules is even neccessary beacuase everything could by done using as you provided. It would also remove this issue

@mohammad7t mohammad7t self-assigned this May 3, 2024
@mohammad7t
Copy link
Collaborator

I see where you are coming from! I agree that the # TODO: allow applying on all modules, not just immediate chidren? How? is confusing. I think the TODO is related to applying force_eval on nested modules:

if force_eval:
def _force_train_in_eval(
self: torch.nn.Module, mode: bool = True
) -> torch.nn.Module:
# ignore train/eval calls: perpetually stays in eval
return self
module.train = MethodType(_force_train_in_eval, module)

I wonder if that loop with named_modules is even necessary because everything could by done using as you provided.

That's a good question. I think the only reason we need the loop with named_modules is to apply force_eval as mentioned above.

I'm not entirely sure, what the best solution is right now. Let me think a bit more and get back to you. Thinking loudly, I guess we don't need to support the ">" css operator, but the bfs is probably a good idea to address the TODO. What do you think?

Thanks again!

@hub-bla
Copy link
Author

hub-bla commented May 3, 2024

Thank you for clarification! Now, I get it and I agree with everything you said.

Regarding the ">" selector, I'm not sure if it's unnecessary when applying bfs. The way it works now is that the input string is splitted by this symbol and then those chunks that might be a regex expression or not, are then passed to bfs. Without it, we had to split the string by dot symbol which is a special character in regex and this might cause some problems. I'll try to think about that too.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
enhancement New feature or request
Projects
None yet
Development

Successfully merging this pull request may close these issues.

None yet

2 participants