cutechicken commited on
Commit
f4a0f87
โ€ข
1 Parent(s): 0b984ea

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +54 -19
app.py CHANGED
@@ -12,34 +12,69 @@ from transformers import AutoModelForCausalLM, AutoTokenizer
12
  HF_TOKEN = os.getenv("HF_TOKEN")
13
  MODEL_ID = "CohereForAI/c4ai-command-r7b-12-2024"
14
 
 
 
15
  class ModelManager:
16
  def __init__(self):
17
- self.model = None
18
- self.tokenizer = None
19
- self.setup_model()
20
 
21
- def setup_model(self):
22
  try:
23
- print("ํ† ํฌ๋‚˜์ด์ € ๋กœ๋”ฉ ์‹œ์ž‘...")
24
- self.tokenizer = AutoTokenizer.from_pretrained(
25
- MODEL_ID,
26
- token=HF_TOKEN,
27
- trust_remote_code=True
28
- )
29
- print("ํ† ํฌ๋‚˜์ด์ € ๋กœ๋”ฉ ์™„๋ฃŒ")
30
-
31
- print("๋ชจ๋ธ ๋กœ๋”ฉ ์‹œ์ž‘...")
32
- self.model = AutoModelForCausalLM.from_pretrained(
33
- MODEL_ID,
34
  token=HF_TOKEN,
35
  device_map="auto",
36
- trust_remote_code=True,
37
  torch_dtype=torch.float16
38
  )
39
- print("๋ชจ๋ธ ๋กœ๋”ฉ ์™„๋ฃŒ")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
40
  except Exception as e:
41
- print(f"๋ชจ๋ธ ๋กœ๋”ฉ ์ค‘ ์˜ค๋ฅ˜ ๋ฐœ์ƒ: {e}")
42
- raise Exception(f"๋ชจ๋ธ ๋กœ๋”ฉ ์‹คํŒจ: {e}")
43
 
44
  class ChatHistory:
45
  def __init__(self):
 
12
  HF_TOKEN = os.getenv("HF_TOKEN")
13
  MODEL_ID = "CohereForAI/c4ai-command-r7b-12-2024"
14
 
15
+ from transformers import pipeline
16
+
17
  class ModelManager:
18
  def __init__(self):
19
+ self.pipe = None
20
+ self.setup_pipeline()
 
21
 
22
+ def setup_pipeline(self):
23
  try:
24
+ print("ํŒŒ์ดํ”„๋ผ์ธ ์ดˆ๊ธฐํ™” ์‹œ์ž‘...")
25
+ self.pipe = pipeline(
26
+ "text-generation",
27
+ model=MODEL_ID,
 
 
 
 
 
 
 
28
  token=HF_TOKEN,
29
  device_map="auto",
 
30
  torch_dtype=torch.float16
31
  )
32
+ print("ํŒŒ์ดํ”„๋ผ์ธ ์ดˆ๊ธฐํ™” ์™„๋ฃŒ")
33
+ except Exception as e:
34
+ print(f"ํŒŒ์ดํ”„๋ผ์ธ ์ดˆ๊ธฐํ™” ์ค‘ ์˜ค๋ฅ˜ ๋ฐœ์ƒ: {e}")
35
+ raise Exception(f"ํŒŒ์ดํ”„๋ผ์ธ ์ดˆ๊ธฐํ™” ์‹คํŒจ: {e}")
36
+
37
+ def generate_response(self, messages, max_tokens=4000, temperature=0.7, top_p=0.9):
38
+ try:
39
+ # ๋ฉ”์‹œ์ง€ ํ˜•์‹ ๋ณ€ํ™˜
40
+ prompt = ""
41
+ for msg in messages:
42
+ role = msg["role"]
43
+ content = msg["content"]
44
+ if role == "system":
45
+ prompt += f"System: {content}\n"
46
+ elif role == "user":
47
+ prompt += f"User: {content}\n"
48
+ elif role == "assistant":
49
+ prompt += f"Assistant: {content}\n"
50
+
51
+ # ์‘๋‹ต ์ƒ์„ฑ
52
+ response = self.pipe(
53
+ prompt,
54
+ max_new_tokens=max_tokens,
55
+ temperature=temperature,
56
+ top_p=top_p,
57
+ do_sample=True,
58
+ num_return_sequences=1,
59
+ pad_token_id=self.pipe.tokenizer.eos_token_id
60
+ )
61
+
62
+ # ์‘๋‹ต ํ…์ŠคํŠธ ์ถ”์ถœ ๋ฐ ์ŠคํŠธ๋ฆฌ๋ฐ ์‹œ๋ฎฌ๋ ˆ์ด์…˜
63
+ generated_text = response[0]['generated_text'][len(prompt):].strip()
64
+ words = generated_text.split()
65
+
66
+ # ๋‹จ์–ด ๋‹จ์œ„๋กœ ์ŠคํŠธ๋ฆฌ๋ฐ
67
+ partial_response = ""
68
+ for word in words:
69
+ partial_response += word + " "
70
+ yield type('Response', (), {
71
+ 'choices': [type('Choice', (), {
72
+ 'delta': {'content': word + " "}
73
+ })()]
74
+ })()
75
+
76
  except Exception as e:
77
+ raise Exception(f"์‘๋‹ต ์ƒ์„ฑ ์‹คํŒจ: {e}")
 
78
 
79
  class ChatHistory:
80
  def __init__(self):