Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Streaming Support for Nvidia's Triton Integration #13135

Merged
merged 8 commits into from
Apr 30, 2024
Merged
Show file tree
Hide file tree
Changes from 6 commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
28 changes: 28 additions & 0 deletions docs/docs/examples/llm/nvidia_triton.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -144,6 +144,34 @@
"```\n"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"#### Call `stream_complete` with a prompt"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"```python\n",
"resp = NvidiaTriton(server_url=triton_url, model_name=model_name, tokens=32).stream_complete(\"The tallest mountain in North America is \")\n",
"for delta in resp:\n",
" print(delta.delta, end=\" \")\n",
"```"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"You should expect the following response as a stream\n",
"```\n",
"the Great Pyramid of Giza, which is about 1,000 feet high. The Great Pyramid of Giza is the tallest mountain in North America.\n",
"```\n"
]
},
{
"cell_type": "markdown",
"metadata": {},
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,8 @@
from llama_index.core.llms.callbacks import llm_chat_callback
from llama_index.core.base.llms.generic_utils import (
completion_to_chat_decorator,
stream_completion_to_chat_decorator

)
from llama_index.core.llms.llm import LLM
from llama_index.llms.nvidia_triton.utils import GrpcTritonClient
Expand Down Expand Up @@ -235,11 +237,13 @@ def metadata(self) -> LLMMetadata:
def chat(self, messages: Sequence[ChatMessage], **kwargs: Any) -> ChatResponse:
chat_fn = completion_to_chat_decorator(self.complete)
return chat_fn(messages, **kwargs)


@llm_chat_callback()
def stream_chat(
self, messages: Sequence[ChatMessage], **kwargs: Any
) -> ChatResponseGen:
raise NotImplementedError
chat_stream_fn = stream_completion_to_chat_decorator(self.stream_complete)
return chat_stream_fn(messages, **kwargs)

def complete(
self, prompt: str, formatted: bool = False, **kwargs: Any
Expand All @@ -266,7 +270,7 @@ def complete(
if isinstance(token, InferenceServerException):
client.stop_stream(model_params["model_name"], request_id)
raise token
response = response + token
response += token

return CompletionResponse(
text=response,
Expand All @@ -275,7 +279,34 @@ def complete(
def stream_complete(
self, prompt: str, formatted: bool = False, **kwargs: Any
) -> CompletionResponseGen:
raise NotImplementedError
from tritonclient.utils import InferenceServerException

client = self._get_client()

invocation_params = self._get_model_default_parameters
invocation_params.update(kwargs)
invocation_params["prompt"] = [[prompt]]
model_params = self._identifying_params
model_params.update(kwargs)
request_id = str(random.randint(1, 9999999)) # nosec

if self.triton_load_model_call:
client.load_model(model_params["model_name"])

result_queue = client.request_streaming(
model_params["model_name"], request_id, **invocation_params
)

def gen() -> CompletionResponseGen:
text = ""
for token in result_queue:
if isinstance(token, InferenceServerException):
client.stop_stream(model_params["model_name"], request_id)
raise token
text += token
yield CompletionResponse(text=text, delta=token)

return gen()

async def achat(
self, messages: Sequence[ChatMessage], **kwargs: Any
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -21,13 +21,13 @@ ignore_missing_imports = true
python_version = "3.8"

[tool.poetry]
authors = ["Your Name <you@example.com>"]
authors = ["Rohith Ramakrishnan <rrohith2001@gmail.com>"]
description = "llama-index llms nvidia triton integration"
exclude = ["**/BUILD"]
license = "MIT"
name = "llama-index-llms-nvidia-triton"
readme = "README.md"
version = "0.1.4"
version = "0.1.5"

[tool.poetry.dependencies]
python = ">=3.8.1,<4.0"
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
python_tests()
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
from llama_index.core.base.llms.base import BaseLLM
from llama_index.llms.nvidia_triton import NvidiaTriton


def test_text_inference_embedding_class():
names_of_base_classes = [b.__name__ for b in NvidiaTriton.__mro__]
assert BaseLLM.__name__ in names_of_base_classes