vilarin commited on
Commit
e6367a7
1 Parent(s): b48b00e

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +28 -38
app.py CHANGED
@@ -1,17 +1,29 @@
1
  import torch
2
- from PIL import Image
3
  import gradio as gr
4
  import spaces
5
- from transformers import AutoModelForCausalLM, GemmaTokenizerFast, TextIteratorStreamer,BitsAndBytesConfig
6
  import os
 
7
  from threading import Thread
8
 
9
 
10
  HF_TOKEN = os.environ.get("HF_TOKEN", None)
11
  MODEL_ID = "google/gemma-2-27b-it"
12
- MODELS = os.environ.get("MODELS")
13
- MODEL_NAME = MODELS.split("/")[-1]
14
- MAX_INPUT_TOKEN_LENGTH = int(os.environ.get("MAX_INPUT_TOKEN_LENGTH", "4096"))
 
 
 
 
 
 
 
 
 
 
 
15
 
16
  TITLE = "<h1><center>Chatbox</center></h1>"
17
 
@@ -36,15 +48,6 @@ h3 {
36
  text-align: center;
37
  }
38
  """
39
- if torch.cuda.is_available():
40
- model = AutoModelForCausalLM.from_pretrained(
41
- MODELS,
42
- device_map="auto",
43
- quantization_config=BitsAndBytesConfig(load_in_4bit=True)
44
- )
45
- tokenizer = GemmaTokenizerFast.from_pretrained(MODELS)
46
- model.config.sliding_window = 4096
47
- model.eval()
48
 
49
 
50
  @spaces.GPU(duration=90)
@@ -58,33 +61,20 @@ def stream_chat(message: str, history: list, temperature: float, max_new_tokens:
58
 
59
  print(f"Conversation is -\n{conversation}")
60
 
61
- input_ids = tokenizer.apply_chat_template(conversation, add_generation_prompt=True, return_tensors="pt")
62
- if input_ids.shape[1] > MAX_INPUT_TOKEN_LENGTH:
63
- input_ids = input_ids[:, -MAX_INPUT_TOKEN_LENGTH:]
64
- gr.Warning(f"Trimmed input from conversation as it was longer than {MAX_INPUT_TOKEN_LENGTH} tokens.")
65
- input_ids = input_ids.to(0)
66
-
67
- streamer = TextIteratorStreamer(tokenizer, timeout=20.0, skip_prompt=True, skip_special_tokens=True)
68
-
69
- generate_kwargs = dict(
70
- {"input_ids": input_ids},
71
- streamer=streamer,
72
  top_k=top_k,
73
  top_p=top_p,
74
- repetition_penalty=penalty,
75
- max_new_tokens=max_new_tokens,
76
- do_sample=True,
77
- temperature=temperature,
78
- num_beams=1,
79
  )
80
 
81
- thread = Thread(target=model.generate, kwargs=generate_kwargs)
82
- thread.start()
83
-
84
- buffer = ""
85
- for new_text in streamer:
86
- buffer += new_text
87
- yield buffer
88
 
89
 
90
 
@@ -113,7 +103,7 @@ with gr.Blocks(css=CSS, theme="soft") as demo:
113
  maximum=2048,
114
  step=1,
115
  value=1024,
116
- label="Max new tokens",
117
  render=False,
118
  ),
119
  gr.Slider(
 
1
  import torch
2
+ import copy
3
  import gradio as gr
4
  import spaces
5
+ from llama_cpp import Llama
6
  import os
7
+ from huggingface_hub import hf_hub_download
8
  from threading import Thread
9
 
10
 
11
  HF_TOKEN = os.environ.get("HF_TOKEN", None)
12
  MODEL_ID = "google/gemma-2-27b-it"
13
+ MODEL_NAME = MODEL_ID.split("/")[-1]
14
+ MODEL_FILE = "gemma-2-27b-it-Q4_K_M.gguf"
15
+
16
+ os.environ["HF_HUB_ENABLE_HF_TRANSFER"] = "1"
17
+
18
+ llm = Llama(
19
+ model_path=hf_hub_download(
20
+ repo_id=os.environ.get(MODEL_ID),
21
+ filename=os.environ.get(MODEL_FILE),
22
+ ),
23
+ n_ctx=4096,
24
+ n_gpu_layers=-1,
25
+ chat_format="gemma",
26
+ )
27
 
28
  TITLE = "<h1><center>Chatbox</center></h1>"
29
 
 
48
  text-align: center;
49
  }
50
  """
 
 
 
 
 
 
 
 
 
51
 
52
 
53
  @spaces.GPU(duration=90)
 
61
 
62
  print(f"Conversation is -\n{conversation}")
63
 
64
+ output = llm.create_chat_completion(
65
+ messages=conversation,
 
 
 
 
 
 
 
 
 
66
  top_k=top_k,
67
  top_p=top_p,
68
+ repeat_penalty=penalty,
69
+ max_tokens=max_new_tokens,
70
+ stream =True,
71
+ temperature=temperature,
 
72
  )
73
 
74
+ for out in output:
75
+ stream = copy.deepcopy(out)
76
+ temp += stream["choices"][0]["text"]
77
+ yield temp
 
 
 
78
 
79
 
80
 
 
103
  maximum=2048,
104
  step=1,
105
  value=1024,
106
+ label="Max Tokens",
107
  render=False,
108
  ),
109
  gr.Slider(