SivilTaram ChengsongHuang commited on
Commit
dba8743
·
1 Parent(s): 2a9df48

Update app.py (#5)

Browse files

- Update app.py (ab1022cbde1b0a44b497412f0742751b9ceac830)


Co-authored-by: Chengsong Huang <ChengsongHuang@users.noreply.huggingface.co>

Files changed (1) hide show
  1. app.py +33 -3
app.py CHANGED
@@ -12,6 +12,12 @@ import torch
12
  import shutil
13
  import os
14
  import uuid
 
 
 
 
 
 
15
 
16
 
17
  css = """
@@ -21,7 +27,6 @@ css = """
21
  """
22
  st.markdown(css, unsafe_allow_html=True)
23
 
24
-
25
  def main():
26
  st.title("💡 LoraHub")
27
  st.markdown("Low-rank adaptations (LoRA) are techniques for fine-tuning large language models on new tasks. We propose LoraHub, a framework that allows composing multiple LoRA modules trained on different tasks. The goal is to achieve good performance on unseen tasks using just a few examples, without needing extra parameters or training. And we want to build a marketplace where users can share their trained LoRA modules, thereby facilitating the application of these modules to new tasks.")
@@ -105,12 +110,28 @@ Infer the date from context. Q: Today is the second day of the third month of 1
105
  txt_input, txt_output, max_inference_step=max_step)
106
 
107
  st.success("Lorahub learning finished! You got the following recommendation:")
 
108
  df = {
109
  "modules": [LORA_HUB_NAMES[i] for i in st.session_state["select_names"]],
110
  "weights": recommendation.value,
111
  }
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
112
  st.table(df)
113
-
114
  random_id = uuid.uuid4().hex
115
  os.makedirs(f"lora/{random_id}")
116
  # copy config file
@@ -126,7 +147,16 @@ Infer the date from context. Q: Today is the second day of the third month of 1
126
  file_name=f"lora_{random_id}.zip",
127
  mime="application/zip"
128
  )
129
- st.warning("The page will be refreshed once you click the download button.")
 
 
 
 
 
 
 
 
 
130
 
131
 
132
 
 
12
  import shutil
13
  import os
14
  import uuid
15
+ import json
16
+
17
+
18
+ from google.oauth2 import service_account
19
+ import gspread
20
+ from google.oauth2.service_account import Credentials
21
 
22
 
23
  css = """
 
27
  """
28
  st.markdown(css, unsafe_allow_html=True)
29
 
 
30
  def main():
31
  st.title("💡 LoraHub")
32
  st.markdown("Low-rank adaptations (LoRA) are techniques for fine-tuning large language models on new tasks. We propose LoraHub, a framework that allows composing multiple LoRA modules trained on different tasks. The goal is to achieve good performance on unseen tasks using just a few examples, without needing extra parameters or training. And we want to build a marketplace where users can share their trained LoRA modules, thereby facilitating the application of these modules to new tasks.")
 
110
  txt_input, txt_output, max_inference_step=max_step)
111
 
112
  st.success("Lorahub learning finished! You got the following recommendation:")
113
+
114
  df = {
115
  "modules": [LORA_HUB_NAMES[i] for i in st.session_state["select_names"]],
116
  "weights": recommendation.value,
117
  }
118
+
119
+
120
+
121
+ def share():
122
+ credentials = service_account.Credentials.from_service_account_info(
123
+ json.loads(st.secrets["gcp_service_account"]),
124
+ scopes=[
125
+ "https://www.googleapis.com/auth/spreadsheets",
126
+ ]
127
+ )
128
+ gsheet_url = st.secrets["private_gsheets_url"]
129
+ gc = gspread.authorize(credentials)
130
+ sh = gc.open_by_url(gsheet_url)
131
+
132
+ ws = sh.sheet1
133
+ ws.insert_rows([[LORA_HUB_NAMES[i] for i in st.session_state["select_names"]],recommendation.value.tolist(),[]])
134
  st.table(df)
 
135
  random_id = uuid.uuid4().hex
136
  os.makedirs(f"lora/{random_id}")
137
  # copy config file
 
147
  file_name=f"lora_{random_id}.zip",
148
  mime="application/zip"
149
  )
150
+ with open(f"lora_{random_id}.zip", "rb") as fp:
151
+ btn = st.download_button(
152
+ label="📥 Download the final LoRA Module and share your results",
153
+ data=fp,
154
+ file_name=f"lora_{random_id}.zip",
155
+ mime="application/zip",
156
+ on_click=share
157
+ )
158
+ st.button("📥 Share your results",on_click=share)
159
+ st.warning("The page will be refreshed once you click the download button. Share results may cost 1-2 mins.")
160
 
161
 
162