File size: 4,407 Bytes
4926347
 
 
 
0c228e3
4926347
 
7afebb8
 
 
 
4926347
 
33d6214
4926347
 
 
33d6214
4926347
0c228e3
4926347
33d6214
 
 
 
 
 
 
 
 
4926347
befe899
 
 
33d6214
 
 
 
befe899
7afebb8
 
 
 
 
 
33d6214
7afebb8
4926347
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
7afebb8
 
 
 
 
 
 
33d6214
7afebb8
 
33d6214
 
 
 
 
7afebb8
4926347
 
 
 
 
 
 
 
 
7afebb8
 
 
 
 
 
33d6214
7afebb8
 
 
 
 
 
33d6214
 
 
 
 
0c228e3
 
33d6214
 
 
 
 
 
 
 
0c228e3
befe899
33d6214
0c228e3
33d6214
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
#!/usr/bin/env python
# coding: utf-8
from os import listdir
from os.path import isdir
from fastapi import FastAPI, HTTPException, Request, responses, Body
from fastapi.middleware.cors import CORSMiddleware
from llama_cpp import Llama

from pydantic import BaseModel
from enum import Enum
from typing import Optional

print("Loading model...")
SAllm = Llama(model_path="/models/final-gemma2b_SA-Q8_0.gguf")#,
      # n_gpu_layers=28, # Uncomment to use GPU acceleration
      # seed=1337, # Uncomment to set a specific seed
      # n_ctx=2048, # Uncomment to increase the context window
#)

# FIllm = Llama(model_path="/models/final-gemma2b_FI-Q8_0.gguf")

# def ask(question, max_new_tokens=200):
#   output = llm(
#     question, # Prompt
#     max_tokens=max_new_tokens, # Generate up to 32 tokens, set to None to generate up to the end of the context window
#     stop=["\n"], # Stop generating just before the model would generate a new question
#     echo=False, # Echo the prompt back in the output
#     temperature=0.0,
#   )
#   return output

def extract_restext(response):
  return response['choices'][0]['text'].strip()

def check_sentiment(text):
  prompt = f'Analyze the sentiment of the tweet enclosed in square brackets, determine if it is positive or negative, and return the answer as the corresponding sentiment label "positive" or  "negative" [{text}] ='
  response = SAllm(prompt, max_tokens=3, stop=["\n"], echo=False, temperature=0.5)
  # print(response)
  result = extract_restext(response)
  if "positive" in result:
    return "positive"
  elif "negative" in result:
    return "negative"
  else:
    return "unknown"

  
print("Testing model...")
assert "positive" in check_sentiment("ดอกไม้ร้านนี้สวยจัง")
print("Ready.")

app = FastAPI(
    title = "GemmaSA_2b",
    description="A simple sentiment analysis API for the Thai language, powered by a finetuned version of Gemma-2b",
    version="1.0.0",
)

origins = ["*"]
app.add_middleware(
    CORSMiddleware,
    allow_origins=origins,
    allow_credentials=True,
    allow_methods=["*"],
    allow_headers=["*"]
)

class SA_Result(str, Enum):
  positive = "positive"
  negative = "negative"
  unknown = "unknown"

class SA_Response(BaseModel):
  code: int = 200
  text: Optional[str] = None
  result: SA_Result = None

class FI_Response(BaseModel):
  code: int = 200
  question: Optional[str] = None
  answer: str = None
  config: Optional[dict] = None

@app.get('/')
def docs():
  "Redirects the user from the main page to the docs."
  return responses.RedirectResponse('./docs')

@app.get('/add/{a}/{b}')
def add(a: int,b: int):
    return a + b

@app.post('/SA')
def perform_sentiment_analysis(prompt: str = Body(..., embed=True, example="I like eating fried chicken")) -> SA_Response:
  """Performs a sentiment analysis using a finetuned version of Gemma-7b"""
  if prompt:
    try:
      print(f"Checking sentiment for {prompt}")
      result = check_sentiment(prompt)
      print(f"Result: {result}")
      return SA_Response(result=result, text=prompt)
    except Exception as e:
      return HTTPException(500, SA_Response(code=500, result=str(e), text=prompt))
  else:
    return HTTPException(400, SA_Response(code=400, result="Request argument 'prompt' not provided."))


@app.post('/FI')
def ask_gemmaFinanceTH(
    prompt: str = Body(..., embed=True, example="What's the best way to invest my money"),
    temperature: float = Body(0.5, embed=True), 
    max_new_tokens: int = Body(200, embed=True)
) -> FI_Response:
  """
  Ask a finetuned Gemma a finance-related question, just for fun.
  NOTICE: IT MAY PRODUCE RANDOM/INACCURATE ANSWERS. PLEASE SEEK PROFESSIONAL ADVICE BEFORE DOING ANYTHING SERIOUS.
  """
  if prompt:
    try:
      print(f'Asking FI with the question "{prompt}"')
      prompt = f"""###User: {prompt}\n###Assistant:"""
      result = extract_restext(FIllm(prompt, max_tokens=max_new_tokens, temperature=temperature, stop=["###User:", "###Assistant:"], echo=False))
      print(f"Result: {result}")
      return FI_Response(answer=result, question=prompt, config={"temperature": temperature, "max_new_tokens": max_new_tokens})
    except Exception as e:
      return HTTPException(500, FI_Response(code=500, answer=str(e), question=prompt))
  else:
    return HTTPException(400, FI_Response(code=400, answer="Request argument 'prompt' not provided."))