pvanand commited on
Commit
1fc729a
1 Parent(s): bdd570c

Update main.py

Browse files
Files changed (1) hide show
  1. main.py +30 -23
main.py CHANGED
@@ -54,45 +54,47 @@ sys_prompts = {
54
  },
55
  }
56
 
57
- class ReportParams(BaseModel):
58
- topic: str = Query(default="market research", description="input query to generate report")
59
  description: str = Query(default="", description="additional context for report")
60
  user_id: str = Query(default="", description="unique user id")
61
  user_name: str = Query(default="", description="user name")
62
  internet: bool = Query(default=True, description="Enable Internet search")
63
- output_format: str = Query(default="Tabular Report", description="Output format for the report", enum=["Chat", "Full Text Report", "Tabular Report", "Tables only"])
64
- data_format: str = Query(default="Structured data", description="Type of data to extract from the internet", enum=["No presets", "Structured data", "Quantitative data"])
 
 
65
 
66
- @app.post("/generate_report")
67
- @cache(expire=604800) # Cache expiration set to 7 days
68
- async def generate_report(params: ReportParams):
69
- query_str = params.topic
70
- internet_status = "online" if params.internet else "offline"
71
- sys_prompt_output_format = sys_prompts[internet_status][params.output_format]
 
 
72
  optimized_search_query = ""
73
- all_text_with_urls = [("","")]
74
 
75
- # Combine query with user keywords
76
- if params.internet:
77
- search_query = params.description
78
  try:
79
  urls, optimized_search_query = search_brave(search_query, num_results=4)
80
- all_text_with_urls = fetch_and_extract_content(params.data_format, urls, query_str)
81
  additional_context = limit_tokens(str(all_text_with_urls))
82
- prompt = f"#### COMPLETE THE TASK: {params.description} #### IN THE CONTEXT OF ### CONTEXT: {query_str} USING THE #### SCRAPED DATA:{additional_context}"
83
  except Exception as e:
84
- params.internet = False
85
  print("failed to search/scrape results, falling back to LLM response")
86
 
87
- if not params.internet:
88
- prompt = f"#### COMPLETE THE TASK: {params.description} #### IN THE CONTEXT OF ### CONTEXT: {query_str}"
89
 
90
  md_report = together_response(prompt, model=llm_default_medium, SysPrompt=sys_prompt_output_format)
91
 
92
- if params.user_id != "test":
93
- insert_data(params.user_id, query_str, params.description, str(all_text_with_urls), md_report)
94
-
95
- references_html = dict()
96
  for text, url in all_text_with_urls:
97
  references_html[url] = str(md_to_html(text))
98
 
@@ -101,6 +103,11 @@ async def generate_report(params: ReportParams):
101
  "references": references_html,
102
  "search_query": optimized_search_query
103
  }
 
 
 
 
 
104
  app.add_middleware(
105
  CORSMiddleware,
106
  allow_origins=["*"],
 
54
  },
55
  }
56
 
57
+ class QueryModel(BaseModel):
58
+ topic: str = Query(default="market research", description="input query to generate Report")
59
  description: str = Query(default="", description="additional context for report")
60
  user_id: str = Query(default="", description="unique user id")
61
  user_name: str = Query(default="", description="user name")
62
  internet: bool = Query(default=True, description="Enable Internet search")
63
+ output_format: str = Query(default="Tabular Report", description="Output format for the report",
64
+ enum=["Chat", "Full Text Report", "Tabular Report", "Tables only"])
65
+ data_format: str = Query(default="Structured data", description="Type of data to extract from the internet",
66
+ enum=["No presets", "Structured data", "Quantitative data"])
67
 
68
+ @cache(expire=604800)
69
+ async def generate_report(query: QueryModel):
70
+ query_str = query.topic
71
+ description = query.description
72
+ user_id = query.user_id
73
+ internet = "online" if query.internet else "offline"
74
+ sys_prompt_output_format = sys_prompts[internet][query.output_format]
75
+ data_format = query.data_format
76
  optimized_search_query = ""
77
+ all_text_with_urls = [("", "")]
78
 
79
+ if query.internet:
80
+ search_query = description
 
81
  try:
82
  urls, optimized_search_query = search_brave(search_query, num_results=4)
83
+ all_text_with_urls = fetch_and_extract_content(data_format, urls, query_str)
84
  additional_context = limit_tokens(str(all_text_with_urls))
85
+ prompt = f"#### COMPLETE THE TASK: {description} #### IN THE CONTEXT OF ### CONTEXT: {query_str} USING THE #### SCRAPED DATA:{additional_context}"
86
  except Exception as e:
87
+ query.internet = False
88
  print("failed to search/scrape results, falling back to LLM response")
89
 
90
+ if not query.internet:
91
+ prompt = f"#### COMPLETE THE TASK: {description} #### IN THE CONTEXT OF ### CONTEXT: {query_str}"
92
 
93
  md_report = together_response(prompt, model=llm_default_medium, SysPrompt=sys_prompt_output_format)
94
 
95
+ if user_id != "test":
96
+ insert_data(user_id, query_str, description, str(all_text_with_urls), md_report)
97
+ references_html = {}
 
98
  for text, url in all_text_with_urls:
99
  references_html[url] = str(md_to_html(text))
100
 
 
103
  "references": references_html,
104
  "search_query": optimized_search_query
105
  }
106
+
107
+ @app.post("/generate_report")
108
+ async def api_generate_report(request: Request, query: QueryModel):
109
+ return await generate_report(query)
110
+
111
  app.add_middleware(
112
  CORSMiddleware,
113
  allow_origins=["*"],