Skip to content

nomic-ai/contrastors

Repository files navigation

contrastors

contrastors is contrastive learning toolkit that enables researchers and engineers to train and evaluate contrastive models efficiently.

img

Features

  • Built on top of Flash Attention for fast and efficient training
  • Support for training on multiple GPUs
  • GradCache support for training with large batch sizes in constrained memory environments
  • Huggingface Support for easy loading of common models (Pythia/GPTNeoX, BERT, etc.)
  • Masked Language Modeling (MLM) Pretraining
  • Matryoshka Representation Learning for flexible embedding sizes

Research

Getting Started and Requirements

The contrastors library relies on custom kernels from the Flash Attention repository. To setup your enviornment you will need to follow the steps below.

Make sure that you have Cuda 11.8+. You can check this by running nvcc --version or if you already have torch installed you can run python -c "import torch; print(torch.version.cuda)"

Create a python venv and activate it

python3 -m venv env
source env/bin/activate

Install torch. See the torch docs for specific instructions for your system (e.g. the default CUDA torch supports is 12.1 as of 12/12/2023).

pip3 install torch torchvision torchaudio

Install wheel, packaging, ninja for Flash Attention (so the builds don't take too long)

pip install wheel packaging ninja

Install Flash Attention and the custom kernels

pip install --no-cache-dir flash-attn --no-build-isolation git+https://github.com/HazyResearch/flash-attention.git#subdirectory=csrc/rotary git+https://github.com/HazyResearch/flash-attention.git#subdirectory=csrc/layer_norm git+https://github.com/HazyResearch/flash-attention.git#subdirectory=csrc/fused_dense_lib git+https://github.com/HazyResearch/flash-attention.git#subdirectory=csrc/xentropy

Install the rest of the requirements and the package

pip install -e . 

Data Access

We provide access to the nomic-embed-text-v1 dataset via the nomic package. To access the data, you will need to create an account and login to the nomic package. First create an account at atlas.nomic.ai, download the nomic Python client, and run the following commands:

pip install nomic
nomic login # follow prompts to login
python -c "from nomic import atlas; print(atlas._get_datastream_credentials(name='contrastors'))"

which will print out your access keys. You can then configure them by using aws configure or setting the AWS_ACCESS_KEY_ID and AWS_SECRET_ACCESS_KEY environment variables.

If you do not have the AWS CLI installed, you can install it here.

To verify your access, you can run the following command to list the contents of the bucket:

aws s3 ls --endpoint-url=https://9fa58365a1a3d032127970d0bd9a1290.r2.cloudflarestorage.com/ s3://contrastive
aws s3 ls --endpoint-url=https://9fa58365a1a3d032127970d0bd9a1290.r2.cloudflarestorage.com/ s3://contrastive-index-filtered

You should be able to see the contents of the bucket and download the data.

If you intend to train using our data and the contrastors repo, you will need to setup fsspec support for Cloudflare R2. To do so, create a file ~/.config/fsspec/s3.json with the following contents:

{
  "s3": {
    "client_kwargs": {
      "endpoint_url": "https://9fa58365a1a3d032127970d0bd9a1290.r2.cloudflarestorage.com/",
      "aws_access_key_id": <ACCESS_KEY_ID>,
      "aws_secret_access_key": <SECRET_KEY_ID>
    }
  }
}

Nomic Data Format

Our text data is stored in gziped jsonl files with which we also store a counts.json file and offsets.json.gzip.

The counts.json file is a dictionary mapping the file name to the number of examples in the file. The offsets.json.gz file is a dictionary mapping the file name to a dictionary where each key is the index of the example and the value is a tuple of the start and end byte offset of the example in the file. We do this to allow for streaming of data in from R2, especially when the data is larger than the buffer size.

Here's a small example of what a dataset configuration might look like:

datasets:
  - name: "paq"
    bucket: "s3://contrastive-index-filtered/paq_full/shard-{00000..00538}.jsonl.gz"
    query_prefix: "search_query"
    document_prefix: "search_document"
    objective: 
        type: "paired"
        columns: ["query", "document"]

objective defines if it's a paired or triplet objective. In both cases, the columns field defines the columns to use for each example.

Training nomic-embed-text-v1

Masked Language Modeling Pretraining

To train your own BERT from scratch (with all the optimizations) run

cd src/contrastors
deepspeed --num_gpus=8 train.py --config=configs/train/mlm.yaml --deepspeed_config=configs/deepspeed/ds_config.json --dtype=bf16

Constrastive Pretraining and Finetuning

To launch an experiment run

cd src/contrastors
torchrun --nproc-per-node=8 train.py --config=configs/train/contrastive_pretrain.yaml --dtype=bf16

This will train a bert model on all ~200M examples. To change the dataset, you can modify data_args.input_shards.

To finetune nomic-bert-embed-v1-unsupervised, update the config to configs/train/contrastive_finetune.yaml.

Generating Your Own Data

To generate your own data for any step of the pipeline, you can use the provided scripts in scripts/text.

See the README in scripts/text for more information.

Pretrained Models

We provide pretrained models for nomic-embed-text-v1 at the following locations:

Join the Nomic Community

License

This project and models are licensed under the Apache 2.0 License.

Acknowledgements

We thank Tri Dao for his work on Flash Attention and the custom kernels that make this project possible, the OpenCLIP team for their great repository with which much of this work is based on, and the Huggingface team for their great work on the transformers library.

Citation

If you find the model, dataset, or training code useful, please cite our work

@misc{nussbaum2024nomic,
      title={Nomic Embed: Training a Reproducible Long Context Text Embedder}, 
      author={Zach Nussbaum and John X. Morris and Brandon Duderstadt and Andriy Mulyar},
      year={2024},
      eprint={2402.01613},
      archivePrefix={arXiv},
      primaryClass={cs.CL}
}