Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Default rag-conversation package does not generate follow-up answers in "chat" playground_type. #592

Open
jo3p opened this issue Apr 3, 2024 · 0 comments

Comments

@jo3p
Copy link

jo3p commented Apr 3, 2024

Hi all,

I've taken the default rag-conversation example (see source code here) and modified the retriever slightly to use Azure AI search. The vector store contains a synthetic dataset filled with data about disturbances in a production factory.

The code is shown below

server.py:

from fastapi import FastAPI
from fastapi.responses import RedirectResponse
from langserve import add_routes
from rag_conversation import chain as rag_conversation_chain

app = FastAPI()


@app.get("/")
async def redirect_root_to_docs():
    return RedirectResponse("/docs")


add_routes(app, rag_conversation_chain, path="/rag-conversation", playground_type="chat")


if __name__ == "__main__":
    import uvicorn

    uvicorn.run(app, host="0.0.0.0", port=8000)

chain.py

import os
from operator import itemgetter
from typing import List, Tuple

from langchain_community.vectorstores.azuresearch import AzureSearch
from langchain_openai import AzureChatOpenAI, AzureOpenAIEmbeddings
from langchain_core.messages import AIMessage, HumanMessage
from langchain_core.output_parsers import StrOutputParser
from langchain_core.prompts import (
    ChatPromptTemplate,
    MessagesPlaceholder,
    format_document,
)
from langchain_core.prompts.prompt import PromptTemplate
from langchain_core.pydantic_v1 import BaseModel, Field
from langchain_core.runnables import (
    RunnableBranch,
    RunnableLambda,
    RunnableParallel,
    RunnablePassthrough,
)

from src.utils import AzureAISearchConfig, AzureOpenAIConfig

# Load configurations
azure_ais_conf = AzureAISearchConfig.from_yaml("/some/config/location")
azure_oai_conf = AzureOpenAIConfig.from_yaml("/some/config/location")

embeddings = AzureOpenAIEmbeddings(
    azure_deployment=azure_oai_conf.embedding_model,
    openai_api_version=azure_oai_conf.api_version,
    azure_endpoint=azure_oai_conf.endpoint,
)
llm = AzureChatOpenAI(
    azure_deployment=azure_oai_conf.chat_model,
    openai_api_version=azure_oai_conf.api_version,
    azure_endpoint=azure_oai_conf.endpoint,
)
vectorstore = AzureSearch(
    azure_search_endpoint=azure_ais_conf.endpoint,
    azure_search_key=os.environ["AZURE_SEARCH_KEY"],
    index_name="langchain-vector-dummy",
    embedding_function=embeddings.embed_query,
)
retriever = vectorstore.as_retriever()

# Condense a chat history and follow-up question into a standalone question
_template = """Given the following conversation and a follow up question, rephrase the follow up question to be a standalone question, in its original language.
Chat History:
{chat_history}
Follow Up Input: {question}
Standalone question:"""  # noqa: E501
CONDENSE_QUESTION_PROMPT = PromptTemplate.from_template(_template)

# RAG answer synthesis prompt
template = """Answer the question based only on the following context:
<context>
{context}
</context>"""
ANSWER_PROMPT = ChatPromptTemplate.from_messages(
    [
        ("system", template),
        MessagesPlaceholder(variable_name="chat_history"),
        ("user", "{question}"),
    ]
)

# Conversational Retrieval Chain
DEFAULT_DOCUMENT_PROMPT = PromptTemplate.from_template(template="{page_content}")


def _combine_documents(docs, document_prompt=DEFAULT_DOCUMENT_PROMPT, document_separator="\n\n"):
    doc_strings = [format_document(doc, document_prompt) for doc in docs]
    return document_separator.join(doc_strings)


def _format_chat_history(chat_history: List[Tuple[str, str]]) -> List:
    buffer = []
    for human, ai in chat_history:
        buffer.append(HumanMessage(content=human))
        buffer.append(AIMessage(content=ai))
    return buffer


# User input
class ChatHistory(BaseModel):
    chat_history: List[Tuple[str, str]] = Field(..., extra={"widget": {"type": "chat"}})
    question: str


_search_query = RunnableBranch(
    # If input includes chat_history, we condense it with the follow-up question
    (
        RunnableLambda(lambda x: bool(x.get("chat_history"))).with_config(
            run_name="HasChatHistoryCheck"
        ),  # Condense follow-up question and chat into a standalone_question
        RunnablePassthrough.assign(chat_history=lambda x: _format_chat_history(x["chat_history"]))
        | CONDENSE_QUESTION_PROMPT
        | llm
        | StrOutputParser(),
    ),
    # Else, we have no chat history, so just pass through the question
    RunnableLambda(itemgetter("question")),
)

_inputs = RunnableParallel(
    {
        "question": lambda x: x["question"],
        "chat_history": lambda x: _format_chat_history(x["chat_history"]),
        "context": _search_query | retriever | _combine_documents,
    }
).with_types(input_type=ChatHistory)

chain = _inputs | ANSWER_PROMPT | llm | StrOutputParser()

When I run the code with playground_type="default" and define the chat history as follows, I get the following ouput:
image

When I run the code with playground_type="chat", I get no output after the first question.
image

I tried to do a further analysis and inspected the network calls. They look slightly different:
The network calls for the default playground look like this:
image
The network calls for the chat playground look like this:
image

What could be going on here?

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

No branches or pull requests

2 participants