Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Allow applying on all modules, not just immediate children #10

Closed
wants to merge 6 commits into from
Closed
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
39 changes: 34 additions & 5 deletions corenet/modeling/misc/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
import argparse
import os
import re
from collections import deque
from types import MethodType
from typing import Dict, List, Optional, Tuple, Union

Expand Down Expand Up @@ -209,19 +210,35 @@ def _force_train_in_eval(

return module

def _module_bfs(name:str, module:torch.nn.Module, p: list[str], idx=1) -> None:
stack = deque()
stack.append((idx, name, module))

while stack:
idx, path, module = stack.popleft()
if idx<len(p):
for submodule_name, submodule in module.named_children():
if re.match(p[idx], submodule_name):
if idx == len(p)-1:
freeze_module(submodule)
logger.info("Freezing module: {} Inside: {}".format(submodule_name, path))
else:
stack.append((idx+1, path+f">{submodule_name}",submodule))


def freeze_modules_based_on_opts(
opts: argparse.Namespace, model: torch.nn.Module, verbose: bool = True
) -> torch.nn.Module:
"""
Allows for freezing immediate modules and parameters of the model using --model.freeze-modules.
Allows for freezing immediate modules and parameters as well as nested modules of the model using --model.freeze-modules.

--model.freeze-modules should be a list of strings or a comma-separated list of regex expressions.
--model.freeze-modules should be a list of strings, a comma-separated list of regex expressions or list of strings with '>' between modules to freeze particular nested layers inside immediate module of the model.

Examples of --model.freeze-modules:
"conv.*" # see example below: can freeze all (top-level) conv layers
"^((?!classifier).)*$" # freezes everything except for "classifier": useful for linear probing
"conv1,layer1,layer2,layer3" # freeze all layers up to layer3
"transformer>decoder" # freeze decoder block inside transformer

>>> model = nn.Sequential(OrderedDict([
('conv1', nn.Conv2d(1, 20, 5)),
Expand All @@ -244,15 +261,27 @@ def freeze_modules_based_on_opts(
verbose = verbose and is_master(opts)

if freeze_patterns:
# TODO: allow applying on all modules, not just immediate chidren? How?
immediate_children_patterns = []
nested_modules_patterns = []
# separate nested expressions from the rest
for p in freeze_patterns:
if ">" in p:
nested_modules_patterns.append([part for part in re.split(r'\s*>\s*', p) if part.strip()])
else:
immediate_children_patterns.append(p)

for name, module in model.named_children():
if any([re.match(p, name) for p in freeze_patterns]):
for p in nested_modules_patterns:
if re.match(p[0], name):
_module_bfs(name, module, p, 1)

if any([re.match(p, name) for p in immediate_children_patterns]):
freeze_module(module)
if verbose:
logger.info("Freezing module: {}".format(name))

for name, param in model.named_parameters():
if any([re.match(p, name) for p in freeze_patterns]):
if any([re.match(p, name) for p in immediate_children_patterns]):
param.requires_grad = False
if verbose:
logger.info("Freezing parameter: {}".format(name))
Expand Down