inflaton commited on
Commit
f850f3b
1 Parent(s): 01f4bd7

completed gradio app

Browse files
Files changed (2) hide show
  1. app.py +20 -7
  2. requirements.txt +0 -1
app.py CHANGED
@@ -8,19 +8,33 @@ from transformers import (
8
  )
9
  import os
10
  from threading import Thread
11
- import spaces
12
  import subprocess
13
 
 
 
 
 
 
 
 
 
 
14
  subprocess.run(
15
  "pip install flash-attn --no-build-isolation",
16
  env={"FLASH_ATTENTION_SKIP_CUDA_BUILD": "TRUE"},
17
  shell=True,
18
  )
19
 
20
- token = os.getenv("HF_TOKEN")
21
- model_name = (
22
- os.getenv("MODEL_NAME") or "google/gemma-1.1-2b-it"
 
23
  ) # "microsoft/Phi-3-mini-128k-instruct"
 
 
 
 
 
24
 
25
  questions_file_path = (
26
  os.getenv("QUESTIONS_FILE_PATH") or "./data/datasets/ms_macro.json"
@@ -71,7 +85,6 @@ else:
71
  model = model.to(device)
72
 
73
 
74
- @spaces.GPU(duration=60)
75
  def chat(message, history, temperature, repetition_penalty, do_sample, max_tokens):
76
  print("repetition_penalty:", repetition_penalty)
77
  chat = []
@@ -123,13 +136,13 @@ demo = gr.ChatInterface(
123
  ),
124
  additional_inputs=[
125
  gr.Slider(
126
- minimum=0, maximum=1, step=0.1, value=0.9, label="Temperature", render=False
127
  ),
128
  gr.Slider(
129
  minimum=1.0,
130
  maximum=1.5,
131
  step=0.1,
132
- value=1.2,
133
  label="Repetition Penalty",
134
  render=False,
135
  ),
 
8
  )
9
  import os
10
  from threading import Thread
 
11
  import subprocess
12
 
13
+ from dotenv import find_dotenv, load_dotenv
14
+
15
+ found_dotenv = find_dotenv(".env")
16
+
17
+ if len(found_dotenv) == 0:
18
+ found_dotenv = find_dotenv(".env.example")
19
+ print(f"loading env vars from: {found_dotenv}")
20
+ load_dotenv(found_dotenv, override=False)
21
+
22
  subprocess.run(
23
  "pip install flash-attn --no-build-isolation",
24
  env={"FLASH_ATTENTION_SKIP_CUDA_BUILD": "TRUE"},
25
  shell=True,
26
  )
27
 
28
+ token = os.getenv("HUGGINGFACE_AUTH_TOKEN")
29
+
30
+ model_name = os.getenv(
31
+ "HUGGINGFACE_MODEL_NAME_OR_PATH", "google/gemma-1.1-2b-it"
32
  ) # "microsoft/Phi-3-mini-128k-instruct"
33
+ print(f" model_name: {model_name}")
34
+
35
+ HF_RP = os.getenv("HF_RP", "1.2")
36
+ repetition_penalty = float(HF_RP)
37
+ print(f" repetition_penalty: {repetition_penalty}")
38
 
39
  questions_file_path = (
40
  os.getenv("QUESTIONS_FILE_PATH") or "./data/datasets/ms_macro.json"
 
85
  model = model.to(device)
86
 
87
 
 
88
  def chat(message, history, temperature, repetition_penalty, do_sample, max_tokens):
89
  print("repetition_penalty:", repetition_penalty)
90
  chat = []
 
136
  ),
137
  additional_inputs=[
138
  gr.Slider(
139
+ minimum=0, maximum=1, step=0.1, value=0, label="Temperature", render=False
140
  ),
141
  gr.Slider(
142
  minimum=1.0,
143
  maximum=1.5,
144
  step=0.1,
145
+ value=repetition_penalty,
146
  label="Repetition Penalty",
147
  render=False,
148
  ),
requirements.txt CHANGED
@@ -6,7 +6,6 @@ transformers==4.40.1
6
  accelerate==0.29.3
7
  python-dotenv==1.0.1
8
  gradio==4.26.0
9
- spaces==0.27.1
10
  black==24.4.0
11
  chardet==5.2.0
12
  sentencepiece==0.2.0
 
6
  accelerate==0.29.3
7
  python-dotenv==1.0.1
8
  gradio==4.26.0
 
9
  black==24.4.0
10
  chardet==5.2.0
11
  sentencepiece==0.2.0