dev/check-gpu-for-rinna

#10
by terapyon - opened
Files changed (1) hide show
  1. app.py +29 -18
app.py CHANGED
@@ -25,14 +25,17 @@ E5_EMBEDDINGS = HuggingFaceEmbeddings(
25
  encode_kwargs=E5_ENCODE_KWARGS,
26
  )
27
 
28
- RINNA_MODEL_NAME = "rinna/bilingual-gpt-neox-4b-instruction-ppo"
29
- RINNA_TOKENIZER = AutoTokenizer.from_pretrained(RINNA_MODEL_NAME, use_fast=False)
30
- RINNA_MODEL = AutoModelForCausalLM.from_pretrained(
31
- RINNA_MODEL_NAME,
32
- load_in_8bit=True,
33
- torch_dtype=torch.float16,
34
- device_map="auto",
35
- )
 
 
 
36
 
37
 
38
  def _get_config_and_embeddings(collection_name: str | None) -> tuple:
@@ -48,14 +51,17 @@ def _get_config_and_embeddings(collection_name: str | None) -> tuple:
48
 
49
 
50
  def _get_rinna_llm(temperature: float):
51
- pipe = pipeline(
52
- "text-generation",
53
- model=RINNA_MODEL,
54
- tokenizer=RINNA_TOKENIZER,
55
- max_new_tokens=1024,
56
- temperature=temperature,
57
- )
58
- llm = HuggingFacePipeline(pipeline=pipe)
 
 
 
59
  return llm
60
 
61
 
@@ -64,7 +70,7 @@ def _get_llm_model(
64
  temperature: float,
65
  ):
66
  if model_name is None:
67
- model = "rinna"
68
  elif model_name == "rinna":
69
  model = "rinna"
70
  elif model_name == "GPT-3.5":
@@ -157,12 +163,17 @@ def main(
157
  return result["result"], html
158
 
159
 
 
 
 
 
 
160
  nvdajp_book_qa = gr.Interface(
161
  fn=main,
162
  inputs=[
163
  gr.Textbox(label="query"),
164
  gr.Radio(["E5", "OpenAI"], label="Embedding", info="選択なしで「E5」を使用"),
165
- gr.Radio(["rinna", "GPT-3.5", "GPT-4"], label="Model", info="選択なしで「rinna」を使用"),
166
  gr.Radio(
167
  ["All", "ja-book", "ja-nvda-user-guide", "en-nvda-user-guide"],
168
  label="絞り込み",
 
25
  encode_kwargs=E5_ENCODE_KWARGS,
26
  )
27
 
28
+ if torch.cuda.is_available():
29
+ RINNA_MODEL_NAME = "rinna/bilingual-gpt-neox-4b-instruction-ppo"
30
+ RINNA_TOKENIZER = AutoTokenizer.from_pretrained(RINNA_MODEL_NAME, use_fast=False)
31
+ RINNA_MODEL = AutoModelForCausalLM.from_pretrained(
32
+ RINNA_MODEL_NAME,
33
+ load_in_8bit=True,
34
+ torch_dtype=torch.float16,
35
+ device_map="auto",
36
+ )
37
+ else:
38
+ RINNA_MODEL = None
39
 
40
 
41
  def _get_config_and_embeddings(collection_name: str | None) -> tuple:
 
51
 
52
 
53
  def _get_rinna_llm(temperature: float):
54
+ if RINNA_MODEL is not None:
55
+ pipe = pipeline(
56
+ "text-generation",
57
+ model=RINNA_MODEL,
58
+ tokenizer=RINNA_TOKENIZER,
59
+ max_new_tokens=1024,
60
+ temperature=temperature,
61
+ )
62
+ llm = HuggingFacePipeline(pipeline=pipe)
63
+ else:
64
+ llm = None
65
  return llm
66
 
67
 
 
70
  temperature: float,
71
  ):
72
  if model_name is None:
73
+ model = "gpt-3.5-turbo"
74
  elif model_name == "rinna":
75
  model = "rinna"
76
  elif model_name == "GPT-3.5":
 
163
  return result["result"], html
164
 
165
 
166
+ AVAILABLE_LLMS = ["GPT-3.5", "GPT-4"]
167
+
168
+ if RINNA_MODEL is not None:
169
+ AVAILABLE_LLMS.append("rinna")
170
+
171
  nvdajp_book_qa = gr.Interface(
172
  fn=main,
173
  inputs=[
174
  gr.Textbox(label="query"),
175
  gr.Radio(["E5", "OpenAI"], label="Embedding", info="選択なしで「E5」を使用"),
176
+ gr.Radio(AVAILABLE_LLMS, label="Model", info="選択なしで「GPT-3.5」を使用"),
177
  gr.Radio(
178
  ["All", "ja-book", "ja-nvda-user-guide", "en-nvda-user-guide"],
179
  label="絞り込み",