File size: 3,087 Bytes
abbcb88
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
# coding=utf-8
# Implements API for fine-tuned models.
# Usage: python api_demo.py --model_name_or_path path_to_model --checkpoint_dir path_to_checkpoint

# Request:
# curl http://127.0.0.1:8000 --header 'Content-Type: application/json' --data '{"prompt": "Hello there!", "history": []}'

# Response:
# {
#   "response": "'Hi there!'",
#   "history": "[('Hello there!', 'Hi there!')]",
#   "status": 200,
#   "time": "2000-00-00 00:00:00"
# }


import json
import torch
import uvicorn
import datetime
from fastapi import FastAPI, Request

from utils import (
    Template,
    load_pretrained,
    prepare_infer_args,
    get_logits_processor
)


def torch_gc():
    if torch.cuda.is_available():
        num_gpus = torch.cuda.device_count()
        for device_id in range(num_gpus):
            with torch.cuda.device(device_id):
                torch.cuda.empty_cache()
                torch.cuda.ipc_collect()


app = FastAPI()


@app.post("/")
async def create_item(request: Request):
    global model, tokenizer, prompt_template, generating_args

    # Parse the request JSON
    json_post_raw = await request.json()
    json_post = json.dumps(json_post_raw)
    json_post_list = json.loads(json_post)
    prompt = json_post_list.get("prompt")
    history = json_post_list.get("history")
    max_new_tokens = json_post_list.get("max_new_tokens", None)
    top_p = json_post_list.get("top_p", None)
    temperature = json_post_list.get("temperature", None)

    # Tokenize the input prompt
    input_ids = tokenizer([prompt_template.get_prompt(prompt, history)], return_tensors="pt")["input_ids"]
    input_ids = input_ids.to(model.device)

    # Generation arguments
    gen_kwargs = generating_args.to_dict()
    gen_kwargs["input_ids"] = input_ids
    gen_kwargs["logits_processor"] = get_logits_processor()
    gen_kwargs["max_new_tokens"] = max_new_tokens if max_new_tokens else gen_kwargs["max_new_tokens"]
    gen_kwargs["top_p"] = top_p if top_p else gen_kwargs["top_p"]
    gen_kwargs["temperature"] = temperature if temperature else gen_kwargs["temperature"]

    # Generate response
    with torch.no_grad():
        generation_output = model.generate(**gen_kwargs)
    outputs = generation_output.tolist()[0][len(input_ids[0]):]
    response = tokenizer.decode(outputs, skip_special_tokens=True)

    # Update history
    history = history + [(prompt, response)]

    # Prepare response
    now = datetime.datetime.now()
    time = now.strftime("%Y-%m-%d %H:%M:%S")
    answer = {
        "response": repr(response),
        "history": repr(history),
        "status": 200,
        "time": time
    }

    # Log and clean up
    log = "[" + time + "] " + "\", prompt:\"" + prompt + "\", response:\"" + repr(response) + "\""
    print(log)
    torch_gc()

    return answer


if __name__ == "__main__":
    model_args, data_args, finetuning_args, generating_args = prepare_infer_args()
    model, tokenizer = load_pretrained(model_args, finetuning_args)
    prompt_template = Template(data_args.prompt_template)

    uvicorn.run(app, host='0.0.0.0', port=8000, workers=1)