Spaces:
Running
Running
Update main.py
Browse files
main.py
CHANGED
@@ -54,45 +54,47 @@ sys_prompts = {
|
|
54 |
},
|
55 |
}
|
56 |
|
57 |
-
class
|
58 |
-
topic: str = Query(default="market research", description="input query to generate
|
59 |
description: str = Query(default="", description="additional context for report")
|
60 |
user_id: str = Query(default="", description="unique user id")
|
61 |
user_name: str = Query(default="", description="user name")
|
62 |
internet: bool = Query(default=True, description="Enable Internet search")
|
63 |
-
output_format: str = Query(default="Tabular Report", description="Output format for the report",
|
64 |
-
|
|
|
|
|
65 |
|
66 |
-
@
|
67 |
-
|
68 |
-
|
69 |
-
|
70 |
-
|
71 |
-
|
|
|
|
|
72 |
optimized_search_query = ""
|
73 |
-
all_text_with_urls = [("","")]
|
74 |
|
75 |
-
|
76 |
-
|
77 |
-
search_query = params.description
|
78 |
try:
|
79 |
urls, optimized_search_query = search_brave(search_query, num_results=4)
|
80 |
-
all_text_with_urls = fetch_and_extract_content(
|
81 |
additional_context = limit_tokens(str(all_text_with_urls))
|
82 |
-
prompt = f"#### COMPLETE THE TASK: {
|
83 |
except Exception as e:
|
84 |
-
|
85 |
print("failed to search/scrape results, falling back to LLM response")
|
86 |
|
87 |
-
if not
|
88 |
-
prompt = f"#### COMPLETE THE TASK: {
|
89 |
|
90 |
md_report = together_response(prompt, model=llm_default_medium, SysPrompt=sys_prompt_output_format)
|
91 |
|
92 |
-
if
|
93 |
-
insert_data(
|
94 |
-
|
95 |
-
references_html = dict()
|
96 |
for text, url in all_text_with_urls:
|
97 |
references_html[url] = str(md_to_html(text))
|
98 |
|
@@ -101,6 +103,11 @@ async def generate_report(params: ReportParams):
|
|
101 |
"references": references_html,
|
102 |
"search_query": optimized_search_query
|
103 |
}
|
|
|
|
|
|
|
|
|
|
|
104 |
app.add_middleware(
|
105 |
CORSMiddleware,
|
106 |
allow_origins=["*"],
|
|
|
54 |
},
|
55 |
}
|
56 |
|
57 |
+
class QueryModel(BaseModel):
|
58 |
+
topic: str = Query(default="market research", description="input query to generate Report")
|
59 |
description: str = Query(default="", description="additional context for report")
|
60 |
user_id: str = Query(default="", description="unique user id")
|
61 |
user_name: str = Query(default="", description="user name")
|
62 |
internet: bool = Query(default=True, description="Enable Internet search")
|
63 |
+
output_format: str = Query(default="Tabular Report", description="Output format for the report",
|
64 |
+
enum=["Chat", "Full Text Report", "Tabular Report", "Tables only"])
|
65 |
+
data_format: str = Query(default="Structured data", description="Type of data to extract from the internet",
|
66 |
+
enum=["No presets", "Structured data", "Quantitative data"])
|
67 |
|
68 |
+
@cache(expire=604800)
|
69 |
+
async def generate_report(query: QueryModel):
|
70 |
+
query_str = query.topic
|
71 |
+
description = query.description
|
72 |
+
user_id = query.user_id
|
73 |
+
internet = "online" if query.internet else "offline"
|
74 |
+
sys_prompt_output_format = sys_prompts[internet][query.output_format]
|
75 |
+
data_format = query.data_format
|
76 |
optimized_search_query = ""
|
77 |
+
all_text_with_urls = [("", "")]
|
78 |
|
79 |
+
if query.internet:
|
80 |
+
search_query = description
|
|
|
81 |
try:
|
82 |
urls, optimized_search_query = search_brave(search_query, num_results=4)
|
83 |
+
all_text_with_urls = fetch_and_extract_content(data_format, urls, query_str)
|
84 |
additional_context = limit_tokens(str(all_text_with_urls))
|
85 |
+
prompt = f"#### COMPLETE THE TASK: {description} #### IN THE CONTEXT OF ### CONTEXT: {query_str} USING THE #### SCRAPED DATA:{additional_context}"
|
86 |
except Exception as e:
|
87 |
+
query.internet = False
|
88 |
print("failed to search/scrape results, falling back to LLM response")
|
89 |
|
90 |
+
if not query.internet:
|
91 |
+
prompt = f"#### COMPLETE THE TASK: {description} #### IN THE CONTEXT OF ### CONTEXT: {query_str}"
|
92 |
|
93 |
md_report = together_response(prompt, model=llm_default_medium, SysPrompt=sys_prompt_output_format)
|
94 |
|
95 |
+
if user_id != "test":
|
96 |
+
insert_data(user_id, query_str, description, str(all_text_with_urls), md_report)
|
97 |
+
references_html = {}
|
|
|
98 |
for text, url in all_text_with_urls:
|
99 |
references_html[url] = str(md_to_html(text))
|
100 |
|
|
|
103 |
"references": references_html,
|
104 |
"search_query": optimized_search_query
|
105 |
}
|
106 |
+
|
107 |
+
@app.post("/generate_report")
|
108 |
+
async def api_generate_report(request: Request, query: QueryModel):
|
109 |
+
return await generate_report(query)
|
110 |
+
|
111 |
app.add_middleware(
|
112 |
CORSMiddleware,
|
113 |
allow_origins=["*"],
|