Skip to content
/ aldi Public

Official implementation of "Align and Distill: Unifying and Improving Domain Adaptive Object Detection"

Notifications You must be signed in to change notification settings

justinkay/aldi

Folders and files

NameName
Last commit message
Last commit date

Latest commit

 
 
 
 
 
 
 
 
 
 
 
 
 
 
 

Repository files navigation

Align and Distill: Unifying and Improving
Domain Adaptive Object Detection

Documentation

Align and Distill (ALDI) is a state-of-the-art framework for domain adaptive object detection.

ALDI is built on top of the Detectron2 object detection library and follows the same design patterns where possible. In particular, training settings are managed by config files, datasets are managed by a dataset registry, training is handled by a custom Trainer class that extends DefaultTrainer, and we provide a training script in tools/train_net.py that comes with all the same functionality as the Detectron2 script by the same name.

Install

1. Install PyTorch and torchvision. Follow the official install guide to install the correct versions for your CUDA version.

2. Install ALDI. Clone this repository and run:

pip install -e .

Optionally include the --no-cache-dir flag if you run into OOM issues.

Data setup

There are three kinds of "datasets" in domain adaptive object detection:

Train (source) Unlabeled (target) Test (source or target)
Labeled source-domain images. Used for source-only baseline training, supervised burn-in, and domain-adaptive training. Target-domain images that are optionally labeled. If unlabeled, used for domain-adaptive training only. If labeled, can be used to train "oracle" methods (see paper). Labels, if provided, will be ignored during domain-adaptive training. A labeled source- or target-domain validation set. In most DAOD papers this comes from the target domain, even though this breaks the constraints of unsupervised domain adaptation.
Custom data

The easiest way to use your own dataset is to create a COCO-formatted JSON files and register your datasets with Detectron2. You will register each separately:

# add this to the top of tools/train_net.py or aldi/datasets.py
from detectron2.data.datasets import register_coco_instances
register_coco_instances("your_train_dataset_name", {}, "path/to/your_train_coco_labels.json", "path/to/your/train/images/")
register_coco_instances("your_unlabeled_dataset_name", {}, "path/to/your_unlabeled_coco_labels.json", "path/to/your/unlabeled/images/")
register_coco_instances("your_test_dataset_name", {}, "path/to/your_test_coco_labels.json", "path/to/your/test/images/")

Note that by default Detectron2 assumes all paths are relative to ./datasets relative to your current working directory. You can change this location if desired using the DETECTRON2_DATASETS environment variable, e.g.: export DETECTRON2_DATASETS=/path/to/datasets.


Set up DAOD benchmarks (Cityscapes, Sim10k, CFC)

Follow these instructions to set up data and reproduce benchmark results on the datasets in our paper: Cityscapes → Foggy Cityscapes, Sim10k → Cityscapes, and CFC Kenai → Channel.

Training

See our detailed training instructions. The TL;DR is:

Config setup

Training is managed through config files. We provide example configs for burn-in/baseline models, oracle models, and ALDI++.

You will need to modify (at least) the following values for any custom data:

DATASETS:
  TRAIN: ("your_training_dataset_name",) # needs to be a tuple, and can contain multiple datasets if you want
  UNLABELED: ("your_unlabeled_dataset_name",) # needs to be a tuple, and can contain multiple datasets if you want
  TEST: ("your_test_dataset_name",)  # needs to be a tuple, and can contain multiple datasets if you want
MODEL:
  ROI_HEADS:
    NUM_CLASSES: 9 # change to match your number of classes

Run training

ALDI involves two training phases: (1) burn-in, (2) domain adaptation. Again, please reference the detailed training instructions. Training involves running tools/train_net.py for each training phase:

python tools/train_net.py --config path/to/your/config.yaml

The script is compatible with all Detectron2 training options (--num-gpus, in-line config modifications, etc.).

Evaluation

After training, to run evaluation with your model:

python tools/train_net.py --eval-only --config-file path/to/your/aldi_config.yaml MODEL.WEIGHTS path/to/your/model_best.pth

We welcome any PRs to add DefaultPredictor inference functionality!

Model zoo

We provide burn-in checkpoints and final models for DAOD benchmarks (Cityscapes → Foggy Cityscapes, Sim10k → Cityscapes, and CFC Kenai → Channel) in the model zoo.

You can download the required model weights for any config file we provide using

python tools/download_model_for_config.py --config-file path/to/config.yaml

Extras

The main branch contains all you need to run ALDI, and is a good starting point for most.

Additional code and configuration files to reproduce all experiments in our paper can be found on the extras branch.

Reference

Justin Kay, Timm Haucke, Suzanne Stathatos, Siqi Deng, Erik Young, Pietro Perona, Sara Beery, and Grant Van Horn.

Object detectors often perform poorly on data that differs from their training set. Domain adaptive object detection (DAOD) methods have recently demonstrated strong results on addressing this challenge. Unfortunately, we identify systemic benchmarking pitfalls that call past results into question and hamper further progress: (a) Overestimation of performance due to underpowered baselines, (b) Inconsistent implementation practices preventing transparent comparisons of methods, and (c) Lack of generality due to outdated backbones and lack of diversity in benchmarks. We address these problems by introducing: (1) A unified benchmarking and implementation framework, Align and Distill (ALDI), enabling comparison of DAOD methods and supporting future development, (2) A fair and modern training and evaluation protocol for DAOD that addresses benchmarking pitfalls, (3) A new DAOD benchmark dataset, CFC-DAOD, enabling evaluation on diverse real-world data, and (4) A new method, ALDI++, that achieves state-of-the-art results by a large margin. ALDI++ outperforms the previous state-of-the-art by +3.5 AP50 on Cityscapes → Foggy Cityscapes, +5.7 AP50 on Sim10k → Cityscapes (where ours is the only method to outperform a fair baseline), and +2.0 AP50 on CFC Kenai → Channel. Our framework, dataset, and state-of-the-art method offer a critical reset for DAOD and provide a strong foundation for future research.

If you find our work useful in your research please consider citing our paper:

@misc{kay2024align,
      title={Align and Distill: Unifying and Improving Domain Adaptive Object Detection}, 
      author={Justin Kay and Timm Haucke and Suzanne Stathatos and Siqi Deng and Erik Young and Pietro Perona and Sara Beery and Grant Van Horn},
      year={2024},
      eprint={2403.12029},
      archivePrefix={arXiv},
      primaryClass={cs.CV}
}