Spaces:
Sleeping
Sleeping
File size: 4,655 Bytes
4926347 0c228e3 4926347 7afebb8 4926347 8b750c3 4926347 33d6214 4926347 8b750c3 4926347 33d6214 4926347 befe899 8b750c3 33d6214 befe899 7afebb8 33d6214 7afebb8 4926347 8b750c3 4926347 7afebb8 33d6214 7afebb8 33d6214 7afebb8 4926347 7afebb8 33d6214 7afebb8 33d6214 0c228e3 33d6214 8b750c3 33d6214 0c228e3 33d6214 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 |
#!/usr/bin/env python
# coding: utf-8
from os import listdir
from os.path import isdir
from fastapi import FastAPI, HTTPException, Request, responses, Body
from fastapi.middleware.cors import CORSMiddleware
from llama_cpp import Llama
from pydantic import BaseModel
from enum import Enum
from typing import Optional
print("Loading model...")
SAllm = Llama(model_path="/models/final-gemma2b_SA-Q8_0.gguf", mmap=False, mlock=True)
# n_gpu_layers=28, # Uncomment to use GPU acceleration
# seed=1337, # Uncomment to set a specific seed
# n_ctx=2048, # Uncomment to increase the context window
#)
FIllm = Llama(model_path="/models/final-gemma7b_FI-Q8_0.gguf", mmap=False, mlock=True)
# def ask(question, max_new_tokens=200):
# output = llm(
# question, # Prompt
# max_tokens=max_new_tokens, # Generate up to 32 tokens, set to None to generate up to the end of the context window
# stop=["\n"], # Stop generating just before the model would generate a new question
# echo=False, # Echo the prompt back in the output
# temperature=0.0,
# )
# return output
def extract_restext(response):
return response['choices'][0]['text'].strip()
def ask_fi(question, max_new_tokens=200, temperature=0.5):
prompt = f"""###User: {question}\n###Assistant:"""
result = extract_restext(FIllm(prompt, max_tokens=max_new_tokens, temperature=temperature, stop=["###User:", "###Assistant:"], echo=False))
return result
def check_sentiment(text):
prompt = f'Analyze the sentiment of the tweet enclosed in square brackets, determine if it is positive or negative, and return the answer as the corresponding sentiment label "positive" or "negative" [{text}] ='
response = SAllm(prompt, max_tokens=3, stop=["\n"], echo=False, temperature=0.5)
# print(response)
result = extract_restext(response)
if "positive" in result:
return "positive"
elif "negative" in result:
return "negative"
else:
return "unknown"
print("Testing model...")
assert "positive" in check_sentiment("ดอกไม้ร้านนี้สวยจัง")
assert ask_fi("Hello!, How are you today?")
print("Ready.")
app = FastAPI(
title = "GemmaSA_2b",
description="A simple sentiment analysis API for the Thai language, powered by a finetuned version of Gemma-2b",
version="1.0.0",
)
origins = ["*"]
app.add_middleware(
CORSMiddleware,
allow_origins=origins,
allow_credentials=True,
allow_methods=["*"],
allow_headers=["*"]
)
class SA_Result(str, Enum):
positive = "positive"
negative = "negative"
unknown = "unknown"
class SA_Response(BaseModel):
code: int = 200
text: Optional[str] = None
result: SA_Result = None
class FI_Response(BaseModel):
code: int = 200
question: Optional[str] = None
answer: str = None
config: Optional[dict] = None
@app.get('/')
def docs():
"Redirects the user from the main page to the docs."
return responses.RedirectResponse('./docs')
@app.get('/add/{a}/{b}')
def add(a: int,b: int):
return a + b
@app.post('/SA')
def perform_sentiment_analysis(prompt: str = Body(..., embed=True, example="I like eating fried chicken")) -> SA_Response:
"""Performs a sentiment analysis using a finetuned version of Gemma-7b"""
if prompt:
try:
print(f"Checking sentiment for {prompt}")
result = check_sentiment(prompt)
print(f"Result: {result}")
return SA_Response(result=result, text=prompt)
except Exception as e:
return HTTPException(500, SA_Response(code=500, result=str(e), text=prompt))
else:
return HTTPException(400, SA_Response(code=400, result="Request argument 'prompt' not provided."))
@app.post('/FI')
def ask_gemmaFinanceTH(
prompt: str = Body(..., embed=True, example="What's the best way to invest my money"),
temperature: float = Body(0.5, embed=True),
max_new_tokens: int = Body(200, embed=True)
) -> FI_Response:
"""
Ask a finetuned Gemma a finance-related question, just for fun.
NOTICE: IT MAY PRODUCE RANDOM/INACCURATE ANSWERS. PLEASE SEEK PROFESSIONAL ADVICE BEFORE DOING ANYTHING SERIOUS.
"""
if prompt:
try:
print(f'Asking FI with the question "{prompt}"')
result = ask_fi(prompt, max_new_tokens=max_new_tokens, temperature=temperature)
print(f"Result: {result}")
return FI_Response(answer=result, question=prompt, config={"temperature": temperature, "max_new_tokens": max_new_tokens})
except Exception as e:
return HTTPException(500, FI_Response(code=500, answer=str(e), question=prompt))
else:
return HTTPException(400, FI_Response(code=400, answer="Request argument 'prompt' not provided."))
|