Skip to content

Commit

Permalink
Tools: fix citations (cohere-ai#108)
Browse files Browse the repository at this point in the history
  • Loading branch information
lusmoura authored and lakshyaag committed May 8, 2024
1 parent 0afe494 commit 6ce0e5e
Show file tree
Hide file tree
Showing 6 changed files with 90 additions and 4 deletions.
32 changes: 32 additions & 0 deletions src/backend/alembic/versions/6553b76de6ca_.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,32 @@
"""empty message
Revision ID: 6553b76de6ca
Revises: 2853273872ca
Create Date: 2024-05-07 11:23:19.581035
"""

from typing import Sequence, Union

import sqlalchemy as sa
from alembic import op

# revision identifiers, used by Alembic.
revision: str = "6553b76de6ca"
down_revision: Union[str, None] = "2853273872ca"
branch_labels: Union[str, Sequence[str], None] = None
depends_on: Union[str, Sequence[str], None] = None


def upgrade() -> None:
# ### commands auto generated by Alembic - please adjust! ###
op.add_column("documents", sa.Column("fields", sa.JSON(), nullable=True))
op.add_column("documents", sa.Column("tool_name", sa.String(), nullable=True))
# ### end Alembic commands ###


def downgrade() -> None:
# ### commands auto generated by Alembic - please adjust! ###
op.drop_column("documents", "tool_name")
op.drop_column("documents", "fields")
# ### end Alembic commands ###
7 changes: 6 additions & 1 deletion src/backend/chat/custom/custom.py
Original file line number Diff line number Diff line change
Expand Up @@ -158,6 +158,11 @@ def get_tool_results(
outputs = tool.implementation().call(
parameters=tool_call.parameters,
)
tool_results.append({"call": tool_call, "outputs": [outputs]})

# If the tool returns a list of outputs, append each output to the tool_results list
# Otherwise, append the single output to the tool_results list
outputs = outputs if isinstance(outputs, list) else [outputs]
for output in outputs:
tool_results.append({"call": tool_call, "outputs": [output]})

return tool_results
4 changes: 3 additions & 1 deletion src/backend/models/document.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from sqlalchemy import ForeignKey, Index, String
from sqlalchemy import JSON, ForeignKey, Index, String
from sqlalchemy.orm import Mapped, mapped_column

from backend.models.base import Base
Expand All @@ -12,6 +12,8 @@ class Document(Base):
user_id: Mapped[str] = mapped_column(String)
title: Mapped[str] = mapped_column(String, nullable=True)
url: Mapped[str] = mapped_column(String, nullable=True)
fields: Mapped[dict] = mapped_column(JSON, nullable=True)
tool_name: Mapped[str] = mapped_column(String, nullable=True)

conversation_id: Mapped[str] = mapped_column(
ForeignKey("conversations.id", ondelete="CASCADE")
Expand Down
7 changes: 7 additions & 0 deletions src/backend/routers/chat.py
Original file line number Diff line number Diff line change
Expand Up @@ -512,6 +512,13 @@ def generate_chat_stream(
text=document.get("text", ""),
title=document.get("title", ""),
url=document.get("url", ""),
tool_name=document.get("tool_name", ""),
# all document fields except for id, tool_name and text
fields={
k: v
for k, v in document.items()
if k not in ["id", "tool_name", "text"]
},
user_id=response_message.user_id,
conversation_id=response_message.conversation_id,
message_id=response_message.id,
Expand Down
2 changes: 2 additions & 0 deletions src/backend/schemas/document.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,8 @@ class Document(BaseModel):

title: Union[str, None]
url: Union[str, None]
fields: Union[dict, None]
tool_name: Union[str, None]

class Config:
from_attributes = True
42 changes: 40 additions & 2 deletions src/backend/tools/function_tools/python_interpreter.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import json
import os
from typing import Any
from typing import Any, Dict, Mapping

import requests
from langchain_core.tools import Tool as LangchainTool
Expand Down Expand Up @@ -30,8 +31,45 @@ def call(self, parameters: dict, **kwargs: Any):

code = parameters.get("code", "")
res = requests.post(self.interpreter_url, json={"code": code})
clean_res = self._clean_response(res.json())

return res.json()
return clean_res

def _clean_response(self, result: Any) -> Dict[str, str]:
if "final_expression" in result:
result["final_expression"] = str(result["final_expression"])

# split up output files into separate result items, so that we may cite them individually
result_list = [result]

output_files = result.pop("output_files", [])
for f in output_files:
result_list.append({"output_file": f})

for r in result_list:
if r.get("sucess") is not None:
r.update({"success": r.get("sucess")})
del r["sucess"]

if r.get("success") is True:
r.setdefault("text", r.get("std_out"))
elif r.get("success") is False:
error_message = r.get("error", {}).get("message", "")
r.setdefault("text", error_message)
elif r.get("output_file") and r.get("output_file").get("filename"):
if r["output_file"]["filename"] != "":
r.setdefault(
"text", f"Created output file {r['output_file']['filename']}"
)

# cast all values to strings, if it's a json object use double quotes
for key, value in r.items():
if isinstance(value, Mapping):
r[key] = json.dumps(value)
else:
r[key] = str(value)

return result_list

# langchain does not return a dict as a parameter, only a code string
def langchain_call(self, code: str):
Expand Down

0 comments on commit 6ce0e5e

Please sign in to comment.