Spaces:
Sleeping
Sleeping
import json | |
from typing import List | |
import torch | |
from fastapi import FastAPI, Request, status, HTTPException | |
from pydantic import BaseModel | |
from torch.cuda import get_device_properties | |
from transformers import AutoModel, AutoTokenizer | |
from sse_starlette.sse import EventSourceResponse | |
from fastapi.middleware.cors import CORSMiddleware | |
import uvicorn | |
import os | |
os.environ['TRANSFORMERS_CACHE'] = ".cache" | |
bits = 4 | |
kernel_path = "models/models--silver--chatglm-6b-int4-slim/quantization_kernels.so" | |
model_path = "./models/models--silver--chatglm-6b-int4-slim/snapshots/02e096b3805c579caf5741a6d8eddd5ba7a74e0d" | |
cache_dir = './models' | |
model_name = 'chatglm-6b-int4' | |
min_memory = 5.5 | |
tokenizer = None | |
model = None | |
app = FastAPI() | |
app.add_middleware( | |
CORSMiddleware, | |
allow_origins=["*"], | |
allow_credentials=True, | |
allow_methods=["*"], | |
allow_headers=["*"], | |
) | |
def init(): | |
global tokenizer, model | |
tokenizer = AutoTokenizer.from_pretrained(model_path, trust_remote_code=True, cache_dir=cache_dir) | |
model = AutoModel.from_pretrained(model_path, trust_remote_code=True, cache_dir=cache_dir) | |
if torch.cuda.is_available() and get_device_properties(0).total_memory / 1024 ** 3 > min_memory: | |
model = model.half().quantize(bits=bits).cuda() | |
print("Using GPU") | |
else: | |
model = model.float().quantize(bits=bits) | |
if torch.cuda.is_available(): | |
print("Total Memory: ", get_device_properties(0).total_memory / 1024 ** 3) | |
else: | |
print("No GPU available") | |
print("Using CPU") | |
model = model.eval() | |
if os.environ.get("ngrok_token") is not None: | |
ngrok_connect() | |
class Message(BaseModel): | |
role: str | |
content: str | |
class Body(BaseModel): | |
messages: List[Message] | |
model: str | |
stream: bool | |
max_tokens: int | |
def read_root(): | |
return {"Hello": "World!"} | |
async def completions(body: Body, request: Request): | |
if not body.stream or body.model != model_name: | |
raise HTTPException(status.HTTP_400_BAD_REQUEST, "Not Implemented") | |
question = body.messages[-1] | |
if question.role == 'user': | |
question = question.content | |
else: | |
raise HTTPException(status.HTTP_400_BAD_REQUEST, "No Question Found") | |
user_question = '' | |
history = [] | |
for message in body.messages: | |
if message.role == 'user': | |
user_question = message.content | |
elif message.role == 'system' or message.role == 'assistant': | |
assistant_answer = message.content | |
history.append((user_question, assistant_answer)) | |
async def event_generator(): | |
for response in model.stream_chat(tokenizer, question, history, max_length=max(2048, body.max_tokens)): | |
if await request.is_disconnected(): | |
return | |
yield json.dumps({"response": response[0]}) | |
yield "[DONE]" | |
return EventSourceResponse(event_generator()) | |
def ngrok_connect(): | |
from pyngrok import ngrok, conf | |
conf.set_default(conf.PyngrokConfig(ngrok_path="./ngrok")) | |
ngrok.set_auth_token(os.environ["ngrok_token"]) | |
http_tunnel = ngrok.connect(8000) | |
print(http_tunnel.public_url) | |
if __name__ == "__main__": | |
uvicorn.run("main:app", reload=True, app_dir=".") | |