cutechicken commited on
Commit
dfe75ef
ยท
verified ยท
1 Parent(s): ec72da8

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +35 -30
app.py CHANGED
@@ -11,67 +11,72 @@ from transformers import AutoTokenizer, AutoModelForCausalLM
11
  HF_TOKEN = os.getenv("HF_TOKEN")
12
  MODEL_ID = "CohereForAI/c4ai-command-r7b-12-2024"
13
 
 
 
14
  class ModelManager:
15
  def __init__(self):
16
  self.tokenizer = None
17
  self.model = None
18
  self.setup_model()
19
 
 
20
  def setup_model(self):
21
  try:
22
  print("ํ† ํฌ๋‚˜์ด์ € ๋กœ๋”ฉ ์‹œ์ž‘...")
23
- self.tokenizer = AutoTokenizer.from_pretrained(
24
- MODEL_ID,
25
- token=HF_TOKEN,
26
- use_fast=True
27
- )
28
  print("ํ† ํฌ๋‚˜์ด์ € ๋กœ๋”ฉ ์™„๋ฃŒ")
29
 
30
  print("๋ชจ๋ธ ๋กœ๋”ฉ ์‹œ์ž‘...")
31
- # ZERO GPU ์„ค์ •
32
  self.model = AutoModelForCausalLM.from_pretrained(
33
  MODEL_ID,
34
- token=HF_TOKEN,
35
- torch_dtype=torch.float16,
36
- device_map="balanced", # ZERO GPU๋ฅผ ์œ„ํ•œ balanced ์„ค์ •
37
- max_memory={0: "8GiB"}, # ZERO GPU ๋ฉ”๋ชจ๋ฆฌ ์ œํ•œ
38
- offload_folder="offload", # ์˜คํ”„๋กœ๋“œ ์„ค์ •
39
- low_cpu_mem_usage=True
40
  )
41
  print("๋ชจ๋ธ ๋กœ๋”ฉ ์™„๋ฃŒ")
42
  except Exception as e:
43
  print(f"๋ชจ๋ธ ๋กœ๋”ฉ ์ค‘ ์˜ค๋ฅ˜ ๋ฐœ์ƒ: {e}")
44
  raise Exception(f"๋ชจ๋ธ ๋กœ๋”ฉ ์‹คํŒจ: {e}")
45
 
 
46
  def generate_response(self, messages, max_tokens=4000, temperature=0.7, top_p=0.9):
47
  try:
 
 
 
 
48
  input_ids = self.tokenizer.apply_chat_template(
49
- messages,
50
- tokenize=True,
51
- add_generation_prompt=True,
52
- return_tensors="pt"
53
- ).to(self.model.device)
54
-
55
- # ZERO GPU์— ์ตœ์ ํ™”๋œ ์ƒ์„ฑ ์„ค์ •
56
- gen_tokens = self.model.generate(
57
- input_ids,
 
 
 
 
 
 
 
58
  max_new_tokens=max_tokens,
59
  do_sample=True,
60
  temperature=temperature,
61
  top_p=top_p,
62
- pad_token_id=self.tokenizer.eos_token_id,
63
- use_cache=True, # ์บ์‹œ ์‚ฌ์šฉ์œผ๋กœ ๋ฉ”๋ชจ๋ฆฌ ํšจ์œจํ™”
64
- num_beams=1 # ๋น” ์„œ์น˜ ๋น„ํ™œ์„ฑํ™”๋กœ ๋ฉ”๋ชจ๋ฆฌ ์ ˆ์•ฝ
65
  )
66
-
67
- response_text = self.tokenizer.decode(gen_tokens[0][input_ids.shape[1]:], skip_special_tokens=True)
68
 
69
- # ๋‹จ์–ด ๋‹จ์œ„ ์ŠคํŠธ๋ฆฌ๋ฐ
70
- words = response_text.split()
71
- for word in words:
 
 
 
72
  yield type('Response', (), {
73
  'choices': [type('Choice', (), {
74
- 'delta': {'content': word + " "}
75
  })()]
76
  })()
77
 
 
11
  HF_TOKEN = os.getenv("HF_TOKEN")
12
  MODEL_ID = "CohereForAI/c4ai-command-r7b-12-2024"
13
 
14
+ import spaces
15
+
16
  class ModelManager:
17
  def __init__(self):
18
  self.tokenizer = None
19
  self.model = None
20
  self.setup_model()
21
 
22
+ @spaces.GPU
23
  def setup_model(self):
24
  try:
25
  print("ํ† ํฌ๋‚˜์ด์ € ๋กœ๋”ฉ ์‹œ์ž‘...")
26
+ self.tokenizer = AutoTokenizer.from_pretrained(MODEL_ID)
 
 
 
 
27
  print("ํ† ํฌ๋‚˜์ด์ € ๋กœ๋”ฉ ์™„๋ฃŒ")
28
 
29
  print("๋ชจ๋ธ ๋กœ๋”ฉ ์‹œ์ž‘...")
 
30
  self.model = AutoModelForCausalLM.from_pretrained(
31
  MODEL_ID,
32
+ torch_dtype=torch.bfloat16,
33
+ device_map="auto"
 
 
 
 
34
  )
35
  print("๋ชจ๋ธ ๋กœ๋”ฉ ์™„๋ฃŒ")
36
  except Exception as e:
37
  print(f"๋ชจ๋ธ ๋กœ๋”ฉ ์ค‘ ์˜ค๋ฅ˜ ๋ฐœ์ƒ: {e}")
38
  raise Exception(f"๋ชจ๋ธ ๋กœ๋”ฉ ์‹คํŒจ: {e}")
39
 
40
+ @spaces.GPU
41
  def generate_response(self, messages, max_tokens=4000, temperature=0.7, top_p=0.9):
42
  try:
43
+ conversation = []
44
+ for msg in messages:
45
+ conversation.append({"role": msg["role"], "content": msg["content"]})
46
+
47
  input_ids = self.tokenizer.apply_chat_template(
48
+ conversation,
49
+ tokenize=False,
50
+ add_generation_prompt=True
51
+ )
52
+ inputs = self.tokenizer(input_ids, return_tensors="pt").to(0)
53
+
54
+ streamer = TextIteratorStreamer(
55
+ self.tokenizer,
56
+ timeout=10.,
57
+ skip_prompt=True,
58
+ skip_special_tokens=True
59
+ )
60
+
61
+ generate_kwargs = dict(
62
+ **inputs,
63
+ streamer=streamer,
64
  max_new_tokens=max_tokens,
65
  do_sample=True,
66
  temperature=temperature,
67
  top_p=top_p,
68
+ eos_token_id=[255001]
 
 
69
  )
 
 
70
 
71
+ thread = Thread(target=self.model.generate, kwargs=generate_kwargs)
72
+ thread.start()
73
+
74
+ buffer = ""
75
+ for new_text in streamer:
76
+ buffer += new_text
77
  yield type('Response', (), {
78
  'choices': [type('Choice', (), {
79
+ 'delta': {'content': new_text}
80
  })()]
81
  })()
82