Skip to content

Latest commit

 

History

History

mPLUG-Owl2

Folders and files

NameName
Last commit message
Last commit date

parent directory

..
 
 
 
 
 
 
 
 
 
 
 
 

mPLUG-Owl2: Revolutionizing Multi-modal Large Language Model with Modality Collaboration

Qinghao Ye*, Haiyang Xu*, Jiabo Ye*, Ming Yan†, Anwen Hu, Haowei Liu, Qian Qi, Ji Zhang, Fei Huang, Jingren Zhou
DAMO Academy, Alibaba Group
*Equal Contribution; † Corresponding Author
Open in Spaces Demo ModelScope License Hits

mPLUG-Owl2 is the multi-modal large lanaguage model (MLLM) proposed by DAMO Academy, and it is the first MLLM that achieves both state-of-the-art on pure-text and multi-modal datasets with remarkable improvement. Compared to the models with similar size, mPLUG-Owl2 has surpasses the strong baseline, LLaVA-1.5 with many aspect. In addition, even with smaller vision backbone, mPLUG-Owl2 outperforms Qwen-VL largely (i.e., ViT-L(0.3B) v.s. ViT-G (1.9B)), especially on the low-level perception task (Q-Bench).

News and Updates

  • 2024.02.01 🔥🔥🔥 We relaese mPLUG-Owl2.1, a Chinese enhanced version of mPLUG-Owl2. The weight is available at HuggingFace.
  • 2023.11.08 🔥🔥🔥 We relaese mPLUG-Owl2 on both modelscope and Huggingface. The paper will be released soon for more details about the model, including training details and model performance.

Performance

General Vision-Language Benchmark Performance

Method #Params Image Caption General VQA General VQA (Zero-shot)
COCO Flickr30K (Zero-shot) VQAv2 OKVQA GQA VizWizQA TextVQA SciQA (IMG)
Generalists BLIP-2 8.2B - 74.9 65.0 45.9 41.0 19.6 42.5 61.0
InstructBLIP 8.2B 102.2 82.4 - - 49.2 34.5 50.1* 60.5
Unified-IO-XL 2.9B 122.3 - 77.9 54.0 - 57.4** - -
PaLM-E-12B 12B 135.0 - 76.2 55.5 - - - -
Shikra 7.2B 117.5 73.9 77.4 47.2 - - - -
LLaVA-1.5 7.2B - - 78.5 - 62.0 50.0 46.1/58.2* 66.8
Qwen-VL-Chat 9.6B 131.9 81.0 78.2 56.6 57.5 38.9 61.5** 68.2
mPLUG-Owl2 8.2B 137.3 85.1 79.4 57.7 56.1 54.5 54.3/58.2* 68.7
mPLUG-Owl2.1 9.8B 135.3 78.5 79.9 58.1 60.3 61.82 57.4 72.3
  • * stands for using OCR pipeline input
  • ** denotes the model has trained on the dataset instead of zero-shot setting.
  • For zero-shot image captioning, mPLUG-Owl2 achieves the SOTA on Flickr30K.
  • For general VQA, mPLUG-Owl2 achieves the SOTA under the same generalist LVLM scale settings. Especially, without OCR pipeline input and fine-tuning on TextVQA, mPLUG-Owl2 has achieves remarkable performance and surpasses LLaVA-1.5 by 8.2 point.

MLLM Benchmark (Zero-shot)

Method Vision Encoder Language Model MME MMBench MM-Vet SEED-Bench Q-Bench
BLIP-2 ViT-g (1.3B) Vicuna (7B) 1293.84 - 22.4 46.4 -
MiniGPT-4 ViT-g (1.3B) Vicuna (7B) 581.67 23.0 22.1 42.8 -
LLaVA ViT-L (0.3B) Vicuna (7B) 502.82 36.2 28.1 33.5 54.7
mPLUG-Owl ViT-L (0.3B) LLaMA (7B) 967.34 46.6 - 34.0 58.9
InstructBLIP ViT-g (1.3B) Vicuna (7B) 1212.82 36.0 26.2 53.4 55.8
LLaMA-Adapter-v2 ViT-L (0.3B) LLaMA (7B) 1328.40 39.5 31.4 32.7 58.1
Otter ViT-L (0.3B) LLaMA (7B) 1292.26 48.3 24.6 32.9 47.2
Qwen-VL-Chat ViT-G (1.9B) Qwen (7B) 1487.58 60.6 - 58.2 61.6
LLaVA-1.5 ViT-L (0.3B) Vicuna (7B) 1510.70 73.7 30.5 58.6 60.7
mPLUG-Owl2 ViT-L (0.3B) LLaMA (7B) 1450.19 64.5 36.2 57.8 62.9
mPLUG-Owl2.1 ViT-G (1.9B) Qwen (7B) 1545 73.7 39.0 60.8 64.7

Text Benchmarks

Method MMLU BBH AGIEval ARC-c ARC-e
LLaMA-2 46.8 38.2 21.8 40.3 56.1
WizardLM 38.1 34.7 23.2 47.5 59.6
LLaMA-2-Chat 46.2 35.6 28.5 54.9 71.6
Vicuna-v1.5 51.1 41.2 21.2 56.6 72.8
mPLUG-Owl2 53.4 45.0 32.7 65.8 79.9

Checkpoints

Huggingface Model Hub

Model Phase Download link
mPLUG-Owl2 Pre-training -
mPLUG-Owl2 Instruction tuning Download link
mPLUG-Owl2.1 Instruction tuning Download link

Modelscope Model Hub

Model Phase Download link
mPLUG-Owl2 Pre-training -
mPLUG-Owl2 Instruction tuning Download link

Note: There might be some variation of the performance due to the conversion of the checkpoint. We do our model's training on Megatron framework with Model Parallism (MP=2) by parallalizing vision transformer, visual abstractor, and LLM, which is more efficient than using DeepSpeed Zero-3.

Usage

Install

  1. Clone this repository and navigate to mPLUG-Owl2 folder
git clone https://github.com/X-PLUG/mPLUG-Owl.git
cd mPLUG-Owl/mPLUG-Owl2
  1. Install Package
conda create -n mplug_owl2 python=3.10 -y
conda activate mplug_owl2
pip install --upgrade pip
pip install -e .
  1. Install additional packages for training cases
pip install -e ".[train]"
pip install flash-attn --no-build-isolation

Quick Start Code

import torch
from PIL import Image
from transformers import TextStreamer

from mplug_owl2.constants import IMAGE_TOKEN_INDEX, DEFAULT_IMAGE_TOKEN
from mplug_owl2.conversation import conv_templates, SeparatorStyle
from mplug_owl2.model.builder import load_pretrained_model
from mplug_owl2.mm_utils import process_images, tokenizer_image_token, get_model_name_from_path, KeywordsStoppingCriteria

image_file = '' # Image Path
model_path = 'MAGAer13/mplug-owl2-llama2-7b'
query = "Describe the image."

model_name = get_model_name_from_path(model_path)
tokenizer, model, image_processor, context_len = load_pretrained_model(model_path, None, model_name, load_8bit=False, load_4bit=False, device="cuda")

conv = conv_templates["mplug_owl2"].copy()
roles = conv.roles

image = Image.open(image_file).convert('RGB')
max_edge = max(image.size) # We recommand you to resize to squared image for BEST performance.
image = image.resize((max_edge, max_edge))

image_tensor = process_images([image], image_processor)
image_tensor = image_tensor.to(model.device, dtype=torch.float16)

inp = DEFAULT_IMAGE_TOKEN + query
conv.append_message(conv.roles[0], inp)
conv.append_message(conv.roles[1], None)
prompt = conv.get_prompt()

input_ids = tokenizer_image_token(prompt, tokenizer, IMAGE_TOKEN_INDEX, return_tensors='pt').unsqueeze(0).to(model.device)
stop_str = conv.sep2
keywords = [stop_str]
stopping_criteria = KeywordsStoppingCriteria(keywords, tokenizer, input_ids)
streamer = TextStreamer(tokenizer, skip_prompt=True, skip_special_tokens=True)

temperature = 0.7
max_new_tokens = 512

with torch.inference_mode():
    output_ids = model.generate(
        input_ids,
        images=image_tensor,
        do_sample=True,
        temperature=temperature,
        max_new_tokens=max_new_tokens,
        streamer=streamer,
        use_cache=True,
        stopping_criteria=[stopping_criteria])

outputs = tokenizer.decode(output_ids[0, input_ids.shape[1]:]).strip()
print(outputs)

Gradio Web UI Setup Guide

To utilize the Gradio demo locally, follow the instructions below. If you need to compare different checkpoints with multiple model workers, remember that you only need to initiate the controller and the web server once.

How to Launch a Controller

Use the following command to start a controller:

python -m mplug_owl2.serve.controller --host 0.0.0.0 --port 10000

How to Launch a Gradio Web Server

The next step is to launch a gradio web server using the command below:

python -m mplug_owl2.serve.gradio_web_server --controller http://localhost:10000 --model-list-mode reload

This command launches the Gradio web interface. You can access the web interface using the URL displayed on your screen. Note that there will be no models listed initially, as we have not launched any model workers. The list will update automatically when a model worker is launched.

How to Launch a Model Worker

A model worker performs the inference on the GPU. To launch it, use the following command:

python -m mplug_owl2.serve.model_worker --host 0.0.0.0 --controller http://localhost:10000 --port 40000 --worker http://localhost:40000 --model-path MAGAer13/mplug-owl2-llama2-7b

Wait until the model loading process is complete and the message "Uvicorn running on ..." appears. Refresh your Gradio web UI to see the newly launched model in the model list.

You can launch multiple workers to compare different model checkpoints within the same Gradio interface. Keep the --controller identical, but change the --port and --worker to different port numbers for each worker.

If you are using an Apple device with an M1 or M2 chip, you can specify the mps device by using the --device flag: --device mps.

How to Use Quantized Inference (4-Bit & 8-Bit)

To reduce the GPU memory footprint, you can run the inference with quantized bits (4-bit or 8-bit) by simply appending --load-4bit or --load-8bit to the model worker command. Here is an example of running with 4-bit quantization

python -m mplug_owl2.serve.model_worker --host 0.0.0.0 --controller http://localhost:10000 --port 40000 --worker http://localhost:40000 --model-path MAGAer13/mplug-owl2-llama2-7b --load-4bit

How to Launch a Model Worker with Unmerged LoRA Weights

You can launch the model worker with unmerged LoRA weights to save disk space. Here is an example:

python -m mplug_owl2.serve.model_worker --host 0.0.0.0 --controller http://localhost:10000 --port 40000 --worker http://localhost:40000 --model-path MAGAer13/mplug-owl2-llama2-7b-lora-sft --model-base MAGAer13/mplug-owl2-llama2-7b

What if I want to run demo just local?

You can use our modified local_serve demo, to start the demo using one-line as follows:

python -m mplug_owl2.local_serve.local_web_server \
    --model-path MAGAer13/mplug-owl2-llama2-7b \
    --port 56789

You also can append --load-4bit or --load-8bit to the command if you would like to launch the demo in 4-bit or 8bit.

CLI Inference Guide

You can chat about images using mPLUG-Owl without the Gradio interface. It also supports multiple GPUs, and 4-bit and 8-bit quantized inference. Here is an example command:

python -m mplug_owl2.serve.cli \
    --model-path MAGAer13/mplug-owl2-llama2-7b \
    --image-file "mplug_owl2/serve/examples/extreme_ironing.jpg" \
    --load-4bit

Training

Taking finetuning on LLAVA dataset as an example.

Prepare Training Data

Please refer to LLaVA for data preparation. Note that we do not use <image> as the token for image, since it would conflict with some code tags, instead we use <|image|> for avoiding such conflict. Besides, we also add the formatting prompt used in LLaVA-1.5 for VQA types data and Multiple Choice data as illustrated follows:

question = "What's the weather like today?"
# VQA like
prompt = f"<|image|>{question}\nAnswer the question using a single word or phrase."

# Multiple Choice like
options = "A. OPTION 1\nB. OPTION 2\nC. OPTION 3"
prompt = f"<|image|>{question}\n{options}\nAnswer with the option’s letter from the given choices directly."

Prepare model checkpoint

You follow the training checkpointing presented above.

Training scripts

Training script with DeepSpeed ZeRO-3: scripts/finetune.sh.

If you are do not have enough GPU memory:

  • Use LoRA: scripts/finetune_lora.sh. Make sure per_device_train_batch_size*gradient_accumulation_steps is the same as the provided script for best reproducibility.
  • Replace zero3.json with zero3_offload.json which offloads some parameters to CPU RAM. This slows down the training speed.

New options to note:

  • --freeze_vision_model True: We freeze the vision transformer by default. If you want training the vision transformer, this option should be set False.
  • --tune_visual_abstractor True: We training the visual abstractor by default. If you want freeze the abstractor, this option should be set False.

Evaluation

See Evaluation Instruction Here.

Citation

If you find mPLUG-Owl2 useful for your research and applications, please cite using this BibTeX:

@misc{ye2023mplugowl2,
      title={mPLUG-Owl2: Revolutionizing Multi-modal Large Language Model with Modality Collaboration}, 
      author={Qinghao Ye and Haiyang Xu and Jiabo Ye and Ming Yan and Anwen Hu and Haowei Liu and Qi Qian and Ji Zhang and Fei Huang and Jingren Zhou},
      year={2023},
      eprint={2311.04257},
      archivePrefix={arXiv},
      primaryClass={cs.CL}
}

@misc{ye2023mplugowl,
      title={mPLUG-Owl: Modularization Empowers Large Language Models with Multimodality}, 
      author={Qinghao Ye and Haiyang Xu and Guohai Xu and Jiabo Ye and Ming Yan and Yiyang Zhou and Junyang Wang and Anwen Hu and Pengcheng Shi and Yaya Shi and Chaoya Jiang and Chenliang Li and Yuanhong Xu and Hehong Chen and Junfeng Tian and Qi Qian and Ji Zhang and Fei Huang},
      year={2023},
      eprint={2304.14178},
      archivePrefix={arXiv},
      primaryClass={cs.CL}
}

Acknowledgement

  • LLaVA: the codebase we built upon. Thanks for the authors of LLaVA for providing the framework.

Related Projects

  • LLaMA. A open-source collection of state-of-the-art large pre-trained language models.
  • LLaVA. A visual instruction tuned vision language model which achieves GPT4 level capabilities.
  • mPLUG. A vision-language foundation model for both cross-modal understanding and generation.
  • mPLUG-2. A multimodal model with a modular design, which inspired our project.