-
Notifications
You must be signed in to change notification settings - Fork 2
/
main.py
98 lines (83 loc) · 3.16 KB
/
main.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
import csv
import json
import os
import box
import pandas as pd
import yaml
from langchain.callbacks import get_openai_callback
from langchain.prompts import ChatPromptTemplate, MessagesPlaceholder
from langchain.tools.render import format_tool_to_openai_function
from tenacity import retry, stop_after_attempt, wait_fixed
from src.agents import create_agent_executor
from src.llm import llm
from src.tools import wikipedia_tool
from src.prompts import system_prompt, generate_input_prompt
from src.utils import default_values
with open("config/config.yaml", "r", encoding="utf8") as ymlfile:
cfg = box.Box(yaml.safe_load(ymlfile))
# Define tools
tools = [wikipedia_tool]
llm_with_tools = llm.bind(functions=[format_tool_to_openai_function(t) for t in tools])
# Define prompt
prompt = ChatPromptTemplate.from_messages(
[
("system", system_prompt),
("user", "{input}"),
MessagesPlaceholder(variable_name="agent_scratchpad"),
]
)
agent_executor = create_agent_executor(
prompt=prompt, llm_with_tools=llm_with_tools, tools=tools
)
@retry(
stop=stop_after_attempt(2), wait=wait_fixed(10), retry_error_callback=default_values
)
def execute_web_scraping(
input_file_path: str = cfg.INPUT_FILE, output_file_path: str = cfg.OUTPUT_FILE
):
df = pd.read_csv(input_file_path)
for _, row in df.iterrows():
song, artist = row["song"], row["artist"]
if os.path.exists(output_file_path):
df_song_info = pd.read_csv(output_file_path, encoding="utf-8")
else:
df_song_info = pd.DataFrame(
columns=[
"artist",
"song",
"genre",
"label",
"language",
"llm_cost",
"llm_tokens",
"producers",
"songwriters",
]
)
df_song_info.to_csv(output_file_path, index=False)
if song not in df_song_info["song"].tolist():
print(f"***** Processing: {song} by {artist} *****")
input_prompt = generate_input_prompt(song, artist)
with get_openai_callback() as cb:
response = agent_executor.invoke({"input": input_prompt})
cost = cb.total_cost
tokens = cb.total_tokens
output = response["output"]
print(output)
output_dict = json.loads(output)
new_row = {
"artist": artist,
"song": song,
"genre": output_dict.get("genre"),
"label": output_dict.get("label"),
"language": output_dict.get("language"),
"llm_cost": cost,
"llm_tokens": tokens,
"producers": output_dict.get("producers"),
"songwriters": output_dict.get("songwriters"),
}
with open(output_file_path, "a", newline="", encoding="utf-8") as file:
writer = csv.writer(file)
writer.writerow(new_row.values())
if __name__ == "__main__":
execute_web_scraping()