Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

AWS Trainium fails number of device validation when using more than 1 accelerator on the instances #19826

Open
BrianF-tessera opened this issue Apr 29, 2024 · 0 comments
Labels
bug Something isn't working needs triage Waiting to be triaged by maintainers ver: 2.0.x ver: 2.1.x

Comments

@BrianF-tessera
Copy link

Bug description

in lightning/fabric/accelerators/tpu.py there is a _parse_tpu_devices function that hard codes a maximum of 8 devices. in torch_xla/distributed/xla_multiprocessing.py there's a validator that allows for either 1 or $WORLD_SIZE devices to be used.

When working with Amazon Trainium, the large trn1.32xlarge instances come equipped with 16 accelerators with 2 cores each for a total of 32 devices. Setting of both 8 and 32 cause a validation error to occur before training starts. This problem is not see when only 1 acclerator is in use since it falls under the <=8 threshold.

What version are you seeing the problem on?

v2.0, v2.1, v2.2

How to reproduce the bug

import lightning as pl
from torch.nn import functional as F
from torch.utils.data import DataLoader
from torchvision import transforms
from torchvision.datasets import FashionMNIST
from torchvision import transforms
from torchmetrics import Accuracy
import mlflow
import torch_xla.core.xla_model as xm
import os
import torch
import uuid


class MNISTClassifier(pl.LightningModule):
    def __init__(self):
        super(MNISTClassifier, self).__init__()
        self.accuracy = Accuracy(task="multiclass", num_classes=10, top_k=1)
        self.layer_1_size = 128
        self.layer_2_size = 256
        self.lr = .01

        # mnist images are (1, 28, 28) (channels, width, height)
        self.layer_1 = torch.nn.Linear(28 * 28, self.layer_1_size)
        self.layer_2 = torch.nn.Linear(self.layer_1_size, self.layer_2_size)
        self.layer_3 = torch.nn.Linear(self.layer_2_size, 10)
        self.eval_loss = []
        self.eval_accuracy = []

    def cross_entropy_loss(self, logits, labels):
        return F.nll_loss(logits, labels)

    def forward(self, x):
        batch_size, channels, width, height = x.size()
        x = x.view(batch_size, -1)

        x = self.layer_1(x)
        x = torch.relu(x)

        x = self.layer_2(x)
        x = torch.relu(x)

        x = self.layer_3(x)
        x = torch.log_softmax(x, dim=1)

        return x

    def training_step(self, train_batch, batch_idx):
        x, y = train_batch
        logits = self.forward(x)
        loss = self.cross_entropy_loss(logits, y)
        accuracy = self.accuracy(logits, y)

        self.log("train_loss", loss)
        self.log("train_accuracy", accuracy)
        return loss

    def validation_step(self, val_batch, batch_idx):
        x, y = val_batch
        logits = self.forward(x)
        loss = self.cross_entropy_loss(logits, y)
        accuracy = self.accuracy(logits, y)
        self.eval_loss.append(loss)
        self.eval_accuracy.append(accuracy)
        return {"val_loss": loss, "val_accuracy": accuracy}

    def on_validation_epoch_end(self):
        avg_loss = torch.stack(self.eval_loss).mean()
        avg_acc = torch.stack(self.eval_accuracy).mean()
        self.log("val_loss", avg_loss, sync_dist=True)
        self.log("val_accuracy", avg_acc, sync_dist=True)
        self.eval_loss.clear()
        self.eval_accuracy.clear()

    def configure_optimizers(self):
        optimizer = torch.optim.Adam(self.parameters(), lr=self.lr)
        return optimizer


def train_func():

    mnist_train = FashionMNIST('~/torchdata/', train=True, download=True, transform=transforms.ToTensor())
    mnist_train = DataLoader(mnist_train, batch_size=512, num_workers=4)
    mnist_val = FashionMNIST('~/torchdata/', download=True, transform=transforms.ToTensor())
    mnist_val = DataLoader(mnist_val, batch_size=512, num_workers=4)

    # model
    mnist_model = MNISTClassifier()

    trainer = pl.Trainer(devices=$NUM_DEVICES, precision='bf16-true', max_epochs=2)
    trainer.fit(mnist_model, mnist_train, mnist_val)

train_func()

Error messages and logs

File ~/miniconda3/lib/python3.10/site-packages/torch_xla/distributed/xla_multiprocessing.py:201, in _pre_fork_setup(num_devices)
    199   num_devices = dev_count
    200 elif num_devices not in [1, dev_count]:
--> 201   raise ValueError(
    202       'The number of devices must be either 1 or {}, got {} instead'.format(
    203           dev_count, num_devices))
    204 total_devices = _get_world_size() * num_devices
    205 if total_devices > 1 and not os.environ.get(xenv.SERVICE_ADDRESS, None):
    206   # In multi-processing mode, even if there is only one XLA host, we still
    207   # bring up the mesh service.

ValueError: The number of devices must be either 1 or 32, got 8 instead

and

File ~/miniconda3/lib/python3.10/site-packages/lightning/fabric/accelerators/tpu.py:158, in _parse_tpu_devices(devices)
    155     devices = _parse_tpu_devices_str(devices.strip())
    157 if not _tpu_devices_valid(devices):
--> 158     raise TypeError("`devices` can only be 1, 8 or [<1-8>] for TPUs.")
    160 return devices

TypeError: `devices` can only be 1, 8 or [<1-8>] for TPUs.

Environment

  • CUDA:
    - GPU: None
    - available: False
    - version: 11.7
  • Lightning:
    - lightning: 2.2.3
    - lightning-cloud: 0.5.68
    - lightning-utilities: 0.11.2
    - pytorch-lightning: 2.2.3
    - torch: 1.13.0
    - torch-neuronx: 1.13.1.1.14.0
    - torch-xla: 1.13.1+torchneurone
    - torchmetrics: 1.3.2
    - torchvision: 0.14.0
  • Packages:
    - absl-py: 2.1.0
    - aiohttp: 3.9.5
    - aiohttp-cors: 0.7.0
    - aiosignal: 1.3.1
    - alembic: 1.13.1
    - anaconda-anon-usage: 0.4.4
    - aniso8601: 9.0.1
    - annotated-types: 0.6.0
    - anyio: 4.3.0
    - archspec: 0.2.3
    - argon2-cffi: 23.1.0
    - argon2-cffi-bindings: 21.2.0
    - arrow: 1.3.0
    - asttokens: 2.4.1
    - async-lru: 2.0.4
    - async-timeout: 4.0.3
    - attrs: 23.2.0
    - aws-neuronx-runtime-discovery: 2.9
    - babel: 2.14.0
    - beautifulsoup4: 4.12.3
    - bio: 1.7.0
    - biopython: 1.83
    - biothings-client: 0.3.1
    - bleach: 6.1.0
    - blessed: 1.20.0
    - blinker: 1.8.1
    - boltons: 23.0.0
    - boto3: 1.34.93
    - botocore: 1.34.93
    - brotli: 1.0.9
    - cachetools: 5.3.3
    - certifi: 2024.2.2
    - cffi: 1.16.0
    - charset-normalizer: 2.0.4
    - click: 8.1.7
    - cloud-tpu-client: 0.10
    - cloudpickle: 3.0.0
    - colorful: 0.5.6
    - comm: 0.2.2
    - conda: 24.4.0
    - conda-content-trust: 0.2.0
    - conda-libmamba-solver: 24.1.0
    - conda-package-handling: 2.2.0
    - conda-package-streaming: 0.9.0
    - contourpy: 1.2.1
    - croniter: 1.3.15
    - cryptography: 42.0.5
    - cycler: 0.12.1
    - datasets: 2.19.0
    - dateutils: 0.6.12
    - debugpy: 1.8.1
    - decorator: 5.1.1
    - deepdiff: 7.0.1
    - defusedxml: 0.7.1
    - deprecated: 1.2.14
    - dill: 0.3.8
    - distlib: 0.3.8
    - distro: 1.8.0
    - dm-tree: 0.1.8
    - docker: 7.0.0
    - docutils: 0.21.2
    - ec2-metadata: 2.10.0
    - editor: 1.6.6
    - entrypoints: 0.4
    - exceptiongroup: 1.2.1
    - executing: 2.0.1
    - farama-notifications: 0.0.4
    - fastapi: 0.88.0
    - fastjsonschema: 2.19.1
    - filelock: 3.14.0
    - flask: 3.0.3
    - fonttools: 4.51.0
    - fqdn: 1.5.1
    - frozenlist: 1.4.1
    - fsspec: 2023.12.2
    - gitdb: 4.0.11
    - gitpython: 3.1.43
    - google-api-core: 1.34.1
    - google-api-python-client: 1.8.0
    - google-auth: 2.29.0
    - google-auth-httplib2: 0.2.0
    - googleapis-common-protos: 1.63.0
    - gprofiler-official: 1.0.0
    - graphene: 3.3
    - graphql-core: 3.2.3
    - graphql-relay: 3.2.0
    - greenlet: 3.0.3
    - grpcio: 1.62.2
    - gunicorn: 21.2.0
    - gymnasium: 0.28.1
    - h11: 0.14.0
    - httpcore: 1.0.5
    - httplib2: 0.22.0
    - httptools: 0.6.1
    - httpx: 0.27.0
    - huggingface-hub: 0.22.2
    - idna: 3.7
    - imageio: 2.34.1
    - importlib-metadata: 7.0.0
    - inquirer: 3.2.4
    - ipykernel: 6.29.4
    - ipython: 8.24.0
    - ipywidgets: 8.1.2
    - islpy: 2023.1
    - isoduration: 20.11.0
    - itsdangerous: 2.2.0
    - jax-jumpy: 1.0.0
    - jedi: 0.19.1
    - jinja2: 3.1.3
    - jmespath: 1.0.1
    - joblib: 1.4.0
    - json5: 0.9.25
    - jsonpatch: 1.33
    - jsonpointer: 2.1
    - jsonschema: 4.21.1
    - jsonschema-specifications: 2023.12.1
    - jupyter: 1.0.0
    - jupyter-client: 8.6.1
    - jupyter-console: 6.6.3
    - jupyter-core: 5.7.2
    - jupyter-events: 0.10.0
    - jupyter-lsp: 2.2.5
    - jupyter-server: 2.14.0
    - jupyter-server-terminals: 0.5.3
    - jupyterlab: 4.1.8
    - jupyterlab-pygments: 0.3.0
    - jupyterlab-server: 2.27.1
    - jupyterlab-widgets: 3.0.10
    - kiwisolver: 1.4.5
    - lazy-loader: 0.4
    - libmambapy: 1.5.8
    - libneuronxla: 0.5.971
    - lightning: 2.2.3
    - lightning-cloud: 0.5.68
    - lightning-utilities: 0.11.2
    - linkify-it-py: 2.0.3
    - lockfile: 0.12.2
    - lz4: 4.3.3
    - mako: 1.3.3
    - markdown: 3.6
    - markdown-it-py: 3.0.0
    - markupsafe: 2.1.5
    - matplotlib: 3.8.4
    - matplotlib-inline: 0.1.7
    - mdit-py-plugins: 0.4.0
    - mdurl: 0.1.2
    - memray: 1.12.0
    - menuinst: 2.0.2
    - mistune: 3.0.2
    - mlflow: 2.12.1
    - mpmath: 1.3.0
    - msgpack: 1.0.8
    - multidict: 6.0.5
    - multiprocess: 0.70.16
    - mygene: 3.2.2
    - nbclient: 0.10.0
    - nbconvert: 7.16.4
    - nbformat: 5.10.4
    - nest-asyncio: 1.6.0
    - networkx: 2.6.3
    - neuronx-cc: 2.13.72.0+78a426937
    - notebook: 7.1.3
    - notebook-shim: 0.2.4
    - numpy: 1.25.2
    - nvidia-cublas-cu11: 11.10.3.66
    - nvidia-cublas-cu12: 12.1.3.1
    - nvidia-cuda-cupti-cu12: 12.1.105
    - nvidia-cuda-nvrtc-cu11: 11.7.99
    - nvidia-cuda-nvrtc-cu12: 12.1.105
    - nvidia-cuda-runtime-cu11: 11.7.99
    - nvidia-cuda-runtime-cu12: 12.1.105
    - nvidia-cudnn-cu11: 8.5.0.96
    - nvidia-cudnn-cu12: 8.9.2.26
    - nvidia-cufft-cu12: 11.0.2.54
    - nvidia-curand-cu12: 10.3.2.106
    - nvidia-cusolver-cu12: 11.4.5.107
    - nvidia-cusparse-cu12: 12.1.0.106
    - nvidia-nccl-cu12: 2.20.5
    - nvidia-nvjitlink-cu12: 12.4.127
    - nvidia-nvtx-cu12: 12.1.105
    - oauth2client: 4.1.3
    - opencensus: 0.11.4
    - opencensus-context: 0.1.3
    - opentelemetry-api: 1.24.0
    - opentelemetry-exporter-otlp: 1.24.0
    - opentelemetry-exporter-otlp-proto-common: 1.24.0
    - opentelemetry-exporter-otlp-proto-grpc: 1.24.0
    - opentelemetry-exporter-otlp-proto-http: 1.24.0
    - opentelemetry-proto: 1.24.0
    - opentelemetry-sdk: 1.24.0
    - opentelemetry-semantic-conventions: 0.45b0
    - ordered-set: 4.1.0
    - overrides: 7.7.0
    - packaging: 23.2
    - pandas: 2.2.2
    - pandocfilters: 1.5.1
    - parso: 0.8.4
    - pexpect: 4.9.0
    - pgzip: 0.3.5
    - pillow: 10.3.0
    - pip: 23.3.1
    - platformdirs: 3.10.0
    - pluggy: 1.0.0
    - polars: 0.20.23
    - pooch: 1.8.1
    - prometheus-client: 0.20.0
    - prompt-toolkit: 3.0.43
    - proto-plus: 1.23.0
    - protobuf: 3.19.6
    - psutil: 5.9.8
    - ptyprocess: 0.7.0
    - pure-eval: 0.2.2
    - py-spy: 0.3.14
    - pyarrow: 15.0.2
    - pyarrow-hotfix: 0.6
    - pyasn1: 0.6.0
    - pyasn1-modules: 0.4.0
    - pycosat: 0.6.6
    - pycparser: 2.21
    - pydantic: 1.10.15
    - pydantic-core: 2.18.2
    - pygments: 2.17.2
    - pyjwt: 2.8.0
    - pyparsing: 3.1.2
    - pysocks: 1.7.1
    - python-daemon: 3.0.1
    - python-dateutil: 2.9.0.post0
    - python-dotenv: 1.0.1
    - python-json-logger: 2.0.7
    - python-multipart: 0.0.9
    - pytorch-lightning: 2.2.3
    - pytz: 2024.1
    - pyyaml: 6.0.1
    - pyzmq: 26.0.2
    - qtconsole: 5.5.1
    - qtpy: 2.4.1
    - querystring-parser: 1.2.4
    - ray: 2.12.0
    - ray-cpp: 2.12.0
    - readchar: 4.0.6
    - referencing: 0.35.0
    - regex: 2024.4.28
    - requests: 2.31.0
    - requests-unixsocket: 0.3.0
    - rfc3339-validator: 0.1.4
    - rfc3986-validator: 0.1.1
    - rich: 13.7.1
    - rpds-py: 0.18.0
    - rsa: 4.9
    - ruamel.yaml: 0.17.21
    - ruamel.yaml.clib: 0.2.6
    - runs: 1.2.2
    - s3transfer: 0.10.1
    - safetensors: 0.4.3
    - scikit-image: 0.23.2
    - scikit-learn: 1.4.2
    - scipy: 1.11.2
    - send2trash: 1.8.3
    - setuptools: 68.2.2
    - shellingham: 1.5.4
    - six: 1.16.0
    - smart-open: 7.0.4
    - smmap: 5.0.1
    - sniffio: 1.3.1
    - soupsieve: 2.5
    - sqlalchemy: 2.0.29
    - sqlparse: 0.5.0
    - stack-data: 0.6.3
    - starlette: 0.22.0
    - starsessions: 1.3.0
    - sympy: 1.12
    - tensorboardx: 2.6.2.2
    - terminado: 0.18.1
    - textual: 0.58.0
    - threadpoolctl: 3.5.0
    - tifffile: 2024.4.24
    - tinycss2: 1.3.0
    - tokenizers: 0.19.1
    - tomli: 2.0.1
    - torch: 1.13.0
    - torch-neuronx: 1.13.1.1.14.0
    - torch-xla: 1.13.1+torchneurone
    - torchmetrics: 1.3.2
    - torchvision: 0.14.0
    - tornado: 6.4
    - tqdm: 4.65.0
    - traitlets: 5.14.3
    - transformers: 4.40.1
    - triton: 2.3.0
    - truststore: 0.8.0
    - typer: 0.12.3
    - types-python-dateutil: 2.9.0.20240316
    - typing-extensions: 4.11.0
    - tzdata: 2024.1
    - uc-micro-py: 1.0.3
    - uri-template: 1.3.0
    - uritemplate: 3.0.1
    - urllib3: 2.1.0
    - uvicorn: 0.29.0
    - uvloop: 0.19.0
    - virtualenv: 20.26.1
    - watchfiles: 0.21.0
    - wcwidth: 0.2.13
    - webcolors: 1.13
    - webencodings: 0.5.1
    - websocket-client: 1.8.0
    - websockets: 11.0.3
    - werkzeug: 3.0.2
    - wheel: 0.41.2
    - widgetsnbextension: 4.0.10
    - wrapt: 1.16.0
    - xmod: 1.8.1
    - xxhash: 3.4.1
    - yarl: 1.9.4
    - zipp: 3.18.1
    - zstandard: 0.19.0
  • System:
    - OS: Linux
    - architecture:
    - 64bit
    - ELF
    - processor: x86_64
    - python: 3.10.14
    - release: 5.10.214-202.855.amzn2.x86_64
    - version: Proposal for help #1 SMP Tue Apr 9 06:57:12 UTC 2024

More info

No response

@BrianF-tessera BrianF-tessera added bug Something isn't working needs triage Waiting to be triaged by maintainers labels Apr 29, 2024
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
bug Something isn't working needs triage Waiting to be triaged by maintainers ver: 2.0.x ver: 2.1.x
Projects
None yet
Development

No branches or pull requests

1 participant