#!/usr/bin/env python # coding: utf-8 from os import listdir from os.path import isdir from fastapi import FastAPI, HTTPException, Request, responses, Body from fastapi.middleware.cors import CORSMiddleware from llama_cpp import Llama from pydantic import BaseModel from enum import Enum from typing import Optional print("Loading model...") SAllm = Llama(model_path="/models/final-gemma2b_SA-Q8_0.gguf", mmap=False, mlock=True) # n_gpu_layers=28, # Uncomment to use GPU acceleration # seed=1337, # Uncomment to set a specific seed # n_ctx=2048, # Uncomment to increase the context window #) FIllm = Llama(model_path="/models/final-gemma7b_FI-Q8_0.gguf", mmap=False, mlock=True) # def ask(question, max_new_tokens=200): # output = llm( # question, # Prompt # max_tokens=max_new_tokens, # Generate up to 32 tokens, set to None to generate up to the end of the context window # stop=["\n"], # Stop generating just before the model would generate a new question # echo=False, # Echo the prompt back in the output # temperature=0.0, # ) # return output def extract_restext(response): return response['choices'][0]['text'].strip() def ask_fi(question, max_new_tokens=200, temperature=0.5): prompt = f"""###User: {question}\n###Assistant:""" result = extract_restext(FIllm(prompt, max_tokens=max_new_tokens, temperature=temperature, stop=["###User:", "###Assistant:"], echo=False)) return result def check_sentiment(text): prompt = f'Analyze the sentiment of the tweet enclosed in square brackets, determine if it is positive or negative, and return the answer as the corresponding sentiment label "positive" or "negative" [{text}] =' response = SAllm(prompt, max_tokens=3, stop=["\n"], echo=False, temperature=0.5) # print(response) result = extract_restext(response) if "positive" in result: return "positive" elif "negative" in result: return "negative" else: return "unknown" print("Testing model...") assert "positive" in check_sentiment("ดอกไม้ร้านนี้สวยจัง") assert ask_fi("Hello!, How are you today?") print("Ready.") app = FastAPI( title = "GemmaSA_2b", description="A simple sentiment analysis API for the Thai language, powered by a finetuned version of Gemma-2b", version="1.0.0", ) origins = ["*"] app.add_middleware( CORSMiddleware, allow_origins=origins, allow_credentials=True, allow_methods=["*"], allow_headers=["*"] ) class SA_Result(str, Enum): positive = "positive" negative = "negative" unknown = "unknown" class SA_Response(BaseModel): code: int = 200 text: Optional[str] = None result: SA_Result = None class FI_Response(BaseModel): code: int = 200 question: Optional[str] = None answer: str = None config: Optional[dict] = None @app.get('/') def docs(): "Redirects the user from the main page to the docs." return responses.RedirectResponse('./docs') @app.get('/add/{a}/{b}') def add(a: int,b: int): return a + b @app.post('/SA') def perform_sentiment_analysis(prompt: str = Body(..., embed=True, example="I like eating fried chicken")) -> SA_Response: """Performs a sentiment analysis using a finetuned version of Gemma-7b""" if prompt: try: print(f"Checking sentiment for {prompt}") result = check_sentiment(prompt) print(f"Result: {result}") return SA_Response(result=result, text=prompt) except Exception as e: return HTTPException(500, SA_Response(code=500, result=str(e), text=prompt)) else: return HTTPException(400, SA_Response(code=400, result="Request argument 'prompt' not provided.")) @app.post('/FI') def ask_gemmaFinanceTH( prompt: str = Body(..., embed=True, example="What's the best way to invest my money"), temperature: float = Body(0.5, embed=True), max_new_tokens: int = Body(200, embed=True) ) -> FI_Response: """ Ask a finetuned Gemma a finance-related question, just for fun. NOTICE: IT MAY PRODUCE RANDOM/INACCURATE ANSWERS. PLEASE SEEK PROFESSIONAL ADVICE BEFORE DOING ANYTHING SERIOUS. """ if prompt: try: print(f'Asking FI with the question "{prompt}"') result = ask_fi(prompt, max_new_tokens=max_new_tokens, temperature=temperature) print(f"Result: {result}") return FI_Response(answer=result, question=prompt, config={"temperature": temperature, "max_new_tokens": max_new_tokens}) except Exception as e: return HTTPException(500, FI_Response(code=500, answer=str(e), question=prompt)) else: return HTTPException(400, FI_Response(code=400, answer="Request argument 'prompt' not provided."))