Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

llava 1.5 invalid output after first inference (llamacpp server) #7060

Closed
CaptainOfHacks opened this issue May 3, 2024 · 13 comments · Fixed by #7163
Closed

llava 1.5 invalid output after first inference (llamacpp server) #7060

CaptainOfHacks opened this issue May 3, 2024 · 13 comments · Fixed by #7163

Comments

@CaptainOfHacks
Copy link

I use this server config:

    "host": "0.0.0.0",
    "port": 8085,
    "api_key": "api_key",
    "models": [
        {
            "model": "models/phi3_mini_model/phi3_mini_model.gguf",
            "model_alias": "gpt-3.5-turbo",
            "chat_format": "chatml",
            "n_gpu_layers": 35,
            "offload_kqv": true,
            "n_threads": 12,
            "n_batch": 512,
            "n_ctx": 2048
        },
        {
            "model": "models/phi3_mini_model/phi3_mini_model.gguf",
            "model_alias": "gpt-4",
            "chat_format": "chatml",
            "n_gpu_layers": 35,
            "offload_kqv": true,
            "n_threads": 12,
            "n_batch": 512,
            "n_ctx": 4096
        },
        {
            "model": "models/llava15_vision_model/ggml-model-q4_k.gguf",
            "model_alias": "gpt-4-vision-preview",
            "chat_format": "llava-1-5",
            "clip_model_path": "models/llava15_vision_model/mmproj-model-f16.gguf",
            "n_gpu_layers": 35,
            "offload_kqv": true,
            "n_threads": 12,
            "n_batch": 512,
            "n_ctx": 2048,
            "flash_attn": true
        }
    ]
}

start server with this command:

python3 -m llama_cpp.server --config_file server_config.json

All works good for only text mode. But for llava 1.5, works only first run, after this for any image response is invalid.

I execute folllow notebook cells:

from openai import OpenAI

client = OpenAI(base_url="http://localtest.me:8085/v1", api_key="api_key")
import base64
import io
from PIL import Image
import requests

def load_image_and_convert_to_base64(url):
    image = Image.open(requests.get(url, stream=True).raw)
    image = image.resize((336, 336))
    buffered = io.BytesIO()
    image.save(buffered, format="PNG")
    img_str = base64.b64encode(buffered.getvalue()).decode("utf-8")
    return img_str


url_1 = "https://www.princeton.edu/sites/default/files/styles/1x_full_2x_half_crop/public/images/2022/02/KOA_Nassau_2697x1517.jpg?itok=Bg2K7j7J"
url_2 = "https://images.pexels.com/photos/106399/pexels-photo-106399.jpeg?auto=compress&cs=tinysrgb&w=1260&h=750&dpr=2"


first_image_b64 = load_image_and_convert_to_base64(url_1)
second_image_b64 = load_image_and_convert_to_base64(url_2)
def generate_caption(image_b64):
    response = client.chat.completions.create(
        model="gpt-4-vision-preview",
        max_tokens=1000,
        stop=["<|end|>"],
        temperature=0.1,
        messages=[
            {
                "role": "system",
                "content": "You are an assistant who perfectly describes images."
            },
            {
                "role": "user",
                "content": [
                    {"type": "text", "text": "What's in this image?"},
                    {"type": "image_url", "image_url": {"url": f"data:image/png;base64,{image_b64}"}}
                ]
            }
        ]
    )
    return response.choices[0].message.content

For first run works correctly:

CleanShot 2024-05-03 at 17 35 45@2x

Second run with another image dosen't work:

CleanShot 2024-05-03 at 17 36 06@2x

Again with first image:
CleanShot 2024-05-03 at 17 42 01@2x

Here are logs for model loading:

clip_model_load: loaded meta data with 18 key-value pairs and 377 tensors from models/llava15_vision_model/mmproj-model-f16.gguf
clip_model_load: Dumping metadata keys/values. Note: KV overrides do not apply in this output.
clip_model_load: - kv   0:                       general.architecture str              = clip
clip_model_load: - kv   1:                      clip.has_text_encoder bool             = false
clip_model_load: - kv   2:                    clip.has_vision_encoder bool             = true
clip_model_load: - kv   3:                   clip.has_llava_projector bool             = true
clip_model_load: - kv   4:                          general.file_type u32              = 1
clip_model_load: - kv   5:                               general.name str              = openai/clip-vit-large-patch14-336
clip_model_load: - kv   6:                        general.description str              = image encoder for LLaVA
clip_model_load: - kv   7:                     clip.vision.image_size u32              = 336
clip_model_load: - kv   8:                     clip.vision.patch_size u32              = 14
clip_model_load: - kv   9:               clip.vision.embedding_length u32              = 1024
clip_model_load: - kv  10:            clip.vision.feed_forward_length u32              = 4096
clip_model_load: - kv  11:                 clip.vision.projection_dim u32              = 768
clip_model_load: - kv  12:           clip.vision.attention.head_count u32              = 16
clip_model_load: - kv  13:   clip.vision.attention.layer_norm_epsilon f32              = 0.000010
clip_model_load: - kv  14:                    clip.vision.block_count u32              = 23
clip_model_load: - kv  15:                     clip.vision.image_mean arr[f32,3]       = [0.481455, 0.457828, 0.408211]
clip_model_load: - kv  16:                      clip.vision.image_std arr[f32,3]       = [0.268630, 0.261303, 0.275777]
clip_model_load: - kv  17:                              clip.use_gelu bool             = false
clip_model_load: - type  f32:  235 tensors
clip_model_load: - type  f16:  142 tensors
ggml_metal_init: allocating
ggml_metal_init: found device: Apple M3 Max
ggml_metal_init: picking default device: Apple M3 Max
ggml_metal_init: using embedded metal library
ggml_metal_init: GPU name:   Apple M3 Max
ggml_metal_init: GPU family: MTLGPUFamilyApple9  (1009)
ggml_metal_init: GPU family: MTLGPUFamilyCommon3 (3003)
ggml_metal_init: GPU family: MTLGPUFamilyMetal3  (5001)
ggml_metal_init: simdgroup reduction support   = true
ggml_metal_init: simdgroup matrix mul. support = true
ggml_metal_init: hasUnifiedMemory              = true
ggml_metal_init: recommendedMaxWorkingSetSize  = 28991.03 MB
clip_model_load: CLIP using Metal backend
clip_model_load: params backend buffer size =  595.49 MB (377 tensors)
key clip.vision.image_grid_pinpoints not found in file
key clip.vision.mm_patch_merge_type not found in file
key clip.vision.image_crop_resolution not found in file
clip_model_load: compute allocated memory: 32.89 MB
llama_model_loader: loaded meta data with 19 key-value pairs and 291 tensors from models/llava15_vision_model/ggml-model-q4_k.gguf (version GGUF V2)
llama_model_loader: Dumping metadata keys/values. Note: KV overrides do not apply in this output.
llama_model_loader: - kv   0:                       general.architecture str              = llama
llama_model_loader: - kv   1:                               general.name str              = LLaMA v2
llama_model_loader: - kv   2:                       llama.context_length u32              = 4096
llama_model_loader: - kv   3:                     llama.embedding_length u32              = 4096
llama_model_loader: - kv   4:                          llama.block_count u32              = 32
llama_model_loader: - kv   5:                  llama.feed_forward_length u32              = 11008
llama_model_loader: - kv   6:                 llama.rope.dimension_count u32              = 128
llama_model_loader: - kv   7:                 llama.attention.head_count u32              = 32
llama_model_loader: - kv   8:              llama.attention.head_count_kv u32              = 32
llama_model_loader: - kv   9:     llama.attention.layer_norm_rms_epsilon f32              = 0.000010
llama_model_loader: - kv  10:                          general.file_type u32              = 15
llama_model_loader: - kv  11:                       tokenizer.ggml.model str              = llama
llama_model_loader: - kv  12:                      tokenizer.ggml.tokens arr[str,32000]   = ["<unk>", "<s>", "</s>", "<0x00>", "<...
llama_model_loader: - kv  13:                      tokenizer.ggml.scores arr[f32,32000]   = [0.000000, 0.000000, 0.000000, 0.0000...
llama_model_loader: - kv  14:                  tokenizer.ggml.token_type arr[i32,32000]   = [2, 3, 3, 6, 6, 6, 6, 6, 6, 6, 6, 6, ...
llama_model_loader: - kv  15:                tokenizer.ggml.bos_token_id u32              = 1
llama_model_loader: - kv  16:                tokenizer.ggml.eos_token_id u32              = 2
llama_model_loader: - kv  17:            tokenizer.ggml.padding_token_id u32              = 0
llama_model_loader: - kv  18:               general.quantization_version u32              = 2
llama_model_loader: - type  f32:   65 tensors
llama_model_loader: - type q4_K:  193 tensors
llama_model_loader: - type q6_K:   33 tensors
llm_load_vocab: special tokens definition check successful ( 259/32000 ).
llm_load_print_meta: format           = GGUF V2
llm_load_print_meta: arch             = llama
llm_load_print_meta: vocab type       = SPM
llm_load_print_meta: n_vocab          = 32000
llm_load_print_meta: n_merges         = 0
llm_load_print_meta: n_ctx_train      = 4096
llm_load_print_meta: n_embd           = 4096
llm_load_print_meta: n_head           = 32
llm_load_print_meta: n_head_kv        = 32
llm_load_print_meta: n_layer          = 32
llm_load_print_meta: n_rot            = 128
llm_load_print_meta: n_embd_head_k    = 128
llm_load_print_meta: n_embd_head_v    = 128
llm_load_print_meta: n_gqa            = 1
llm_load_print_meta: n_embd_k_gqa     = 4096
llm_load_print_meta: n_embd_v_gqa     = 4096
llm_load_print_meta: f_norm_eps       = 0.0e+00
llm_load_print_meta: f_norm_rms_eps   = 1.0e-05
llm_load_print_meta: f_clamp_kqv      = 0.0e+00
llm_load_print_meta: f_max_alibi_bias = 0.0e+00
llm_load_print_meta: f_logit_scale    = 0.0e+00
llm_load_print_meta: n_ff             = 11008
llm_load_print_meta: n_expert         = 0
llm_load_print_meta: n_expert_used    = 0
llm_load_print_meta: causal attn      = 1
llm_load_print_meta: pooling type     = 0
llm_load_print_meta: rope type        = 0
llm_load_print_meta: rope scaling     = linear
llm_load_print_meta: freq_base_train  = 10000.0
llm_load_print_meta: freq_scale_train = 1
llm_load_print_meta: n_yarn_orig_ctx  = 4096
llm_load_print_meta: rope_finetuned   = unknown
llm_load_print_meta: ssm_d_conv       = 0
llm_load_print_meta: ssm_d_inner      = 0
llm_load_print_meta: ssm_d_state      = 0
llm_load_print_meta: ssm_dt_rank      = 0
llm_load_print_meta: model type       = 7B
llm_load_print_meta: model ftype      = Q4_K - Medium
llm_load_print_meta: model params     = 6.74 B
llm_load_print_meta: model size       = 3.80 GiB (4.84 BPW) 
llm_load_print_meta: general.name     = LLaMA v2
llm_load_print_meta: BOS token        = 1 '<s>'
llm_load_print_meta: EOS token        = 2 '</s>'
llm_load_print_meta: UNK token        = 0 '<unk>'
llm_load_print_meta: PAD token        = 0 '<unk>'
llm_load_print_meta: LF token         = 13 '<0x0A>'
llm_load_tensors: ggml ctx size =    0.30 MiB
ggml_backend_metal_log_allocated_size: allocated buffer, size =  3820.94 MiB, ( 4460.03 / 27648.00)
llm_load_tensors: offloading 32 repeating layers to GPU
llm_load_tensors: offloading non-repeating layers to GPU
llm_load_tensors: offloaded 33/33 layers to GPU
llm_load_tensors:        CPU buffer size =    70.31 MiB
llm_load_tensors:      Metal buffer size =  3820.93 MiB
..................................................................................................
llama_new_context_with_model: n_ctx      = 2048
llama_new_context_with_model: n_batch    = 512
llama_new_context_with_model: n_ubatch   = 512
llama_new_context_with_model: flash_attn = 0
llama_new_context_with_model: freq_base  = 10000.0
llama_new_context_with_model: freq_scale = 1
ggml_metal_init: allocating
ggml_metal_init: found device: Apple M3 Max
ggml_metal_init: picking default device: Apple M3 Max
ggml_metal_init: using embedded metal library
ggml_metal_init: GPU name:   Apple M3 Max
ggml_metal_init: GPU family: MTLGPUFamilyApple9  (1009)
ggml_metal_init: GPU family: MTLGPUFamilyCommon3 (3003)
ggml_metal_init: GPU family: MTLGPUFamilyMetal3  (5001)
ggml_metal_init: simdgroup reduction support   = true
ggml_metal_init: simdgroup matrix mul. support = true
ggml_metal_init: hasUnifiedMemory              = true
ggml_metal_init: recommendedMaxWorkingSetSize  = 28991.03 MB
llama_kv_cache_init:      Metal KV buffer size =  1024.00 MiB
llama_new_context_with_model: KV self size  = 1024.00 MiB, K (f16):  512.00 MiB, V (f16):  512.00 MiB
llama_new_context_with_model:        CPU  output buffer size =     0.14 MiB
llama_new_context_with_model:      Metal compute buffer size =   164.00 MiB
llama_new_context_with_model:        CPU compute buffer size =    12.01 MiB
llama_new_context_with_model: graph nodes  = 1030
llama_new_context_with_model: graph splits = 2
AVX = 0 | AVX_VNNI = 0 | AVX2 = 0 | AVX512 = 0 | AVX512_VBMI = 0 | AVX512_VNNI = 0 | FMA = 0 | NEON = 1 | ARM_FMA = 1 | F16C = 0 | FP16_VA = 1 | WASM_SIMD = 0 | BLAS = 1 | SSE3 = 0 | SSSE3 = 0 | VSX = 0 | MATMUL_INT8 = 0 | LLAMAFILE = 1 | 
Model metadata: {'general.quantization_version': '2', 'tokenizer.ggml.padding_token_id': '0', 'tokenizer.ggml.eos_token_id': '2', 'tokenizer.ggml.bos_token_id': '1', 'tokenizer.ggml.model': 'llama', 'llama.attention.head_count_kv': '32', 'llama.context_length': '4096', 'llama.attention.head_count': '32', 'llama.rope.dimension_count': '128', 'general.file_type': '15', 'llama.feed_forward_length': '11008', 'llama.embedding_length': '4096', 'llama.block_count': '32', 'general.architecture': 'llama', 'llama.attention.layer_norm_rms_epsilon': '0.000010', 'general.name': 'LLaMA v2'}
encode_image_with_clip: image embedding created: 576 tokens
@henk717
Copy link

henk717 commented May 3, 2024

We have been noticing this on Koboldcpp 1.64 as well so this may not be specific to the server, on our side it seems to work for the first image but as soon as the image gets swapped it becomes gibberish.

@LostRuins
Copy link
Collaborator

I think I found the offending commit that causes the issue, it's #6899
Reverting it seems to fix the issue.
Pinging @vikhyat

@xBelladonna
Copy link

I'm also experiencing this after commit #6899 but I have done multiple different tests and looked into the features of the outputs, and it seems there's some kind of repeatable ablation going on in the vision encoder that results in psychedelic-like imagery being projected into the language model. Frankly it's fascinating and it would be interesting to see both what is the problem as well as the fix. For now I have no conclusions, unfortunately.

jart added a commit to Mozilla-Ocho/llamafile that referenced this issue May 7, 2024
This broke the server's LLaVA support in a non-obvious way.

See ggerganov/llama.cpp#6899
See ggerganov/llama.cpp#7060
@CaptainOfHacks
Copy link
Author

@ggerganov take a look, can we revert moondream changes?

@ggerganov
Copy link
Owner

Yes, just reverted it: 9da243b

@vikhyat PTAL

@abetlen
Copy link
Collaborator

abetlen commented May 9, 2024

Narrowed down the bug to here https://github.com/vikhyat/llama.cpp/blob/3d771207b7166286baef8f9d90b960418e163f55/examples/llava/clip.cpp#L575
in the original PR where the embedding tensor name is registered and it's set as an input.

    struct ggml_tensor * embeddings = inp;
    if (ctx->has_class_embedding) {
        embeddings = ggml_new_tensor_3d(ctx0, GGML_TYPE_F32, hidden_size, num_positions, batch_size);
        // llava 1.5 fix
        ggml_set_name(embeddings, "embeddings");
        ggml_set_input(embeddings);
        //
        embeddings = ggml_acc(ctx0, embeddings, model.class_embedding,
                embeddings->nb[1], embeddings->nb[2], embeddings->nb[3], 0);
        embeddings = ggml_acc(ctx0, embeddings, inp,
                embeddings->nb[1], embeddings->nb[2], embeddings->nb[3], model.class_embedding->nb[1]);
    }

Changing the location of the ggml_set_name and ggml_set_input calls so they instead are called on the newly created tensor solves the issue for llava1.5 however I'm not sure what the appropriate fix is for mooondream and other clip models without class embeddings.

@CaptainOfHacks
Copy link
Author

CaptainOfHacks commented May 10, 2024

Same error, after @abetlen fix. Works on older version without moondream support, after revert moondream works well, now again with same error. Take a look @ggerganov @abetlen .

@CaptainOfHacks
Copy link
Author

CaptainOfHacks commented May 10, 2024

moondream generate something, but context is shared between different inferences.

Generate for first image:
generate_caption(first_image_b64)
Output:
A golden retriever puppy sits on the gray sidewalk, facing the camera with its tongue out and wearing a purple collar. The background features lush green grass and trees lining the street.
Generate for second image:
generate_caption(second_image_b64)
Output:
" A golden retriever puppy sits on the sidewalk, surrounded by people who appear to be enjoying their time outdoors. The puppy is wearing a purple collar and has a gold tag attached to its ear. In the background, there's a fountain with water flowing down it, adding a touch of tranquility to the scene."

@xBelladonna
Copy link

abetlen here commented that if failed to clear the KV cache before the next inference, and the image embedding is first, this merges context and creates odd results. But also it was said this bug is separate in llama-cpp-python. I don't know if KV cache must be cleared in explicit step when performing inference according to your process above. If maybe we could get some detail about how KV cache is cleared in this instance and replicate this step, we can see if results are still messed up or if problem is indeed fixed? Or perhaps an easier test would be to place image embedding last, such that according to the comment, bug is not triggered.

@abetlen
Copy link
Collaborator

abetlen commented May 10, 2024

@CaptainOfHacks can you install via pip install --upgrade --no-cache-dir --force-reinstall git+https://github.com/abetlen/llama-cpp-python.git I haven't published a new release yet since the merge so not sure if that's the reason it didn't appear as fixed.

@CaptainOfHacks
Copy link
Author

CaptainOfHacks commented May 10, 2024

@abetlen now works, can you please add support for Phi3 chat format? And can you create new release to be able to install from pypi with this fix?

@abetlen
Copy link
Collaborator

abetlen commented May 10, 2024

@CaptainOfHacks thanks I'll run the release now, can you link to the phi3 llava chat format in a new issue on llama-cpp-python and I'll take a look?

@CaptainOfHacks
Copy link
Author

CaptainOfHacks commented May 10, 2024

llava-phi-3-mini uses the Phi-3-instruct chat template. I think is similar with current llava-1-5, but with Phi3 instruct template instead of llama 2.
abetlen/llama-cpp-python#1443

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging a pull request may close this issue.

6 participants