vilarin commited on
Commit
0278a97
1 Parent(s): 0fc329a

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +18 -22
app.py CHANGED
@@ -9,11 +9,9 @@ import torch
9
  from PIL import Image
10
  import gradio as gr
11
  import spaces
12
- from transformers import AutoModelForCausalLM, AutoProcessor, TextIteratorStreamer
13
  import os
14
  import time
15
- from huggingface_hub import hf_hub_download
16
-
17
 
18
 
19
  os.environ["HF_HUB_ENABLE_HF_TRANSFER"] = "1"
@@ -24,7 +22,7 @@ MODEL_NAME = MODEL_ID.split("/")[-1]
24
 
25
  TITLE = "<h1><center>VL-Chatbox</center></h1>"
26
 
27
- DESCRIPTION = "<h3><center>MODEL: " + MODEL_NAME + "</center></h3>"
28
 
29
  CSS = """
30
  .duplicate-button {
@@ -35,15 +33,13 @@ CSS = """
35
  }
36
  """
37
 
38
- model = AutoModelForCausalLM.from_pretrained(
39
  MODEL_ID,
40
  torch_dtype=torch.float16,
41
- low_cpu_mem_usage=True,
42
  trust_remote_code=True
43
  ).to(0)
44
- processor = AutoProcessor.from_pretrained(MODEL_ID, trust_remote_code=True)
45
- eos_token_id=processor.tokenizer.eos_token_id
46
-
47
 
48
 
49
 
@@ -53,8 +49,8 @@ def stream_chat(message, history: list, temperature: float, max_new_tokens: int)
53
  print(f'history is - {history}')
54
  conversation = []
55
  if message["files"]:
56
- image = Image.open(message["files"][-1])
57
- conversation.append({"role": "user", "content": f"<|image_1|>\n{message['text']}"})
58
  else:
59
  if len(history) == 0:
60
  raise gr.Error("Please upload an image first.")
@@ -62,29 +58,29 @@ def stream_chat(message, history: list, temperature: float, max_new_tokens: int)
62
  else:
63
  image = Image.open(history[0][0][0])
64
  for prompt, answer in history:
65
- if answer is None:
66
- conversation.extend([{"role": "user", "content":"<|image_1|>"},{"role": "assistant", "content": ""}])
67
- else:
68
  conversation.extend([{"role": "user", "content": prompt}, {"role": "assistant", "content": answer}])
69
  conversation.append({"role": "user", "content": message['text']})
70
  print(f"Conversation is -\n{conversation}")
71
- inputs = processor.tokenizer.apply_chat_template(conversation, tokenize=False, add_generation_prompt=True)
72
- inputs_ids = processor(inputs, image, return_tensors="pt").to(0)
73
- streamer = TextIteratorStreamer(processor, **{"skip_special_tokens": True, "skip_prompt": True, 'clean_up_tokenization_spaces':False,})
74
 
75
  generate_kwargs = dict(
 
 
76
  streamer=streamer,
77
  max_new_tokens=max_new_tokens,
78
  temperature=temperature,
79
- do_sample=True,
80
- eos_token_id=eos_token_id,
81
  )
82
  if temperature == 0:
83
- generate_kwargs["do_sample"] = False
84
- generate_kwargs = {**inputs_ids, **generate_kwargs}
85
 
86
 
87
- thread = Thread(target=model.generate, kwargs=generate_kwargs)
88
  thread.start()
89
 
90
  buffer = ""
 
9
  from PIL import Image
10
  import gradio as gr
11
  import spaces
12
+ from transformers import AutoModel, AutoProcessor, TextIteratorStreamer
13
  import os
14
  import time
 
 
15
 
16
 
17
  os.environ["HF_HUB_ENABLE_HF_TRANSFER"] = "1"
 
22
 
23
  TITLE = "<h1><center>VL-Chatbox</center></h1>"
24
 
25
+ DESCRIPTION = "<h3><center>MODEL: " + f'[{MODEL_NAME}](https://hf.co/models/{MODEL_NAME})' + "</center></h3>"
26
 
27
  CSS = """
28
  .duplicate-button {
 
33
  }
34
  """
35
 
36
+ model = AutoModel.from_pretrained(
37
  MODEL_ID,
38
  torch_dtype=torch.float16,
 
39
  trust_remote_code=True
40
  ).to(0)
41
+ tokenizer = AutoTokenizer.from_pretrained(MODEL_ID, trust_remote_code=True)
42
+ model.eval()
 
43
 
44
 
45
 
 
49
  print(f'history is - {history}')
50
  conversation = []
51
  if message["files"]:
52
+ image = Image.open(message["files"][-1]).convert('RGB')
53
+ conversation.append({"role": "user", "content": message['text']})
54
  else:
55
  if len(history) == 0:
56
  raise gr.Error("Please upload an image first.")
 
58
  else:
59
  image = Image.open(history[0][0][0])
60
  for prompt, answer in history:
61
+ # if answer is None:
62
+ # conversation.extend([{"role": "user", "content":"<|image_1|>"},{"role": "assistant", "content": ""}])
63
+ # else:
64
  conversation.extend([{"role": "user", "content": prompt}, {"role": "assistant", "content": answer}])
65
  conversation.append({"role": "user", "content": message['text']})
66
  print(f"Conversation is -\n{conversation}")
67
+
68
+ streamer = TextIteratorStreamer(tokenizer, **{"skip_special_tokens": True, "skip_prompt": True, 'clean_up_tokenization_spaces':False,})
 
69
 
70
  generate_kwargs = dict(
71
+ image=image,
72
+ msg=conversation,
73
  streamer=streamer,
74
  max_new_tokens=max_new_tokens,
75
  temperature=temperature,
76
+ sampling=True,
77
+ tokenizer=tokenizer,
78
  )
79
  if temperature == 0:
80
+ generate_kwargs["sampling"] = False
 
81
 
82
 
83
+ thread = Thread(target=model.chat, kwargs=generate_kwargs)
84
  thread.start()
85
 
86
  buffer = ""