File size: 3,274 Bytes
a686568
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
import json
from typing import List

import torch
from fastapi import FastAPI, Request, status, HTTPException
from pydantic import BaseModel
from torch.cuda import get_device_properties
from transformers import AutoModel, AutoTokenizer
from sse_starlette.sse import EventSourceResponse
from fastapi.middleware.cors import CORSMiddleware
import uvicorn
from pyngrok import ngrok, conf

import os

os.environ['TRANSFORMERS_CACHE'] = ".cache"

bits = 4
kernel_path = "models/models--silver--chatglm-6b-int4-slim/quantization_kernels.so"
model_path = "./models/models--silver--chatglm-6b-int4-slim/snapshots/02e096b3805c579caf5741a6d8eddd5ba7a74e0d"
cache_dir = './models'
model_name = 'chatglm-6b-int4'
min_memory = 5.5
tokenizer = None
model = None

app = FastAPI()

app.add_middleware(
    CORSMiddleware,
    allow_origins=["*"],
    allow_credentials=True,
    allow_methods=["*"],
    allow_headers=["*"],
)


@app.on_event('startup')
def init():
    global tokenizer, model
    tokenizer = AutoTokenizer.from_pretrained(model_path, trust_remote_code=True, cache_dir=cache_dir)
    model = AutoModel.from_pretrained(model_path, trust_remote_code=True, cache_dir=cache_dir)

    if torch.cuda.is_available() and get_device_properties(0).total_memory / 1024 ** 3 > min_memory:
        model = model.half().quantize(bits=bits).cuda()
        print("Using GPU")
    else:
        model = model.float().quantize(bits=bits)
        if torch.cuda.is_available():
            print("Total Memory: ", get_device_properties(0).total_memory / 1024 ** 3)
        else:
            print("No GPU available")
        print("Using CPU")
    model = model.eval()


class Message(BaseModel):
    role: str
    content: str


class Body(BaseModel):
    messages: List[Message]
    model: str
    stream: bool
    max_tokens: int


@app.get("/")
def read_root():
    return {"Hello": "World!"}


@app.post("/chat/completions")
async def completions(body: Body, request: Request):
    if not body.stream or body.model != model_name:
        raise HTTPException(status.HTTP_400_BAD_REQUEST, "Not Implemented")

    question = body.messages[-1]
    if question.role == 'user':
        question = question.content
    else:
        raise HTTPException(status.HTTP_400_BAD_REQUEST, "No Question Found")

    user_question = ''
    history = []
    for message in body.messages:
        if message.role == 'user':
            user_question = message.content
        elif message.role == 'system' or message.role == 'assistant':
            assistant_answer = message.content
            history.append((user_question, assistant_answer))

    async def event_generator():
        for response in model.stream_chat(tokenizer, question, history, max_length=max(2048, body.max_tokens)):
            if await request.is_disconnected():
                return
            yield json.dumps({"response": response[0]})
        yield "[DONE]"

    return EventSourceResponse(event_generator())


def ngrok_connect():
    conf.set_default(conf.PyngrokConfig(ngrok_path="./ngrok"))
    ngrok.set_auth_token(os.environ["ngrok_token"])
    http_tunnel = ngrok.connect(8000)
    print(http_tunnel.public_url)


if __name__ == "__main__":
    ngrok_connect()
    uvicorn.run("main:app", reload=True, app_dir=".")