PawinC commited on
Commit
33d6214
1 Parent(s): 2cbc039

Upload main.py

Browse files
Files changed (1) hide show
  1. app/main.py +47 -20
app/main.py CHANGED
@@ -12,35 +12,36 @@ from enum import Enum
12
  from typing import Optional
13
 
14
  print("Loading model...")
15
- llm = Llama(
16
- model_path="/models/final-gemma2b_SA-Q8_0.gguf",
17
  # n_gpu_layers=28, # Uncomment to use GPU acceleration
18
  # seed=1337, # Uncomment to set a specific seed
19
  # n_ctx=2048, # Uncomment to increase the context window
20
- )
21
 
22
- def ask(question, max_new_tokens=200):
23
- output = llm(
24
- question, # Prompt
25
- max_tokens=max_new_tokens, # Generate up to 32 tokens, set to None to generate up to the end of the context window
26
- stop=["\n"], # Stop generating just before the model would generate a new question
27
- echo=False, # Echo the prompt back in the output
28
- temperature=0.0,
29
- )
30
- return output
31
 
32
- def check_sentiment(text):
33
- result = ask(f'Analyze the sentiment of the tweet enclosed in square brackets, determine if it is positive or negative, and return the answer as the corresponding sentiment label "positive" or "negative" [{text}] =', max_new_tokens=3)
34
- return result['choices'][0]['text'].strip()
 
 
 
 
 
 
35
 
36
- def clean_sentiment_response(response_text):
37
- result = response_text.strip()
 
 
 
38
  if "positive" in result:
39
  return "positive"
40
  elif "negative" in result:
41
  return "negative"
42
  else:
43
  return "unknown"
 
44
 
45
  print("Testing model...")
46
  assert "positive" in check_sentiment("ดอกไม้ร้านนี้สวยจัง")
@@ -67,11 +68,15 @@ class SA_Result(str, Enum):
67
  unknown = "unknown"
68
 
69
  class SA_Response(BaseModel):
70
- text: Optional[str] = None
71
  code: int = 200
 
72
  result: SA_Result = None
73
 
74
-
 
 
 
 
75
 
76
  @app.get('/')
77
  def docs():
@@ -88,10 +93,32 @@ def perform_sentiment_analysis(prompt: str = Body(..., embed=True, example="I li
88
  if prompt:
89
  try:
90
  print(f"Checking sentiment for {prompt}")
91
- result = clean_sentiment_response(check_sentiment(prompt))
92
  print(f"Result: {result}")
93
  return SA_Response(result=result, text=prompt)
94
  except Exception as e:
95
  return HTTPException(500, SA_Response(code=500, result=str(e), text=prompt))
96
  else:
97
  return HTTPException(400, SA_Response(code=400, result="Request argument 'prompt' not provided."))
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
12
  from typing import Optional
13
 
14
  print("Loading model...")
15
+ SAllm = Llama(model_path="/models/final-gemma2b_SA-Q8_0.gguf")#,
 
16
  # n_gpu_layers=28, # Uncomment to use GPU acceleration
17
  # seed=1337, # Uncomment to set a specific seed
18
  # n_ctx=2048, # Uncomment to increase the context window
19
+ #)
20
 
21
+ FIllm = Llama(model_path="/models/final-gemma2b_FI-Q8_0.gguf")
 
 
 
 
 
 
 
 
22
 
23
+ # def ask(question, max_new_tokens=200):
24
+ # output = llm(
25
+ # question, # Prompt
26
+ # max_tokens=max_new_tokens, # Generate up to 32 tokens, set to None to generate up to the end of the context window
27
+ # stop=["\n"], # Stop generating just before the model would generate a new question
28
+ # echo=False, # Echo the prompt back in the output
29
+ # temperature=0.0,
30
+ # )
31
+ # return output
32
 
33
+ def check_sentiment(text):
34
+ prompt = f'Analyze the sentiment of the tweet enclosed in square brackets, determine if it is positive or negative, and return the answer as the corresponding sentiment label "positive" or "negative" [{text}] ='
35
+ response = SAllm(prompt, max_tokens=3, stop=["\n"], echo=False, temperature=0.5)
36
+ # print(response)
37
+ result = response['choices'][0]['text'].strip()
38
  if "positive" in result:
39
  return "positive"
40
  elif "negative" in result:
41
  return "negative"
42
  else:
43
  return "unknown"
44
+
45
 
46
  print("Testing model...")
47
  assert "positive" in check_sentiment("ดอกไม้ร้านนี้สวยจัง")
 
68
  unknown = "unknown"
69
 
70
  class SA_Response(BaseModel):
 
71
  code: int = 200
72
+ text: Optional[str] = None
73
  result: SA_Result = None
74
 
75
+ class FI_Response(BaseModel):
76
+ code: int = 200
77
+ question: Optional[str] = None
78
+ answer: str = None
79
+ config: Optional[dict] = None
80
 
81
  @app.get('/')
82
  def docs():
 
93
  if prompt:
94
  try:
95
  print(f"Checking sentiment for {prompt}")
96
+ result = check_sentiment(prompt)
97
  print(f"Result: {result}")
98
  return SA_Response(result=result, text=prompt)
99
  except Exception as e:
100
  return HTTPException(500, SA_Response(code=500, result=str(e), text=prompt))
101
  else:
102
  return HTTPException(400, SA_Response(code=400, result="Request argument 'prompt' not provided."))
103
+
104
+
105
+ @app.post('/FI')
106
+ def ask_gemmaFinanceTH(
107
+ prompt: str = Body(..., embed=True, example="What's the best way to invest my money"),
108
+ temperature: float = 0.5,
109
+ max_new_tokens: int = 200
110
+ ) -> FI_Response:
111
+ """
112
+ Ask a finetuned Gemma a finance-related question, just for fun.
113
+ NOTICE: IT MAY PRODUCE RANDOM/INACCURATE ANSWERS. PLEASE SEEK PROFESSIONAL ADVICE BEFORE DOING ANYTHING SERIOUS.
114
+ """
115
+ if prompt:
116
+ try:
117
+ print(f'Asking FI with the question "{prompt}"')
118
+ result = FIllm(prompt, max_tokens=max_new_tokens, temperature=temperature, stop=["###User:", "###Assistant:"], echo=False)
119
+ print(f"Result: {result}")
120
+ return FI_Response(answer=result, question=prompt)
121
+ except Exception as e:
122
+ return HTTPException(500, FI_Response(code=500, answer=str(e), question=prompt))
123
+ else:
124
+ return HTTPException(400, FI_Response(code=400, answer="Request argument 'prompt' not provided."))