archit11 commited on
Commit
22ff5cb
·
verified ·
1 Parent(s): f146c64

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +61 -51
app.py CHANGED
@@ -1,14 +1,10 @@
1
  import os
2
- import spaces
3
-
4
- from threading import Thread
5
- from typing import Iterator, List, Tuple
6
- import json
7
  import requests
8
-
9
  import gradio as gr
10
  import torch
11
- import transformers
 
 
12
  from transformers import AutoModelForCausalLM, AutoTokenizer, TextIteratorStreamer
13
 
14
  # Description for the Gradio Interface
@@ -36,17 +32,20 @@ models = {}
36
  tokenizers = {}
37
 
38
  for model_id in MODEL_OPTIONS:
39
- tokenizers[model_id] = AutoTokenizer.from_pretrained(model_id)
40
- models[model_id] = AutoModelForCausalLM.from_pretrained(
41
- model_id,
42
- device_map="auto",
43
- load_in_8bit=True,
44
- )
45
- models[model_id].eval()
46
-
47
- # Set pad_token_id to eos_token_id if it's not set
48
- if tokenizers[model_id].pad_token_id is None:
49
- tokenizers[model_id].pad_token_id = tokenizers[model_id].eos_token_id
 
 
 
50
 
51
  # Function to log comparisons
52
  def log_comparison(model1_name: str, model2_name: str, question: str, answer1: str, answer2: str, winner: str = None):
@@ -57,7 +56,6 @@ def log_comparison(model1_name: str, model2_name: str, question: str, answer1: s
57
  "winner": winner
58
  }
59
 
60
- # Send log data to remote server
61
  try:
62
  response = requests.post('http://144.24.151.32:5000/log', json=log_data, timeout=5)
63
  if response.status_code == 200:
@@ -70,14 +68,17 @@ def log_comparison(model1_name: str, model2_name: str, question: str, answer1: s
70
  # Function to prepare input
71
  def prepare_input(model_id: str, message: str, chat_history: List[Tuple[str, str]]):
72
  tokenizer = tokenizers[model_id]
73
- # Prepare inputs for the model
74
- inputs = tokenizer(
75
- [x[1] for x in chat_history] + [message],
76
- return_tensors="pt",
77
- truncation=True,
78
- padding=True,
79
- max_length=MAX_INPUT_TOKEN_LENGTH,
80
- )
 
 
 
81
  return inputs
82
 
83
  # Function to generate responses from models
@@ -101,24 +102,28 @@ def generate(
101
  gr.Warning(f"Trimmed input from conversation as it was longer than {MAX_INPUT_TOKEN_LENGTH} tokens.")
102
  input_ids = input_ids.to(model.device)
103
 
104
- streamer = TextIteratorStreamer(tokenizer, timeout=10.0, skip_prompt=True, skip_special_tokens=True)
105
- generate_kwargs = dict(
106
- input_ids=input_ids,
107
- streamer=streamer,
108
- max_new_tokens=max_new_tokens,
109
- do_sample=True,
110
- top_p=top_p,
111
- temperature=temperature,
112
- num_beams=1,
113
- pad_token_id=tokenizer.eos_token_id,
114
- )
115
- t = Thread(target=model.generate, kwargs=generate_kwargs)
116
- t.start()
117
-
118
- outputs = []
119
- for text in streamer:
120
- outputs.append(text)
121
- yield "".join(outputs)
 
 
 
 
122
 
123
  # Function to compare two models
124
  def compare_models(
@@ -135,15 +140,20 @@ def compare_models(
135
  error_message = [("System", "Error: Please select two different models.")]
136
  return error_message, error_message, chat_history1, chat_history2
137
 
138
- output1 = "".join(list(generate(model1_name, message, chat_history1, max_new_tokens, temperature, top_p)))
139
- output2 = "".join(list(generate(model2_name, message, chat_history2, max_new_tokens, temperature, top_p)))
 
140
 
141
- chat_history1.append((message, output1))
142
- chat_history2.append((message, output2))
143
 
144
- log_comparison(model1_name, model2_name, message, output1, output2)
145
 
146
- return chat_history1, chat_history2, chat_history1, chat_history2
 
 
 
 
147
 
148
  # Function to log the voting result
149
  def vote_better(model1_name, model2_name, question, answer1, answer2, choice):
 
1
  import os
 
 
 
 
 
2
  import requests
 
3
  import gradio as gr
4
  import torch
5
+ import spaces
6
+ from threading import Thread
7
+ from typing import Iterator, List, Tuple
8
  from transformers import AutoModelForCausalLM, AutoTokenizer, TextIteratorStreamer
9
 
10
  # Description for the Gradio Interface
 
32
  tokenizers = {}
33
 
34
  for model_id in MODEL_OPTIONS:
35
+ try:
36
+ tokenizers[model_id] = AutoTokenizer.from_pretrained(model_id)
37
+ models[model_id] = AutoModelForCausalLM.from_pretrained(
38
+ model_id,
39
+ device_map="auto",
40
+ load_in_8bit=True,
41
+ )
42
+ models[model_id].eval()
43
+
44
+ # Set pad_token_id to eos_token_id if it's not set
45
+ if tokenizers[model_id].pad_token_id is None:
46
+ tokenizers[model_id].pad_token_id = tokenizers[model_id].eos_token_id
47
+ except Exception as e:
48
+ print(f"Error loading model {model_id}: {e}")
49
 
50
  # Function to log comparisons
51
  def log_comparison(model1_name: str, model2_name: str, question: str, answer1: str, answer2: str, winner: str = None):
 
56
  "winner": winner
57
  }
58
 
 
59
  try:
60
  response = requests.post('http://144.24.151.32:5000/log', json=log_data, timeout=5)
61
  if response.status_code == 200:
 
68
  # Function to prepare input
69
  def prepare_input(model_id: str, message: str, chat_history: List[Tuple[str, str]]):
70
  tokenizer = tokenizers[model_id]
71
+ try:
72
+ inputs = tokenizer(
73
+ [x[1] for x in chat_history] + [message],
74
+ return_tensors="pt",
75
+ truncation=True,
76
+ padding=True,
77
+ max_length=MAX_INPUT_TOKEN_LENGTH,
78
+ )
79
+ except Exception as e:
80
+ print(f"Error preparing input for model {model_id}: {e}")
81
+ inputs = tokenizer([message], return_tensors="pt", padding=True, max_length=MAX_INPUT_TOKEN_LENGTH)
82
  return inputs
83
 
84
  # Function to generate responses from models
 
102
  gr.Warning(f"Trimmed input from conversation as it was longer than {MAX_INPUT_TOKEN_LENGTH} tokens.")
103
  input_ids = input_ids.to(model.device)
104
 
105
+ try:
106
+ streamer = TextIteratorStreamer(tokenizer, timeout=10.0, skip_prompt=True, skip_special_tokens=True)
107
+ generate_kwargs = dict(
108
+ input_ids=input_ids,
109
+ streamer=streamer,
110
+ max_new_tokens=max_new_tokens,
111
+ do_sample=True,
112
+ top_p=top_p,
113
+ temperature=temperature,
114
+ num_beams=1,
115
+ pad_token_id=tokenizer.eos_token_id,
116
+ )
117
+ t = Thread(target=model.generate, kwargs=generate_kwargs)
118
+ t.start()
119
+
120
+ outputs = []
121
+ for text in streamer:
122
+ outputs.append(text)
123
+ yield "".join(outputs)
124
+ except Exception as e:
125
+ print(f"Error generating response from model {model_id}: {e}")
126
+ yield "Error generating response."
127
 
128
  # Function to compare two models
129
  def compare_models(
 
140
  error_message = [("System", "Error: Please select two different models.")]
141
  return error_message, error_message, chat_history1, chat_history2
142
 
143
+ try:
144
+ output1 = "".join(list(generate(model1_name, message, chat_history1, max_new_tokens, temperature, top_p)))
145
+ output2 = "".join(list(generate(model2_name, message, chat_history2, max_new_tokens, temperature, top_p)))
146
 
147
+ chat_history1.append((message, output1))
148
+ chat_history2.append((message, output2))
149
 
150
+ log_comparison(model1_name, model2_name, message, output1, output2)
151
 
152
+ return chat_history1, chat_history2, chat_history1, chat_history2
153
+ except Exception as e:
154
+ print(f"Error comparing models: {e}")
155
+ error_message = [("System", "Error comparing models.")]
156
+ return error_message, error_message, chat_history1, chat_history2
157
 
158
  # Function to log the voting result
159
  def vote_better(model1_name, model2_name, question, answer1, answer2, choice):