cutechicken commited on
Commit
bacca03
โ€ข
1 Parent(s): 63b4531

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +31 -37
app.py CHANGED
@@ -22,6 +22,7 @@ class ModelManager:
22
  if self.model is None or self.tokenizer is None:
23
  self.setup_model()
24
 
 
25
  def setup_model(self):
26
  try:
27
  print("ํ† ํฌ๋‚˜์ด์ € ๋กœ๋”ฉ ์‹œ์ž‘...")
@@ -36,12 +37,11 @@ class ModelManager:
36
  print("ํ† ํฌ๋‚˜์ด์ € ๋กœ๋”ฉ ์™„๋ฃŒ")
37
 
38
  print("๋ชจ๋ธ ๋กœ๋”ฉ ์‹œ์ž‘...")
39
- # CUDA ์ดˆ๊ธฐํ™” ๋ฐฉ์ง€๋ฅผ ์œ„ํ•œ ์„ค์ •
40
  self.model = AutoModelForCausalLM.from_pretrained(
41
  MODEL_ID,
42
  token=HF_TOKEN,
43
  torch_dtype=torch.float16,
44
- device_map=None, # ์ดˆ๊ธฐ์—๋Š” device_map์„ ์„ค์ •ํ•˜์ง€ ์•Š์Œ
45
  trust_remote_code=True,
46
  low_cpu_mem_usage=True
47
  )
@@ -52,41 +52,11 @@ class ModelManager:
52
  raise Exception(f"๋ชจ๋ธ ๋กœ๋”ฉ ์‹คํŒจ: {e}")
53
 
54
  @spaces.GPU
55
- def generate_text(self, prompt, max_tokens, temperature, top_p):
56
- try:
57
- # GPU ์ปจํ…์ŠคํŠธ ๋‚ด์—์„œ device ์„ค์ •
58
- self.model = self.model.to("cuda")
59
- input_ids = self.tokenizer.encode(
60
- prompt,
61
- return_tensors="pt",
62
- add_special_tokens=True
63
- ).to("cuda")
64
-
65
- with torch.no_grad():
66
- output_ids = self.model.generate(
67
- input_ids,
68
- max_new_tokens=max_tokens,
69
- do_sample=True,
70
- temperature=temperature,
71
- top_p=top_p,
72
- pad_token_id=self.tokenizer.pad_token_id,
73
- eos_token_id=self.tokenizer.eos_token_id,
74
- num_return_sequences=1
75
- )
76
-
77
- # CPU๋กœ ๋‹ค์‹œ ์ด๋™
78
- self.model = self.model.to("cpu")
79
- return self.tokenizer.decode(
80
- output_ids[0][input_ids.shape[1]:],
81
- skip_special_tokens=True
82
- )
83
- except Exception as e:
84
- if self.model.device.type == "cuda":
85
- self.model = self.model.to("cpu")
86
- raise Exception(f"ํ…์ŠคํŠธ ์ƒ์„ฑ ์‹คํŒจ: {e}")
87
-
88
  def generate_response(self, messages, max_tokens=4000, temperature=0.7, top_p=0.9):
89
  try:
 
 
 
90
  # ์ž…๋ ฅ ํ…์ŠคํŠธ ์ค€๋น„
91
  prompt = ""
92
  for msg in messages:
@@ -100,8 +70,32 @@ class ModelManager:
100
  prompt += f"Assistant: {content}\n"
101
  prompt += "Assistant: "
102
 
103
- # spaces.GPU ๋ฐ์ฝ”๋ ˆ์ดํ„ฐ๊ฐ€ ์ ์šฉ๋œ ๋ฉ”์„œ๋“œ ํ˜ธ์ถœ
104
- generated_text = self.generate_text(prompt, max_tokens, temperature, top_p)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
105
 
106
  # ๋‹จ์–ด ๋‹จ์œ„๋กœ ์ŠคํŠธ๋ฆฌ๋ฐ
107
  words = generated_text.split()
 
22
  if self.model is None or self.tokenizer is None:
23
  self.setup_model()
24
 
25
+ @spaces.GPU
26
  def setup_model(self):
27
  try:
28
  print("ํ† ํฌ๋‚˜์ด์ € ๋กœ๋”ฉ ์‹œ์ž‘...")
 
37
  print("ํ† ํฌ๋‚˜์ด์ € ๋กœ๋”ฉ ์™„๋ฃŒ")
38
 
39
  print("๋ชจ๋ธ ๋กœ๋”ฉ ์‹œ์ž‘...")
 
40
  self.model = AutoModelForCausalLM.from_pretrained(
41
  MODEL_ID,
42
  token=HF_TOKEN,
43
  torch_dtype=torch.float16,
44
+ device_map="auto",
45
  trust_remote_code=True,
46
  low_cpu_mem_usage=True
47
  )
 
52
  raise Exception(f"๋ชจ๋ธ ๋กœ๋”ฉ ์‹คํŒจ: {e}")
53
 
54
  @spaces.GPU
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
55
  def generate_response(self, messages, max_tokens=4000, temperature=0.7, top_p=0.9):
56
  try:
57
+ # ๋ชจ๋ธ์ด ๋กœ๋“œ๋˜์–ด ์žˆ๋Š”์ง€ ํ™•์ธ
58
+ self.ensure_model_loaded()
59
+
60
  # ์ž…๋ ฅ ํ…์ŠคํŠธ ์ค€๋น„
61
  prompt = ""
62
  for msg in messages:
 
70
  prompt += f"Assistant: {content}\n"
71
  prompt += "Assistant: "
72
 
73
+ # ํ† ํฌ๋‚˜์ด์ง•
74
+ input_ids = self.tokenizer(
75
+ prompt,
76
+ return_tensors="pt",
77
+ padding=True,
78
+ truncation=True,
79
+ max_length=4096
80
+ ).input_ids
81
+
82
+ # ์ƒ์„ฑ
83
+ outputs = self.model.generate(
84
+ input_ids,
85
+ max_new_tokens=max_tokens,
86
+ do_sample=True,
87
+ temperature=temperature,
88
+ top_p=top_p,
89
+ pad_token_id=self.tokenizer.pad_token_id,
90
+ eos_token_id=self.tokenizer.eos_token_id,
91
+ num_return_sequences=1
92
+ )
93
+
94
+ # ๋””์ฝ”๋”ฉ
95
+ generated_text = self.tokenizer.decode(
96
+ outputs[0][input_ids.shape[1]:],
97
+ skip_special_tokens=True
98
+ )
99
 
100
  # ๋‹จ์–ด ๋‹จ์œ„๋กœ ์ŠคํŠธ๋ฆฌ๋ฐ
101
  words = generated_text.split()