from fastapi import FastAPI, WebSocket from fastapi.responses import HTMLResponse from fastapi import Form, Depends, HTTPException, status from transformers import pipeline, set_seed, AutoConfig, AutoTokenizer, AutoModelForCausalLM import torch import os import time import re import json app = FastAPI() html = """ Chat

WebSocket Chat

""" @app.get("/") async def get(): return HTMLResponse(html) @app.get("/api/env") async def env(): environment_variables = "

Environment Variables

" for name, value in os.environ.items(): environment_variables += f"{name}: {value}
" return HTMLResponse(environment_variables) @app.websocket("/api/ws") async def websocket_endpoint(websocket: WebSocket): await websocket.accept() while True: data = await websocket.receive_text() await websocket.send_text(f"Message text was: {data}") @app.post("/api/indochat/v1") async def indochat(**kwargs): return text_generate("indochat-tiny", kwargs) @app.post("/api/text-generator/v1") async def text_generate( model_name: str = Form(default="", description="The model name"), text: str = Form(default="", description="The Prompt"), decoding_method: str = Form(default="Sampling", description="Decoding method"), min_length: int = Form(default=50, description="Minimal length of the generated text"), max_length: int = Form(default=250, description="Maximal length of the generated text"), num_beams: int = Form(default=5, description="Beams number"), top_k: int = Form(default=30, description="The number of highest probability vocabulary tokens to keep " "for top-k-filtering"), top_p: float = Form(default=0.95, description="If set to float < 1, only the most probable tokens with " "probabilities that add up to top_p or higher are kept " "for generation"), temperature: float = Form(default=0.5, description="The Temperature of the softmax distribution"), penalty_alpha: float = Form(default=0.5, description="Penalty alpha"), repetition_penalty: float = Form(default=1.2, description="Repetition penalty"), seed: int = Form(default=-1, description="Random Seed"), max_time: float = Form(default=60.0, description="Maximal time in seconds to generate the text") ): if seed >= 0: set_seed(seed) if decoding_method == "Beam Search": do_sample = False penalty_alpha = 0 elif decoding_method == "Sampling": do_sample = True penalty_alpha = 0 num_beams = 1 else: do_sample = False num_beams = 1 if repetition_penalty == 0.0: min_penalty = 1.05 max_penalty = 1.5 repetition_penalty = max(min_penalty + (1.0 - temperature) * (max_penalty - min_penalty), 0.8) prompt = f"User: {text}\nAssistant: " input_ids = text_generator[model_name]["tokenizer"](prompt, return_tensors='pt').input_ids.to(0) text_generator[model_name]["model"].eval() print("Generating text...") print(f"max_length: {max_length}, do_sample: {do_sample}, top_k: {top_k}, top_p: {top_p}, " f"temperature: {temperature}, repetition_penalty: {repetition_penalty}, penalty_alpha: {penalty_alpha}") time_start = time.time() sample_outputs = text_generator[model_name]["model"].generate(input_ids, penalty_alpha=penalty_alpha, do_sample=do_sample, num_beams=num_beams, min_length=min_length, max_length=max_length, top_k=top_k, top_p=top_p, temperature=temperature, repetition_penalty=repetition_penalty, num_return_sequences=1, max_time=max_time ) result = text_generator[model_name]["tokenizer"].decode(sample_outputs[0], skip_special_tokens=True) time_end = time.time() time_diff = time_end - time_start print(f"result:\n{result}") generated_text = result[len(prompt)+1:] generated_text = generated_text[:generated_text.find("User:")] return {"generated_text": generated_text, "processing_time": time_diff} def get_text_generator(model_name: str, load_in_8bit: bool = False, device: str = "cpu"): hf_auth_token = os.getenv("HF_AUTH_TOKEN", False) print(f"hf_auth_token: {hf_auth_token}") print(f"Loading model with device: {device}...") tokenizer = AutoTokenizer.from_pretrained(model_name, use_auth_token=hf_auth_token) model = AutoModelForCausalLM.from_pretrained(model_name, pad_token_id=tokenizer.eos_token_id, load_in_8bit=load_in_8bit, device_map="auto", use_auth_token=hf_auth_token) # model.to(device) print("Model loaded") return model, tokenizer def get_config(): return json.load(open("config.json", "r")) config = get_config() device = "cuda" if torch.cuda.is_available() else "cpu" text_generator = {} for model_name in config["text-generator"]: model, tokenizer = get_text_generator(model_name=config["text-generator"][model_name]["name"], load_in_8bit=config["text-generator"][model_name]["load_in_8bit"], device=device) text_generator[model_name] = { "model": model, "tokenizer": tokenizer }