nroggendorff commited on
Commit
6196dea
1 Parent(s): c4a8e34

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +5 -3
app.py CHANGED
@@ -1,8 +1,9 @@
1
  import gradio as gr
2
  from transformers import AutoTokenizer, AutoModelForCausalLM, BitsAndBytesConfig
3
  import torch
 
4
 
5
- torch.set_default_device("cuda")
6
 
7
  bnb_config = BitsAndBytesConfig(
8
  load_in_4bit=True,
@@ -11,11 +12,12 @@ bnb_config = BitsAndBytesConfig(
11
  bnb_4bit_compute_dtype=torch.bfloat16
12
  )
13
 
14
- model_id = "cognitivecomputations/dolphin-2.9.3-mistral-7B-32k"
15
 
16
  tokenizer = AutoTokenizer.from_pretrained(model_id)
17
  model = AutoModelForCausalLM.from_pretrained(model_id, quantization_config=bnb_config)
18
 
 
19
  def predict(input_text, history):
20
  chat = []
21
  for item in history:
@@ -26,7 +28,7 @@ def predict(input_text, history):
26
 
27
  conv = tokenizer.apply_chat_template(chat, tokenize=False)
28
  inputs = tokenizer(conv, return_tensors="pt").to("cuda")
29
- outputs = model.generate(**inputs, max_new_tokens=512)
30
 
31
  generated_text = tokenizer.batch_decode(outputs, skip_special_tokens=True)[0]
32
  return generated_text.split("<|assistant|>")[-1]
 
1
  import gradio as gr
2
  from transformers import AutoTokenizer, AutoModelForCausalLM, BitsAndBytesConfig
3
  import torch
4
+ from spaces import GPU
5
 
6
+ GPU = lambda: GPU(duration=70)
7
 
8
  bnb_config = BitsAndBytesConfig(
9
  load_in_4bit=True,
 
12
  bnb_4bit_compute_dtype=torch.bfloat16
13
  )
14
 
15
+ model_id = "cognitivecomputations/dolphin-2.5-mixtral-8x7b"
16
 
17
  tokenizer = AutoTokenizer.from_pretrained(model_id)
18
  model = AutoModelForCausalLM.from_pretrained(model_id, quantization_config=bnb_config)
19
 
20
+ @GPU
21
  def predict(input_text, history):
22
  chat = []
23
  for item in history:
 
28
 
29
  conv = tokenizer.apply_chat_template(chat, tokenize=False)
30
  inputs = tokenizer(conv, return_tensors="pt").to("cuda")
31
+ outputs = model.generate(**inputs, max_new_tokens=2048)
32
 
33
  generated_text = tokenizer.batch_decode(outputs, skip_special_tokens=True)[0]
34
  return generated_text.split("<|assistant|>")[-1]