PawinC commited on
Commit
7a0ef1d
1 Parent(s): ce36f28

Update app/main.py

Browse files
Files changed (1) hide show
  1. app/main.py +39 -15
app/main.py CHANGED
@@ -15,6 +15,7 @@ from typing import Optional
15
  print("Loading model...")
16
  SAllm = Llama(model_path="/models/final-gemma2b_SA-Q8_0.gguf", mmap=False, mlock=True)
17
  FIllm = Llama(model_path="/models/final-gemma7b_FI-Q8_0.gguf", mmap=False, mlock=True)
 
18
  # n_gpu_layers=28, # Uncomment to use GPU acceleration
19
  # seed=1337, # Uncomment to set a specific seed
20
  # n_ctx=2048, # Uncomment to increase the context window
@@ -23,9 +24,9 @@ FIllm = Llama(model_path="/models/final-gemma7b_FI-Q8_0.gguf", mmap=False, mlock
23
  def extract_restext(response):
24
  return response['choices'][0]['text'].strip()
25
 
26
- def ask_fi(question, max_new_tokens=200, temperature=0.5):
27
  prompt = f"""###User: {question}\n###Assistant:"""
28
- result = extract_restext(FIllm(prompt, max_tokens=max_new_tokens, temperature=temperature, stop=["###User:", "###Assistant:"], echo=False))
29
  return result
30
 
31
  def check_sentiment(text):
@@ -43,7 +44,8 @@ def check_sentiment(text):
43
  # TESTING THE MODEL
44
  print("Testing model...")
45
  assert "positive" in check_sentiment("ดอกไม้ร้านนี้สวยจัง")
46
- assert ask_fi("Hello!, How are you today?")
 
47
  print("Ready.")
48
 
49
 
@@ -70,12 +72,12 @@ class SA_Result(str, Enum):
70
  negative = "negative"
71
  unknown = "unknown"
72
 
73
- class SA_Response(BaseModel):
74
  code: int = 200
75
  text: Optional[str] = None
76
  result: SA_Result = None
77
 
78
- class FI_Response(BaseModel):
79
  code: int = 200
80
  question: Optional[str] = None
81
  answer: str = None
@@ -89,18 +91,18 @@ def docs():
89
  return responses.RedirectResponse('./docs')
90
 
91
  @app.post('/classifications/sentiment')
92
- async def perform_sentiment_analysis(prompt: str = Body(..., embed=True, example="I like eating fried chicken")) -> SA_Response:
93
  """Performs a sentiment analysis using a finetuned version of Gemma-7b"""
94
  if prompt:
95
  try:
96
  print(f"Checking sentiment for {prompt}")
97
  result = check_sentiment(prompt)
98
  print(f"Result: {result}")
99
- return SA_Response(result=result, text=prompt)
100
  except Exception as e:
101
- return HTTPException(500, SA_Response(code=500, result=str(e), text=prompt))
102
  else:
103
- return HTTPException(400, SA_Response(code=400, result="Request argument 'prompt' not provided."))
104
 
105
 
106
  @app.post('/questions/finance')
@@ -108,18 +110,40 @@ async def ask_gemmaFinanceTH(
108
  prompt: str = Body(..., embed=True, example="What's the best way to invest my money"),
109
  temperature: float = Body(0.5, embed=True),
110
  max_new_tokens: int = Body(200, embed=True)
111
- ) -> FI_Response:
112
  """
113
  Ask a finetuned Gemma a finance-related question, just for fun.
114
  NOTICE: IT MAY PRODUCE RANDOM/INACCURATE ANSWERS. PLEASE SEEK PROFESSIONAL ADVICE BEFORE DOING ANYTHING SERIOUS.
115
  """
116
  if prompt:
117
  try:
118
- print(f'Asking FI with the question "{prompt}"')
119
- result = ask_fi(prompt, max_new_tokens=max_new_tokens, temperature=temperature)
120
  print(f"Result: {result}")
121
- return FI_Response(answer=result, question=prompt, config={"temperature": temperature, "max_new_tokens": max_new_tokens})
122
  except Exception as e:
123
- return HTTPException(500, FI_Response(code=500, answer=str(e), question=prompt))
124
  else:
125
- return HTTPException(400, FI_Response(code=400, answer="Request argument 'prompt' not provided."))
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
15
  print("Loading model...")
16
  SAllm = Llama(model_path="/models/final-gemma2b_SA-Q8_0.gguf", mmap=False, mlock=True)
17
  FIllm = Llama(model_path="/models/final-gemma7b_FI-Q8_0.gguf", mmap=False, mlock=True)
18
+ WIllm = Llama(model_path="/models/final-GemmaWild7b-Q8_0.gguf", mmap=False, mlock=True)
19
  # n_gpu_layers=28, # Uncomment to use GPU acceleration
20
  # seed=1337, # Uncomment to set a specific seed
21
  # n_ctx=2048, # Uncomment to increase the context window
 
24
  def extract_restext(response):
25
  return response['choices'][0]['text'].strip()
26
 
27
+ def ask_llm(llm, question, max_new_tokens=200, temperature=0.5):
28
  prompt = f"""###User: {question}\n###Assistant:"""
29
+ result = extract_restext(llm(prompt, max_tokens=max_new_tokens, temperature=temperature, stop=["###User:", "###Assistant:"], echo=False))
30
  return result
31
 
32
  def check_sentiment(text):
 
44
  # TESTING THE MODEL
45
  print("Testing model...")
46
  assert "positive" in check_sentiment("ดอกไม้ร้านนี้สวยจัง")
47
+ assert ask_llm(FIllm, "Hello!, How are you today?", max_new_tokens=1) #Just checking that it can run
48
+ assert ask_llm(WIllm, "Hello!, How are you today?", max_new_tokens=1) #Just checking that it can run
49
  print("Ready.")
50
 
51
 
 
72
  negative = "negative"
73
  unknown = "unknown"
74
 
75
+ class SAResponse(BaseModel):
76
  code: int = 200
77
  text: Optional[str] = None
78
  result: SA_Result = None
79
 
80
+ class QuestionResponse(BaseModel):
81
  code: int = 200
82
  question: Optional[str] = None
83
  answer: str = None
 
91
  return responses.RedirectResponse('./docs')
92
 
93
  @app.post('/classifications/sentiment')
94
+ async def perform_sentiment_analysis(prompt: str = Body(..., embed=True, example="I like eating fried chicken")) -> SAResponse:
95
  """Performs a sentiment analysis using a finetuned version of Gemma-7b"""
96
  if prompt:
97
  try:
98
  print(f"Checking sentiment for {prompt}")
99
  result = check_sentiment(prompt)
100
  print(f"Result: {result}")
101
+ return SAResponse(result=result, text=prompt)
102
  except Exception as e:
103
+ return HTTPException(500, SAResponse(code=500, result=str(e), text=prompt))
104
  else:
105
+ return HTTPException(400, SAResponse(code=400, result="Request argument 'prompt' not provided."))
106
 
107
 
108
  @app.post('/questions/finance')
 
110
  prompt: str = Body(..., embed=True, example="What's the best way to invest my money"),
111
  temperature: float = Body(0.5, embed=True),
112
  max_new_tokens: int = Body(200, embed=True)
113
+ ) -> QuestionResponse:
114
  """
115
  Ask a finetuned Gemma a finance-related question, just for fun.
116
  NOTICE: IT MAY PRODUCE RANDOM/INACCURATE ANSWERS. PLEASE SEEK PROFESSIONAL ADVICE BEFORE DOING ANYTHING SERIOUS.
117
  """
118
  if prompt:
119
  try:
120
+ print(f'Asking GemmaFinance with the question "{prompt}"')
121
+ result = ask_llm(FIllm, prompt, max_new_tokens=max_new_tokens, temperature=temperature)
122
  print(f"Result: {result}")
123
+ return QuestionResponse(answer=result, question=prompt, config={"temperature": temperature, "max_new_tokens": max_new_tokens})
124
  except Exception as e:
125
+ return HTTPException(500, QuestionResponse(code=500, answer=str(e), question=prompt))
126
  else:
127
+ return HTTPException(400, QuestionResponse(code=400, answer="Request argument 'prompt' not provided."))
128
+
129
+
130
+ @app.post('/questions/finance')
131
+ async def ask_gemmaFinanceTH(
132
+ prompt: str = Body(..., embed=True, example="Why is ice cream so delicious?"),
133
+ temperature: float = Body(0.5, embed=True),
134
+ max_new_tokens: int = Body(200, embed=True)
135
+ ) -> QuestionResponse:
136
+ """
137
+ Ask a finetuned Gemma an open-ended question..
138
+ NOTICE: IT MAY PRODUCE RANDOM/INACCURATE ANSWERS. PLEASE SEEK PROFESSIONAL ADVICE BEFORE DOING ANYTHING SERIOUS.
139
+ """
140
+ if prompt:
141
+ try:
142
+ print(f'Asking GemmaWild with the question "{prompt}"')
143
+ result = ask_llm(WIllm, prompt, max_new_tokens=max_new_tokens, temperature=temperature)
144
+ print(f"Result: {result}")
145
+ return QuestionResponse(answer=result, question=prompt, config={"temperature": temperature, "max_new_tokens": max_new_tokens})
146
+ except Exception as e:
147
+ return HTTPException(500, QuestionResponse(code=500, answer=str(e), question=prompt))
148
+ else:
149
+ return HTTPException(400, QuestionResponse(code=400, answer="Request argument 'prompt' not provided."))