Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Stochastic Weight Averaging #700

Open
WillCastle opened this issue Sep 27, 2020 · 10 comments
Open

Stochastic Weight Averaging #700

WillCastle opened this issue Sep 27, 2020 · 10 comments

Comments

@WillCastle
Copy link

WillCastle commented Sep 27, 2020

PyTorch recently added methods to implement Stochastic Weight Averaging (SWA):
[(https://pytorch.org/blog/pytorch-1.6-now-includes-stochastic-weight-averaging/)]

This method can improve many models' performance by creating a new model with weights that are averaged over the last few training epochs. Paper here:
[(https://arxiv.org/abs/1803.05407)]

The PyTorch implementation requires calling methods within a training loop but I wanted to use SWA with a Skorch network so I wrote a callback to do it. I wondered if this would be of some use to others.

`
train_loader, skorch_model = ...

class StochasticWeightAveraging(Callback):

def on_train_begin(self, skorch_model, **kwargs):
    
    skorch_model.swa_model = torch.optim.swa_utils.AveragedModel(skorch_model.module_)

def on_epoch_end(self, skorch_model, **kwargs):
    
    if skorch_model.history[-1, 'epoch'] >= skorch_model.module__swa_start * skorch_model.max_epochs:
        
        skorch_model.swa_model.update_parameters(skorch_model.module_)

def on_train_end(self, skorch_model, **kwargs):
    
    torch.optim.swa_utils.update_bn(train_loader, skorch_model.swa_model, device = skorch_model.device)

`

@BenjaminBossan
Copy link
Collaborator

I didn't know about stochastic weight averaging, thanks a lot. I looked at your code and the PyTorch example and came up with a slightly different implementation based on yours:

from torch.optim import swa_utils

class StochasticWeightAveraging(Callback):
    def __init__(
            self,
            swa_utils,
            swa_start=10,
            verbose=0,
            sink=print,
            **kwargs  # additional arguments to swa_utils.SWALR
    ):
        self.swa_utils = swa_utils
        self.swa_start = swa_start
        self.verbose = verbose
        self.sink = sink
        vars(self).update(kwargs)

    @property
    def kwargs(self):
        # These are the parameters that are passed to SWALR.
        # Parameters that don't belong there must be excluded.
        excluded = {'swa_utils', 'swa_start', 'verbose', 'sink'}
        kwargs = {key: val for key, val in vars(self).items()
                  if not (key in excluded or key.endswith('_'))}
        return kwargs

    def on_train_begin(self, net, **kwargs):
        self.optimizer_swa_ = self.swa_utils.SWALR(net.optimizer_, **self.kwargs)
        if not hasattr(net, 'module_swa_'):
            net.module_swa_ = self.swa_utils.AveragedModel(net.module_)
            
    def on_epoch_begin(self, net, **kwargs):
        if self.verbose and len(net.history) == self.swa_start + 1:
            self.sink("Using SWA to update parameters")

    def on_epoch_end(self, net, **kwargs):
        if len(net.history) >= self.swa_start + 1:
            net.module_swa_.update_parameters(net.module_)
            self.optimizer_swa_.step()

    def on_train_end(self, net, X, y=None, **kwargs):
        if self.verbose:
            self.sink("Using training data to update batch norm statistics of the SWA model")

        loader = net.get_iterator(net.get_dataset(X, y))
        self.swa_utils.update_bn(loader, net.module_swa_, device = net.device)

Let me explain some of the changes:

  • I want to pass swa_utils as a parameter, because this way, the skorch code still works with PyTorch versions < 1.6 ("works" in the sense that there won't be an import error, but this callback still won't be useable); also, a user could in theory provide their own implementations of update_bn etc.

  • I made swa_start a parameter of the callback instead of the module

  • Added verbosity to get a better feel of what's happening

  • Your code has this line: if skorch_model.history[-1, 'epoch'] >= skorch_model.module__swa_start * skorch_model.max_epochs -- I think the logic is wrong there? Why multiply the epochs with swa_start? (Is it meant as a fraction?)

What your example is missing compared to the PyTorch example is the use of swa_utils.SWALR. In my code, I tried to work it in. However, my code differs from the PyTorch example because there, they use SWALR instead of the normal lr scheduler, whereas in my code, it's used in addition.

Below a working example using the callback as implemented above:

import numpy as np
from sklearn.datasets import make_classification
from sklearn.model_selection import train_test_split
import torch
from torch import nn
import torch.nn.functional as F
from torch.optim.lr_scheduler import CosineAnnealingLR
from torch.optim import swa_utils

from skorch import NeuralNetClassifier
from skorch.callbacks import LRScheduler, EpochScoring

X, y = make_classification(1000, 20, n_informative=10, random_state=0)
X, y = X.astype(np.float32), y.astype(np.int64)
X_train, X_test, y_train, y_test = train_test_split(X, y, random_state=0)

SWA_START = 5
MAX_EPOCHS = 100
LR = 0.01
LR_SWA = 0.05

# skorch implementation

class StochasticWeightAveraging(Callback):
    ...

torch.manual_seed(0)
net = NeuralNetClassifier(
    ClassifierModule,
    max_epochs=50,
    lr=LR,
    callbacks=[
        LRScheduler(CosineAnnealingLR, T_max=MAX_EPOCHS),
        StochasticWeightAveraging(swa_utils, swa_start=SWA_START, verbose=1, swa_lr=LR_SWA),
        EpochScoring('accuracy', lower_is_better=False, on_train=True, name='train_acc'),
    ],
    train_split=False,
)
net.fit(X_train, y_train)
test_accuracy = (net.predict(X_test) == y_test).mean()

# PyTorch implementation inspired by linked example

torch.manual_seed(0)
loader = net.get_iterator(net.get_dataset(X, y))
model = ClassifierModule()
optimizer = torch.optim.SGD(model.parameters(), LR)
loss_fn = torch.nn.NLLLoss()

swa_model = swa_utils.AveragedModel(model)
scheduler = CosineAnnealingLR(optimizer, T_max=MAX_EPOCHS)
swa_scheduler = swa_utils.SWALR(optimizer, swa_lr=LR_SWA)

for epoch in range(MAX_EPOCHS):
    losses = []
    for input, target in loader:
        optimizer.zero_grad()
        loss = loss_fn(torch.log(model(input)), target)
        loss.backward()
        optimizer.step()
        losses.append(loss.item())
    if epoch == 1 + SWA_START:
        print("starting SWA")
    if epoch > SWA_START:
        swa_model.update_parameters(model)
        swa_scheduler.step()
    else:
        scheduler.step()

    preds = swa_model(torch.as_tensor(X))
    print("epoch: {:>2} | train loss: {:.4f} | train acc: {:.2f} %".format(
        epoch, np.mean(losses), 100 * (preds.detach().numpy().argmax(-1) == y).mean()))

swa_utils.update_bn(loader, swa_model)
test_accuracy = (swa_model(torch.as_tensor(X_test)).detach().numpy().argmax(-1) == y_test).mean()

The skorch version gets a train loss of 0.548, train accuracy of 0.733, and test accuracy of 0.752.

The PyTorch version gets a train loss of 0.460, train accuracy of 0.727, and test accuracy of 0.736.

So there seems to be a significant difference in train loss, but I'm not sure where it's coming from. It's not due to the described difference, as introducing the same deviation in the PyTorch code doesn't make a difference. Do you have any idea?

@BenjaminBossan
Copy link
Collaborator

@WillCastle Did you have opportunity to test this out yet?

I believe it might not even be necessary store swa_model (module_swa_ in my code) on the net, it can be stored on the callback instead.

@WillCastle
Copy link
Author

@BenjaminBossan Hi, sorry I have been a little caught up in some job applications. I will have a look at this next week. I am not sure about the discrepancy in training loss, I'll run some test cases and try to work out where it's coming from. The changes you proposed look good, as to multiplying by swa_start, I did intend it as a fraction. I only did it this way as most of the examples I have read suggest beginning the averaging at 75% of the way through training, also this might be helpful when running several models for different numbers of epochs (say when tuning hyperparameters). It could probably do with a clearer name if used in this way though.

@WillCastle
Copy link
Author

@BenjaminBossan Just had another look and noticed a couple of things. In your Skorch example, you create and fit the object net but then you check the accuracy of this same network:

test_accuracy = (net.predict(X_test) == y_test).mean()

I believe that the net object remains unchanged by SWA and the one we want to evaluate is net.module_swa_ which is actually a new model that is a sort of ensemble of some of the training iterations of net.

The SWA model is a PyTorch module so I follow it's creation with a conversion to a Skorch model with something like NeuralNetBinaryClassifier(module = net.module_swa_).
Also it looks like your skorch network is training for 50 epochs as you initialize it with max_epochs=50,, whereas the torch network trains using your MAX_EPOCHS = 100 variable. That might be the reason for the difference in training loss.

@BenjaminBossan
Copy link
Collaborator

Thanks for taking another look @WillCastle

as to multiplying by swa_start, I did intend it as a fraction

Okay, this makes sense. I would probably allow both possibilities: if int, take it as absolute value, if float, as relative value. This is consistent with how sklearn works in some places, e.g. the train_size argument in train_test_split.

I believe that the net object remains unchanged by SWA and the one we want to evaluate is net.module_swa_ which is actually a new model that is a sort of ensemble of some of the training iterations of net.

Yes, you're right; to be more precise, it's not the net object, but the net.module_ (which is the PyTorch module itself).

Unfortunately, that's not the reason for the discrepancy. I tested both the original module_ and the new module_swa_ and they still give different results for the skorch and the pure PyTorch implementation (I made sure to fix the seeds and use the exact same data loader):

  PyTorch skorch
test accuracy module 0.752 0.772
test accuracy swa 0.704 0.760

@useric
Copy link

useric commented Oct 10, 2020

``

Thanks for taking another look @WillCastle

as to multiplying by swa_start, I did intend it as a fraction

Okay, this makes sense. I would probably allow both possibilities: if int, take it as absolute value, if float, as relative value. This is consistent with how sklearn works in some places, e.g. the train_size argument in train_test_split.

I believe that the net object remains unchanged by SWA and the one we want to evaluate is net.module_swa_ which is actually a new model that is a sort of ensemble of some of the training iterations of net.

Yes, you're right; to be more precise, it's not the net object, but the net.module_ (which is the PyTorch module itself).

Unfortunately, that's not the reason for the discrepancy. I tested both the original module_ and the new module_swa_ and they still give different results for the skorch and the pure PyTorch implementation (I made sure to fix the seeds and use the exact same data loader):
PyTorch skorch
test accuracy module 0.752 0.772
test accuracy swa 0.704 0.760

quite different validation results between pytorch and skorch. Did skorch some weight initialization automatically by default?

@BenjaminBossan
Copy link
Collaborator

quite different validation results between pytorch and skorch.

Yes, I need to investigate further, or perhaps someone else can spot a mistake.

Did skorch some weight initialization automatically by default?

No, this is left completely to the user. The module is initialized exactly the same, as well as the data loader.

@MohammadJavadD
Copy link

MohammadJavadD commented Mar 27, 2023

SkorchAttributeError: Trying to set torch compoment 'module_swa_' outside of an initialize method. Consider defining it inside 'initialize_module' I'm getting this error! any update on this issue? @BenjaminBossan

@BenjaminBossan
Copy link
Collaborator

I assume you have used the code I posted above and now encountered this error. In that case, could you please replace the line:

net.module_swa_ = self.swa_utils.AveragedModel(net.module_)

by

with net._current_init_context('module'):
    net.module_swa_ = self.swa_utils.AveragedModel(net.module_)

and see if that fixes the issue?

@MohammadJavadD
Copy link

@BenjaminBossan Yes it fixed the issue.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

No branches or pull requests

4 participants