AWS Trainium fails number of device validation when using more than 1 accelerator on the instances #19826

BrianF-tessera opened this issue Apr 29, 2024 · 0 comments
bug Something isn't working needs triage Waiting to be triaged by maintainers ver: 2.0.x ver: 2.1.x


Bug description

in lightning/fabric/accelerators/ there is a _parse_tpu_devices function that hard codes a maximum of 8 devices. in torch_xla/distributed/ there's a validator that allows for either 1 or $WORLD_SIZE devices to be used.

When working with Amazon Trainium, the large trn1.32xlarge instances come equipped with 16 accelerators with 2 cores each for a total of 32 devices. Setting of both 8 and 32 cause a validation error to occur before training starts. This problem is not see when only 1 acclerator is in use since it falls under the <=8 threshold.

What version are you seeing the problem on?

v2.0, v2.1, v2.2

How to reproduce the bug

import lightning as pl
from torch.nn import functional as F
from import DataLoader
from torchvision import transforms
from torchvision.datasets import FashionMNIST
from torchvision import transforms
from torchmetrics import Accuracy
import mlflow
import torch_xla.core.xla_model as xm
import os
import torch
import uuid

class MNISTClassifier(pl.LightningModule):
    def __init__(self):
        super(MNISTClassifier, self).__init__()
        self.accuracy = Accuracy(task="multiclass", num_classes=10, top_k=1)
        self.layer_1_size = 128
        self.layer_2_size = 256 = .01

        # mnist images are (1, 28, 28) (channels, width, height)
        self.layer_1 = torch.nn.Linear(28 * 28, self.layer_1_size)
        self.layer_2 = torch.nn.Linear(self.layer_1_size, self.layer_2_size)
        self.layer_3 = torch.nn.Linear(self.layer_2_size, 10)
        self.eval_loss = []
        self.eval_accuracy = []

    def cross_entropy_loss(self, logits, labels):
        return F.nll_loss(logits, labels)

    def forward(self, x):
        batch_size, channels, width, height = x.size()
        x = x.view(batch_size, -1)

        x = self.layer_1(x)
        x = torch.relu(x)

        x = self.layer_2(x)
        x = torch.relu(x)

        x = self.layer_3(x)
        x = torch.log_softmax(x, dim=1)

        return x

    def training_step(self, train_batch, batch_idx):
        x, y = train_batch
        logits = self.forward(x)
        loss = self.cross_entropy_loss(logits, y)
        accuracy = self.accuracy(logits, y)

        self.log("train_loss", loss)
        self.log("train_accuracy", accuracy)
        return loss

    def validation_step(self, val_batch, batch_idx):
        x, y = val_batch
        logits = self.forward(x)
        loss = self.cross_entropy_loss(logits, y)
        accuracy = self.accuracy(logits, y)
        return {"val_loss": loss, "val_accuracy": accuracy}

    def on_validation_epoch_end(self):
        avg_loss = torch.stack(self.eval_loss).mean()
        avg_acc = torch.stack(self.eval_accuracy).mean()
        self.log("val_loss", avg_loss, sync_dist=True)
        self.log("val_accuracy", avg_acc, sync_dist=True)

    def configure_optimizers(self):
        optimizer = torch.optim.Adam(self.parameters(),
        return optimizer

def train_func():

    mnist_train = FashionMNIST('~/torchdata/', train=True, download=True, transform=transforms.ToTensor())
    mnist_train = DataLoader(mnist_train, batch_size=512, num_workers=4)
    mnist_val = FashionMNIST('~/torchdata/', download=True, transform=transforms.ToTensor())
    mnist_val = DataLoader(mnist_val, batch_size=512, num_workers=4)

    # model
    mnist_model = MNISTClassifier()

    trainer = pl.Trainer(devices=$NUM_DEVICES, precision='bf16-true', max_epochs=2), mnist_train, mnist_val)


Error messages and logs

File ~/miniconda3/lib/python3.10/site-packages/torch_xla/distributed/, in _pre_fork_setup(num_devices)
    199   num_devices = dev_count
    200 elif num_devices not in [1, dev_count]:
--> 201   raise ValueError(
    202       'The number of devices must be either 1 or {}, got {} instead'.format(
    203           dev_count, num_devices))
    204 total_devices = _get_world_size() * num_devices
    205 if total_devices > 1 and not os.environ.get(xenv.SERVICE_ADDRESS, None):
    206   # In multi-processing mode, even if there is only one XLA host, we still
    207   # bring up the mesh service.

ValueError: The number of devices must be either 1 or 32, got 8 instead


File ~/miniconda3/lib/python3.10/site-packages/lightning/fabric/accelerators/, in _parse_tpu_devices(devices)
    155     devices = _parse_tpu_devices_str(devices.strip())
    157 if not _tpu_devices_valid(devices):
--> 158     raise TypeError("`devices` can only be 1, 8 or [<1-8>] for TPUs.")
    160 return devices

TypeError: `devices` can only be 1, 8 or [<1-8>] for TPUs.


More info

No response

