Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[support Qwen models] AssertionError: Cannot intersect further patterns if '*' has already been handled. #296

Open
justairr opened this issue Dec 13, 2023 · 3 comments
Labels
enhancement New feature or request

Comments

@justairr
Copy link

justairr commented Dec 13, 2023

I'm trying to load a Qwen model with lmql, but I keep encountering this error. I've also tried Qwen-14B-Chat/Qwen-14B but encountered the same error.
My code is as follows:

import lmql

@lmql.query(
    model=lmql.model(
        "Qwen/Qwen-72B-Chat", 
        tokenizer="Qwen/Qwen-72B-Chat",
        trust_remote_code=True
    )
)
def prompt():
    '''lmql
    argmax
        "What is the capital of France? [RESPONSE]"
    where
        len(TOKENS(RESPONSE)) < 20
    '''

if __name__ == '__main__':
    print(prompt())

Error:

Traceback (most recent call last):
  File "/home/name/user1/lmql/lmql_test.py", line 21, in <module>
    print(prompt())
  File "/home/name/miniconda3/envs/lmql/lib/python3.10/site-packages/lmql/api/queries.py", line 148, in lmql_query_wrapper
    return module.query(*args, **kwargs)
  File "/home/name/miniconda3/envs/lmql/lib/python3.10/site-packages/lmql/runtime/lmql_runtime.py", line 204, in __call__
    return call_sync(self, *args, **kwargs)
  File "/home/name/miniconda3/envs/lmql/lib/python3.10/site-packages/lmql/runtime/loop.py", line 37, in call_sync
    res = loop.run_until_complete(task)
  File "/home/name/miniconda3/envs/lmql/lib/python3.10/asyncio/base_events.py", line 649, in run_until_complete
    return future.result()
  File "/home/name/miniconda3/envs/lmql/lib/python3.10/site-packages/lmql/runtime/lmql_runtime.py", line 230, in __acall__
    results = await interpreter.run(self.fct, **query_kwargs)
  File "/home/name/miniconda3/envs/lmql/lib/python3.10/site-packages/lmql/runtime/tracing/tracer.py", line 240, in wrapper
    return await fct(*args, **kwargs)
  File "/home/name/miniconda3/envs/lmql/lib/python3.10/site-packages/lmql/runtime/interpreter.py", line 1070, in run
    async for _ in decoder_fct(prompt, **decoder_args):
  File "/home/name/miniconda3/envs/lmql/lib/python3.10/site-packages/lmql/runtime/dclib/decoders.py", line 21, in argmax
    h = h.extend(await model.argmax(h, noscore=True))
  File "/home/name/miniconda3/envs/lmql/lib/python3.10/site-packages/lmql/runtime/dclib/dclib_cache.py", line 277, in argmax
    return await arr.aelement_wise(op_argmax)
  File "/home/name/miniconda3/envs/lmql/lib/python3.10/site-packages/lmql/runtime/dclib/dclib_array.py", line 318, in aelement_wise
    result_items = await asyncio.gather(*[op_with_path(path, seqs, *args, **kwargs) for path, seqs in self.sequences.items()])
  File "/home/name/miniconda3/envs/lmql/lib/python3.10/site-packages/lmql/runtime/dclib/dclib_array.py", line 317, in op_with_path
    return path, await op(element, *args, **kwargs)
  File "/home/name/miniconda3/envs/lmql/lib/python3.10/site-packages/lmql/runtime/dclib/dclib_cache.py", line 249, in op_argmax
    cache_entries = [await self.get_cache(s, 'top-1', user_data=True, **kwargs) for s in seqs]
  File "/home/name/miniconda3/envs/lmql/lib/python3.10/site-packages/lmql/runtime/dclib/dclib_cache.py", line 249, in <listcomp>
    cache_entries = [await self.get_cache(s, 'top-1', user_data=True, **kwargs) for s in seqs]
  File "/home/name/miniconda3/envs/lmql/lib/python3.10/site-packages/lmql/runtime/dclib/dclib_cache.py", line 196, in get_cache
    keys = await self.get_keys(s, edge_type, **kwargs)
  File "/home/name/miniconda3/envs/lmql/lib/python3.10/site-packages/lmql/runtime/dclib/dclib_cache.py", line 171, in get_keys
    mask = (await self.get_mask(s, **kwargs)).logits_mask[0]
  File "/home/name/miniconda3/envs/lmql/lib/python3.10/site-packages/lmql/runtime/dclib/dclib_cache.py", line 142, in get_mask
    logits_mask_result = await self.delegate.compute_logits_mask(s.input_ids.reshape(1, -1), [s.user_data], constrained_seqs, [s], **kwargs, required=True)
  File "/home/name/miniconda3/envs/lmql/lib/python3.10/site-packages/lmql/runtime/dclib/dclib_model.py", line 87, in compute_logits_mask
    mask = await processor(seqs, additional_logits_processor_mask=is_constrained, **kwargs)
  File "/home/name/miniconda3/envs/lmql/lib/python3.10/site-packages/lmql/runtime/interpreter.py", line 671, in where_processor
    results = [(mask, user_data, max_tokens_hint) for mask, user_data, max_tokens_hint in await asyncio.gather(*token_mask_tasks)]        
  File "/home/name/miniconda3/envs/lmql/lib/python3.10/site-packages/lmql/runtime/interpreter.py", line 487, in where_for_sequence
    mask, logit_mask, state, max_tokens_hint = await self.where_step_for_sequence(s, needs_masking, seqidx, return_follow_map=return_follow_map, **kwargs)
  File "/home/name/miniconda3/envs/lmql/lib/python3.10/site-packages/lmql/runtime/interpreter.py", line 561, in where_step_for_sequence
    valid, is_final, trace, follow_trace = ops.digest(where,
  File "/home/name/miniconda3/envs/lmql/lib/python3.10/site-packages/lmql/ops/node.py", line 236, in digest
    op_follow_map = follow_apply(intm, op, value, context=follow_context)
  File "/home/name/miniconda3/envs/lmql/lib/python3.10/site-packages/lmql/ops/follow_map.py", line 226, in follow_apply
    result_map = result_map.intersect(pattern)
  File "/home/name/miniconda3/envs/lmql/lib/python3.10/site-packages/lmql/ops/follow_map.py", line 71, in intersect
    assert handled != "*", "Cannot intersect further patterns if '*' has already been handled."
AssertionError: Cannot intersect further patterns if '*' has already been handled.
@lbeurerkellner
Copy link
Collaborator

lbeurerkellner commented Dec 17, 2023

Hi there, we have not tested LMQL with Qwen models yet, so this may be an issue with supporting its tokenizer. I will have to investigate a bit further.

@lbeurerkellner lbeurerkellner added the enhancement New feature or request label Dec 17, 2023
@lbeurerkellner lbeurerkellner changed the title AssertionError: Cannot intersect further patterns if '*' has already been handled. [support Qwen models] AssertionError: Cannot intersect further patterns if '*' has already been handled. Dec 17, 2023
@lbeurerkellner
Copy link
Collaborator

lbeurerkellner commented Dec 17, 2023

You can fix the assertion by adding if p2 == "*": break after this line:

value_follow.add_all(result_map)

However, I could not get the model to do inference on my machine, since it never seems to finish a forward pass. Maybe you can try running with the change above, and report back with further results?

@justairr
Copy link
Author

You can fix the assertion by adding if p2 == "*": break after this line:

value_follow.add_all(result_map)

However, I could not get the model to do inference on my machine, since it never seems to finish a forward pass. Maybe you can try running with the change above, and report back with further results?

Thank you for your reply! I still couldn't fix the error by adding the code at the location you mentioned. But I can directly add this code in front of the assertion which is throwing the error to allow the program to continue running. Now, my new error message is as follows:

[Loading Qwen/Qwen-72B-Chat with AutoModelForCausalLM.from_pretrained("Qwen/Qwen-72B-Chat", trust_remote_code=True)]]
The model is automatically converting to bf16 for faster inference. If you want to disable the automatic precision, please manually add bf16/fp16/fp32=True to "AutoModelForCausalLM.from_pretrained".
Try importing flash-attention for faster inference...
Warning: import flash_attn rotary fail, please install FlashAttention rotary to get higher efficiency https://github.com/Dao-AILab/flash-attention/tree/main/csrc/rotary
Warning: import flash_attn rms_norm fail, please install FlashAttention layer_norm to get higher efficiency https://github.com/Dao-AILab/flash-attention/tree/main/csrc/layer_norm
Warning: import flash_attn fail, please install FlashAttention to get higher efficiency https://github.com/Dao-AILab/flash-attention
Loading checkpoint shards: 100%|█████████████████████████████████████████████████████████████████████████████████████████████████████| 82/82 [00:11<00:00,  7.04it/s]
[Qwen/Qwen-72B-Chat ready on device cpu]
/home/name/miniconda3/envs/lmql/lib/python3.10/site-packages/transformers/generation/configuration_utils.py:394: UserWarning: `do_sample` is set to `False`. However, `top_p` is set to `0.8` -- this flag is only used in sample-based generation modes. You should set `do_sample=True` or unset `top_p`.
  warnings.warn(
/home/name/miniconda3/envs/lmql/lib/python3.10/site-packages/transformers/generation/configuration_utils.py:404: UserWarning: `do_sample` is set to `False`. However, `top_k` is set to `0` -- this flag is only used in sample-based generation modes. You should set `do_sample=True` or unset `top_k`.
  warnings.warn(
[Error during generate()] expected scalar type c10::BFloat16 but found double
Traceback (most recent call last):
  File "/home/name/user1/lmql/lmql_test.py", line 19, in <module>
    print(prompt())
  File "/home/name/miniconda3/envs/lmql/lib/python3.10/site-packages/lmql/api/queries.py", line 148, in lmql_query_wrapper
    return module.query(*args, **kwargs)
  File "/home/name/miniconda3/envs/lmql/lib/python3.10/site-packages/lmql/runtime/lmql_runtime.py", line 204, in __call__
    return call_sync(self, *args, **kwargs)
  File "/home/name/miniconda3/envs/lmql/lib/python3.10/site-packages/lmql/runtime/loop.py", line 37, in call_sync
    res = loop.run_until_complete(task)
  File "/home/name/miniconda3/envs/lmql/lib/python3.10/asyncio/base_events.py", line 649, in run_until_complete
    return future.result()
  File "/home/name/miniconda3/envs/lmql/lib/python3.10/site-packages/lmql/runtime/lmql_runtime.py", line 230, in __acall__
    results = await interpreter.run(self.fct, **query_kwargs)
  File "/home/name/miniconda3/envs/lmql/lib/python3.10/site-packages/lmql/runtime/tracing/tracer.py", line 240, in wrapper
    return await fct(*args, **kwargs)
  File "/home/name/miniconda3/envs/lmql/lib/python3.10/site-packages/lmql/runtime/interpreter.py", line 1070, in run
    async for _ in decoder_fct(prompt, **decoder_args):
  File "/home/name/miniconda3/envs/lmql/lib/python3.10/site-packages/lmql/runtime/dclib/decoders.py", line 21, in argmax
    h = h.extend(await model.argmax(h, noscore=True))
  File "/home/name/miniconda3/envs/lmql/lib/python3.10/site-packages/lmql/runtime/dclib/dclib_cache.py", line 277, in argmax
    return await arr.aelement_wise(op_argmax)
  File "/home/name/miniconda3/envs/lmql/lib/python3.10/site-packages/lmql/runtime/dclib/dclib_array.py", line 318, in aelement_wise
    result_items = await asyncio.gather(*[op_with_path(path, seqs, *args, **kwargs) for path, seqs in self.sequences.items()])
  File "/home/name/miniconda3/envs/lmql/lib/python3.10/site-packages/lmql/runtime/dclib/dclib_array.py", line 317, in op_with_path
    return path, await op(element, *args, **kwargs)
  File "/home/name/miniconda3/envs/lmql/lib/python3.10/site-packages/lmql/runtime/dclib/dclib_cache.py", line 256, in op_argmax
    non_cached_argmax = iter((await self.delegate.argmax(DataArray(non_cached), **kwargs)).items())                
  File "/home/name/miniconda3/envs/lmql/lib/python3.10/site-packages/lmql/models/lmtp/lmtp_dcmodel.py", line 307, in argmax
    return await self.sample(sequences, temperature=0.0, **kwargs)
  File "/home/name/miniconda3/envs/lmql/lib/python3.10/site-packages/lmql/models/lmtp/lmtp_dcmodel.py", line 350, in sample
    return await sequences.aelement_wise(op_sample)
  File "/home/name/miniconda3/envs/lmql/lib/python3.10/site-packages/lmql/runtime/dclib/dclib_array.py", line 318, in aelement_wise
    result_items = await asyncio.gather(*[op_with_path(path, seqs, *args, **kwargs) for path, seqs in self.sequences.items()])
  File "/home/name/miniconda3/envs/lmql/lib/python3.10/site-packages/lmql/runtime/dclib/dclib_array.py", line 317, in op_with_path
    return path, await op(element, *args, **kwargs)
  File "/home/name/miniconda3/envs/lmql/lib/python3.10/site-packages/lmql/models/lmtp/lmtp_dcmodel.py", line 340, in op_sample
    tokens = await asyncio.gather(*[self.stream_and_return_first(s, await self.generate(s, temperature=temperature, **kwargs), mode) for s,mode in zip(seqs, unique_sampling_mode)])
  File "/home/name/miniconda3/envs/lmql/lib/python3.10/site-packages/lmql/models/lmtp/lmtp_dcmodel.py", line 147, in stream_and_return_first
    buffer += [await anext(iterator)]
  File "/home/name/miniconda3/envs/lmql/lib/python3.10/site-packages/lmql/models/lmtp/lmtp_multiprocessing.py", line 188, in generate
    async for token in self.stream_iterator(self.stream_id):
  File "/home/name/miniconda3/envs/lmql/lib/python3.10/site-packages/lmql/models/lmtp/lmtp_multiprocessing.py", line 217, in stream_iterator
    raise LMTPStreamError(item["error"])
lmql.models.lmtp.errors.LMTPStreamError: failed to generate tokens 'expected scalar type c10::BFloat16 but found double'
Task was destroyed but it is pending!
task: <Task cancelling name='lmtp_inprocess_client_loop' coro=<LMTPDcModel.inprocess_client_loop() running at /home/name/miniconda3/envs/lmql/lib/python3.10/site-packages/lmql/models/lmtp/lmtp_dcmodel.py:76> wait_for=<Future finished result=True>>

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
enhancement New feature or request
Projects
None yet
Development

No branches or pull requests

2 participants