Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

BaseCallbackHandler is working fine with ChatOpenAI but raising error when we use ChatGoogleGenerativeAI LLM #277

Open
Adil-Ashraf opened this issue Feb 26, 2024 · 0 comments

Comments

@Adil-Ashraf
Copy link

this is my code for openai llm. And it is working fine

from fastapi import FastAPI, Request, APIRouter
import pdb
import cohere
from fastapi.responses import StreamingResponse
from langchain.schema.messages import HumanMessage, AIMessage
from collections.abc import Generator
from queue import Queue, Empty
from langchain.chat_models import ChatAnthropic, ChatOpenAI
from chatbot.app.utils.chain import QueueCallback, create_chain
from chatbot.app.utils.retriever import get_retriever, create_retriever_chain
from threading import Thread
from chatbot.app.settings import COHERE_API_KEY

router = APIRouter()

@router.post("/chat")
async def chat_endpoint(request: Request):
data = await request.json()
question = data.get("message")
chat_history = data.get("history", [])
converted_chat_history = []
for message in chat_history:
if message.get("human") is not None:
converted_chat_history.append(HumanMessage(content=message["human"]))
if message.get("ai") is not None:
converted_chat_history.append(AIMessage(content=message["ai"]))
data.get("conversation_id")

cohere_client = cohere.Client(COHERE_API_KEY)

def stream() -> Generator:
q = Queue()
job_done = object()

llm = ChatOpenAI(
  model="gpt-3.5-turbo-16k",
  streaming=True,
  temperature=0,
  callbacks=[QueueCallback(q)],
)

llm_without_callback = ChatOpenAI(
  model="gpt-3.5-turbo-16k",
  streaming=True,
  temperature=0,
)

def task():
  retriever_chain = create_retriever_chain(
    chat_history, llm_without_callback, get_retriever()
  )
  chain = create_chain(llm, retriever_chain)
  docs = retriever_chain.invoke(
    {"question": question, "chat_history": chat_history},
  )
  documents_for_rerank = [{"text": doc.page_content} for doc in docs]

  # Perform reranking using Cohere
  results = cohere_client.rerank(
    query=question,
    documents=documents_for_rerank,
    top_n=3,
    model='rerank-english-v2.0'
  )

  # Extract texts from the rerank results with a relevance score greater than 0.70
  high_score_texts = {
    result.document['text'] for result in results
    if result.relevance_score > 0.70
  }

  # Construct the filtered_docs list by checking the page_content against high_score_texts
  filtered_docs = [doc for doc in docs if doc.page_content in high_score_texts]

  url_set = set()
  if filtered_docs:
    for doc in filtered_docs:
      if doc.metadata["source"] in url_set:
        continue
      q.put(doc.metadata["title"] + ":" + doc.metadata["source"] + "\n")
      url_set.add(doc.metadata["source"])

  q.put("SOURCES:----------------------------")

  chain.invoke(
    {
      "question": question,
      "chat_history": converted_chat_history,
      "context": docs,
    },
  )
  q.put(job_done)

t = Thread(target=task)
t.start()

content = ""

while True:
  try:
    next_token = q.get(True, timeout=1)
    if next_token is job_done:
      break
    content += next_token
    yield next_token
  except Empty:
      continue

return StreamingResponse(stream())
this is the QueueCallback class
from langchain.callbacks.base import BaseCallbackHandler

class QueueCallback(BaseCallbackHandler):
def init(self, q):
self.q = q

def on_llm_new_token(self, token: str, **kwargs: any) -> None:
self.q.put(token)

def on_llm_end(self, *args, **kwargs: any) -> None:
return self.q.empty()

but when i use it ChatGoogleGenerativeAI it raise error.Below is the ChatGoogleGenerativeAI code.

from fastapi import FastAPI, Request, APIRouter
import pdb
import cohere
from fastapi.responses import StreamingResponse
from langchain.schema.messages import HumanMessage, AIMessage
from collections.abc import Generator
from queue import Queue, Empty
from langchain.chat_models import ChatAnthropic, ChatOpenAI
from chatbot.app.utils.chain import create_chain
from chatbot.app.utils.retriever import get_retriever, create_retriever_chain
from threading import Thread

from langchain_google_genai import ChatGoogleGenerativeAI
from chatbot.app.settings import COHERE_API_KEY, GOOGLE_API_KEY
from langchain.callbacks.manager import BaseCallbackManager

router = APIRouter()

class QueueCallback(BaseCallbackManager):
def init(self, q):
self.q = q

def on_llm_new_token(self, token: str, **kwargs: any) -> None:
self.q.put(token)

def on_llm_end(self, *args, **kwargs: any) -> None:
return self.q.empty()

@router.post("/chat")
async def chat_endpoint(request: Request):
data = await request.json()
question = data.get("message")
chat_history = data.get("history", [])
converted_chat_history = []
for message in chat_history:
if message.get("human") is not None:
converted_chat_history.append(HumanMessage(content=message["human"]))
if message.get("ai") is not None:
converted_chat_history.append(AIMessage(content=message["ai"]))
data.get("conversation_id")

cohere_client = cohere.Client(COHERE_API_KEY)

def stream() -> Generator:
q = Queue()
job_done = object()
# callback_manager = BaseCallbackManager([])
# callback_manager.add_callback(QueueCallback(q))

llm = ChatGoogleGenerativeAI(
  model="gemini-pro",
  streaming=True,
  temperature=0,
  google_api_key=GOOGLE_API_KEY,
  callbacks=[QueueCallback(q)],
)

llm_without_callback = ChatGoogleGenerativeAI(
  model="gemini-pro",
  streaming=True,
  temperature=0,
  google_api_key=GOOGLE_API_KEY,
)

def task():
  retriever_chain = create_retriever_chain(
    chat_history, llm_without_callback, get_retriever()
  )
  chain = create_chain(llm, retriever_chain)
  docs = retriever_chain.invoke(
    {"question": question, "chat_history": chat_history},
  )
  documents_for_rerank = [{"text": doc.page_content} for doc in docs]

  # Perform reranking using Cohere
  results = cohere_client.rerank(
    query=question,
    documents=documents_for_rerank,
    top_n=3,
    model='rerank-english-v2.0'
  )

  # Extract texts from the rerank results with a relevance score greater than 0.70
  high_score_texts = {
    result.document['text'] for result in results
    if result.relevance_score > 0.70
  }

  # Construct the filtered_docs list by checking the page_content against high_score_texts
  filtered_docs = [doc for doc in docs if doc.page_content in high_score_texts]

  url_set = set()
  if filtered_docs:
    for doc in filtered_docs:
      if doc.metadata["source"] in url_set:
        continue
      q.put(doc.metadata["title"] + ":" + doc.metadata["source"] + "\n")
      url_set.add(doc.metadata["source"])

  q.put("SOURCES:----------------------------")

  chain.invoke(
    {
      "question": question,
      "chat_history": converted_chat_history,
      "context": docs,
    },
  )
  q.put(job_done)

t = Thread(target=task)
t.start()

content = ""

while True:
  try:
    next_token = q.get(True, timeout=1)
    if next_token is job_done:
      break
    content += next_token
    yield next_token
  except Empty:
      continue

return StreamingResponse(stream())
below is the error that I got
Exception in ASGI application
Traceback (most recent call last):
File "/home/adil/.cache/pypoetry/virtualenvs/chatbot-L7SgJIPD-py3.10/lib/python3.10/site-packages/uvicorn/protocols/http/h11_impl.py", line 408, in run_asgi
result = await app( # type: ignore[func-returns-value]
File "/home/adil/.cache/pypoetry/virtualenvs/chatbot-L7SgJIPD-py3.10/lib/python3.10/site-packages/uvicorn/middleware/proxy_headers.py", line 84, in call
return await self.app(scope, receive, send)
File "/home/adil/.cache/pypoetry/virtualenvs/chatbot-L7SgJIPD-py3.10/lib/python3.10/site-packages/fastapi/applications.py", line 292, in call
await super().call(scope, receive, send)
File "/home/adil/.cache/pypoetry/virtualenvs/chatbot-L7SgJIPD-py3.10/lib/python3.10/site-packages/starlette/applications.py", line 122, in call
await self.middleware_stack(scope, receive, send)
File "/home/adil/.cache/pypoetry/virtualenvs/chatbot-L7SgJIPD-py3.10/lib/python3.10/site-packages/starlette/middleware/errors.py", line 184, in call
raise exc
File "/home/adil/.cache/pypoetry/virtualenvs/chatbot-L7SgJIPD-py3.10/lib/python3.10/site-packages/starlette/middleware/errors.py", line 162, in call
await self.app(scope, receive, _send)
File "/home/adil/.cache/pypoetry/virtualenvs/chatbot-L7SgJIPD-py3.10/lib/python3.10/site-packages/starlette/middleware/cors.py", line 91, in call
await self.simple_response(scope, receive, send, request_headers=headers)
File "/home/adil/.cache/pypoetry/virtualenvs/chatbot-L7SgJIPD-py3.10/lib/python3.10/site-packages/starlette/middleware/cors.py", line 146, in simple_response
await self.app(scope, receive, send)
File "/home/adil/.cache/pypoetry/virtualenvs/chatbot-L7SgJIPD-py3.10/lib/python3.10/site-packages/starlette/middleware/exceptions.py", line 79, in call
raise exc
File "/home/adil/.cache/pypoetry/virtualenvs/chatbot-L7SgJIPD-py3.10/lib/python3.10/site-packages/starlette/middleware/exceptions.py", line 68, in call
await self.app(scope, receive, sender)
File "/home/adil/.cache/pypoetry/virtualenvs/chatbot-L7SgJIPD-py3.10/lib/python3.10/site-packages/fastapi/middleware/asyncexitstack.py", line 20, in call
raise e
File "/home/adil/.cache/pypoetry/virtualenvs/chatbot-L7SgJIPD-py3.10/lib/python3.10/site-packages/fastapi/middleware/asyncexitstack.py", line 17, in call
await self.app(scope, receive, send)
File "/home/adil/.cache/pypoetry/virtualenvs/chatbot-L7SgJIPD-py3.10/lib/python3.10/site-packages/starlette/routing.py", line 718, in call
await route.handle(scope, receive, send)
File "/home/adil/.cache/pypoetry/virtualenvs/chatbot-L7SgJIPD-py3.10/lib/python3.10/site-packages/starlette/routing.py", line 276, in handle
await self.app(scope, receive, send)
File "/home/adil/.cache/pypoetry/virtualenvs/chatbot-L7SgJIPD-py3.10/lib/python3.10/site-packages/starlette/routing.py", line 69, in app
await response(scope, receive, send)
File "/home/adil/.cache/pypoetry/virtualenvs/chatbot-L7SgJIPD-py3.10/lib/python3.10/site-packages/starlette/responses.py", line 270, in call
async with anyio.create_task_group() as task_group:
File "/home/adil/.cache/pypoetry/virtualenvs/chatbot-L7SgJIPD-py3.10/lib/python3.10/site-packages/anyio/_backends/_asyncio.py", line 597, in aexit
raise exceptions[0]
File "/home/adil/.cache/pypoetry/virtualenvs/chatbot-L7SgJIPD-py3.10/lib/python3.10/site-packages/starlette/responses.py", line 273, in wrap
await func()
File "/home/adil/.cache/pypoetry/virtualenvs/chatbot-L7SgJIPD-py3.10/lib/python3.10/site-packages/starlette/responses.py", line 262, in stream_response
async for chunk in self.body_iterator:
File "/home/adil/.cache/pypoetry/virtualenvs/chatbot-L7SgJIPD-py3.10/lib/python3.10/site-packages/starlette/concurrency.py", line 63, in iterate_in_threadpool
yield await anyio.to_thread.run_sync(_next, iterator)
File "/home/adil/.cache/pypoetry/virtualenvs/chatbot-L7SgJIPD-py3.10/lib/python3.10/site-packages/anyio/to_thread.py", line 33, in run_sync
return await get_asynclib().run_sync_in_worker_thread(
File "/home/adil/.cache/pypoetry/virtualenvs/chatbot-L7SgJIPD-py3.10/lib/python3.10/site-packages/anyio/_backends/_asyncio.py", line 877, in run_sync_in_worker_thread
return await future
File "/home/adil/.cache/pypoetry/virtualenvs/chatbot-L7SgJIPD-py3.10/lib/python3.10/site-packages/anyio/_backends/_asyncio.py", line 807, in run
result = context.run(func, *args)
File "/home/adil/.cache/pypoetry/virtualenvs/chatbot-L7SgJIPD-py3.10/lib/python3.10/site-packages/starlette/concurrency.py", line 53, in _next
return next(iterator)
File "/home/adil/Documents/Devbox/chat-langchain/chatbot/app/routers/chat.py", line 53, in stream
llm = ChatGoogleGenerativeAI(
File "/home/adil/.cache/pypoetry/virtualenvs/chatbot-L7SgJIPD-py3.10/lib/python3.10/site-packages/langchain_core/load/serializable.py", line 107, in init
super().init(**kwargs)
File "/home/adil/.cache/pypoetry/virtualenvs/chatbot-L7SgJIPD-py3.10/lib/python3.10/site-packages/pydantic/v1/main.py", line 341, in init
raise validation_error
pydantic.v1.error_wrappers.ValidationError: 2 validation errors for ChatGoogleGenerativeAI
callbacks -> 0
instance of BaseCallbackHandler expected (type=type_error.arbitrary_type; expected_arbitrary_type=BaseCallbackHandler)
callbacks
instance of BaseCallbackManager expected (type=type_error.arbitrary_type; expected_arbitrary_type=BaseCallbackManager)
I have tried to use
from langchain.callbacks.base import BaseCallbackHandler
but didn't found it helpful. How can i fix it?

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

1 participant