YaTharThShaRma999 commited on
Commit
0dc25c1
1 Parent(s): 99cf705

Upload message.txt

Browse files
Files changed (1) hide show
  1. message.txt +132 -0
message.txt ADDED
@@ -0,0 +1,132 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ A model worker executes the model.
3
+ """
4
+ import argparse
5
+ import json
6
+ import uuid
7
+
8
+ from fastapi import FastAPI, Request
9
+ from fastapi.responses import StreamingResponse
10
+ from transformers import AutoModel, AutoTokenizer
11
+ import torch
12
+ import uvicorn
13
+ import bitsandbytes as bnb
14
+ from transformers import BitsAndBytesConfig
15
+
16
+ from transformers.generation.streamers import BaseStreamer
17
+ from threading import Thread
18
+ from queue import Queue
19
+
20
+
21
+ class TokenStreamer(BaseStreamer):
22
+ def __init__(self, skip_prompt: bool = False, timeout=None):
23
+ self.skip_prompt = skip_prompt
24
+
25
+ # variables used in the streaming process
26
+ self.token_queue = Queue()
27
+ self.stop_signal = None
28
+ self.next_tokens_are_prompt = True
29
+ self.timeout = timeout
30
+
31
+ def put(self, value):
32
+ if len(value.shape) > 1 and value.shape[0] > 1:
33
+ raise ValueError("TextStreamer only supports batch size 1")
34
+ elif len(value.shape) > 1:
35
+ value = value[0]
36
+
37
+ if self.skip_prompt and self.next_tokens_are_prompt:
38
+ self.next_tokens_are_prompt = False
39
+ return
40
+
41
+ for token in value.tolist():
42
+ self.token_queue.put(token)
43
+
44
+ def end(self):
45
+ self.token_queue.put(self.stop_signal)
46
+
47
+ def __iter__(self):
48
+ return self
49
+
50
+ def __next__(self):
51
+ value = self.token_queue.get(timeout=self.timeout)
52
+ if value == self.stop_signal:
53
+ raise StopIteration()
54
+ else:
55
+ return value
56
+
57
+
58
+ class ModelWorker:
59
+ def __init__(self, model_path, device='cuda'):
60
+ self.device = device
61
+
62
+ # Configure 4-bit quantization
63
+ quantization_config = BitsAndBytesConfig(
64
+ load_in_4bit=True,
65
+ bnb_4bit_compute_dtype=torch.float16,
66
+ bnb_4bit_quant_type="nf4",
67
+ bnb_4bit_use_double_quant=True
68
+ )
69
+
70
+ self.glm_model = AutoModel.from_pretrained(
71
+ model_path,
72
+ trust_remote_code=True,
73
+ device_map=device, # Use device_map instead of device
74
+ quantization_config=quantization_config
75
+ ).eval() # Remove .to(device) call
76
+
77
+ self.glm_tokenizer = AutoTokenizer.from_pretrained(model_path, trust_remote_code=True)
78
+
79
+ @torch.inference_mode()
80
+ def generate_stream(self, params):
81
+ tokenizer, model = self.glm_tokenizer, self.glm_model
82
+
83
+ prompt = params["prompt"]
84
+
85
+ temperature = float(params.get("temperature", 1.0))
86
+ top_p = float(params.get("top_p", 1.0))
87
+ max_new_tokens = int(params.get("max_new_tokens", 256))
88
+
89
+ inputs = tokenizer([prompt], return_tensors="pt")
90
+ inputs = inputs.to(self.device)
91
+ streamer = TokenStreamer(skip_prompt=True)
92
+ thread = Thread(target=model.generate,
93
+ kwargs=dict(**inputs, max_new_tokens=int(max_new_tokens),
94
+ temperature=float(temperature), top_p=float(top_p),
95
+ streamer=streamer))
96
+ thread.start()
97
+ for token_id in streamer:
98
+ yield (json.dumps({"token_id": token_id, "error_code": 0}) + "\n").encode()
99
+
100
+ def generate_stream_gate(self, params):
101
+ try:
102
+ for x in self.generate_stream(params):
103
+ yield x
104
+ except Exception as e:
105
+ print("Caught Unknown Error", e)
106
+ ret = {
107
+ "text": "Server Error",
108
+ "error_code": 1,
109
+ }
110
+ yield (json.dumps(ret)+ "\n").encode()
111
+
112
+
113
+ app = FastAPI()
114
+
115
+
116
+ @app.post("/generate_stream")
117
+ async def generate_stream(request: Request):
118
+ params = await request.json()
119
+
120
+ generator = worker.generate_stream_gate(params)
121
+ return StreamingResponse(generator)
122
+
123
+
124
+ if __name__ == "__main__":
125
+ parser = argparse.ArgumentParser()
126
+ parser.add_argument("--host", type=str, default="localhost")
127
+ parser.add_argument("--port", type=int, default=10000)
128
+ parser.add_argument("--model-path", type=str, default="THUDM/glm-4-voice-9b")
129
+ args = parser.parse_args()
130
+
131
+ worker = ModelWorker(args.model_path)
132
+ uvicorn.run(app, host=args.host, port=args.port, log_level="info")