Spaces:
Runtime error
Runtime error
rogerkoranteng
commited on
Upload folder using huggingface_hub
Browse files- fined-tuned.lora.h5 +3 -0
- main.py +53 -25
- requirements.txt +0 -65
fined-tuned.lora.h5
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:3f50855153040325fe2a203e7ee03c3aa8c98f3f3db7cccc435fa759527bd7b5
|
3 |
+
size 5560280
|
main.py
CHANGED
@@ -1,32 +1,60 @@
|
|
1 |
-
from openai import OpenAI
|
2 |
import gradio as gr
|
3 |
-
|
|
|
|
|
4 |
|
5 |
-
load_dotenv()
|
6 |
|
7 |
-
|
8 |
|
|
|
|
|
9 |
|
10 |
-
|
11 |
-
|
12 |
-
|
13 |
-
formatted_history.append({"role": "user", "content": user})
|
14 |
-
formatted_history.append({"role": "assistant", "content": assistant})
|
15 |
-
|
16 |
-
formatted_history.append({"role": "user", "content": message})
|
17 |
-
|
18 |
-
response = client.chat.completions.create(model='gpt-3.5-turbo',
|
19 |
-
messages=formatted_history,
|
20 |
-
temperature=1.0)
|
21 |
|
22 |
-
|
|
|
|
|
23 |
|
24 |
-
|
25 |
-
|
26 |
-
|
27 |
-
|
28 |
-
|
29 |
-
|
30 |
-
|
31 |
-
|
32 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
import gradio as gr
|
2 |
+
import os
|
3 |
+
import keras_nlp
|
4 |
+
from transformers import AutoModelForCausalLM
|
5 |
|
|
|
6 |
|
7 |
+
# Set Kaggle API credentials
|
8 |
|
9 |
+
os.environ["KAGGLE_USERNAME"] = "rogerkorantenng"
|
10 |
+
os.environ["KAGGLE_KEY"] = "9a33b6e88bcb6058b1281d777fa6808d"
|
11 |
|
12 |
+
# Load LoRA weights if you have them
|
13 |
+
LoRA_weights_path = "fined-tuned.lora.h5"
|
14 |
+
gemma_lm = keras_nlp.models.GemmaCausalLM.from_preset("gemma_2b_en")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
15 |
|
16 |
+
gemma_lm.backbone.enable_lora(rank=4) # Enable LoRA with rank 4
|
17 |
+
gemma_lm.preprocessor.sequence_length = 512 # Limit sequence length
|
18 |
+
gemma_lm.backbone.load_lora_weights(LoRA_weights_path) # Load LoRA weights
|
19 |
|
20 |
+
# Define the response generation function
|
21 |
+
def generate_response(message, history):
|
22 |
+
# Create a prompt template
|
23 |
+
template = "Instruction:\n{instruction}\n\nResponse:\n{response}"
|
24 |
+
|
25 |
+
# Format the history and the current message into the prompt
|
26 |
+
formatted_history = ""
|
27 |
+
for user_msg, bot_msg in history:
|
28 |
+
formatted_history += template.format(instruction=user_msg, response=bot_msg)
|
29 |
+
|
30 |
+
# Add the latest message from the user
|
31 |
+
prompt = template.format(instruction=message, response="")
|
32 |
+
print(prompt)
|
33 |
+
|
34 |
+
# Combine history with the latest prompt
|
35 |
+
final_prompt = formatted_history + prompt
|
36 |
+
print(final_prompt)
|
37 |
+
|
38 |
+
# Generate response from the model
|
39 |
+
response = gemma_lm.generate(final_prompt, max_length=256)
|
40 |
+
# Only keep the generated response
|
41 |
+
response = response.split("Response:")[1].strip()
|
42 |
+
|
43 |
+
print(response)
|
44 |
+
|
45 |
+
# Extract and return the generated response text
|
46 |
+
return response # Adjust this if your model's output structure differs
|
47 |
+
|
48 |
+
# Create the Gradio chat interface
|
49 |
+
interface = gr.ChatInterface(
|
50 |
+
fn=generate_response, # Function that generates responses
|
51 |
+
chatbot=gr.Chatbot(height=300), # Chatbot UI component
|
52 |
+
textbox=gr.Textbox(placeholder="Hello, am Sage, your mental health advisor", container=False, scale=7),
|
53 |
+
title="Local Model Chat Bot",
|
54 |
+
retry_btn=None, # Disable retry button
|
55 |
+
undo_btn="Delete Previous", # Enable undo button
|
56 |
+
clear_btn="Clear" # Enable clear button
|
57 |
+
)
|
58 |
+
|
59 |
+
# Launch the Gradio app
|
60 |
+
interface.launch(share=True)
|
requirements.txt
CHANGED
@@ -1,65 +0,0 @@
|
|
1 |
-
aiofiles==23.2.1
|
2 |
-
annotated-types==0.7.0
|
3 |
-
anyio==4.4.0
|
4 |
-
certifi==2024.7.4
|
5 |
-
charset-normalizer==3.3.2
|
6 |
-
click==8.1.7
|
7 |
-
contourpy==1.3.0
|
8 |
-
cycler==0.12.1
|
9 |
-
distro==1.9.0
|
10 |
-
fastapi==0.112.2
|
11 |
-
ffmpy==0.4.0
|
12 |
-
filelock==3.15.4
|
13 |
-
fonttools==4.53.1
|
14 |
-
fsspec==2024.6.1
|
15 |
-
gradio==4.42.0
|
16 |
-
gradio_client==1.3.0
|
17 |
-
h11==0.14.0
|
18 |
-
httpcore==1.0.5
|
19 |
-
httpx==0.27.2
|
20 |
-
huggingface-hub==0.24.6
|
21 |
-
idna==3.8
|
22 |
-
importlib_resources==6.4.4
|
23 |
-
Jinja2==3.1.4
|
24 |
-
jiter==0.5.0
|
25 |
-
kiwisolver==1.4.5
|
26 |
-
markdown-it-py==3.0.0
|
27 |
-
MarkupSafe==2.1.5
|
28 |
-
matplotlib==3.9.2
|
29 |
-
mdurl==0.1.2
|
30 |
-
numpy==2.1.0
|
31 |
-
openai==1.42.0
|
32 |
-
orjson==3.10.7
|
33 |
-
packaging==24.1
|
34 |
-
pandas==2.2.2
|
35 |
-
pillow==10.4.0
|
36 |
-
pydantic==2.8.2
|
37 |
-
pydantic_core==2.20.1
|
38 |
-
pydub==0.25.1
|
39 |
-
Pygments==2.18.0
|
40 |
-
pyparsing==3.1.4
|
41 |
-
python-dateutil==2.9.0.post0
|
42 |
-
python-dotenv==1.0.1
|
43 |
-
python-multipart==0.0.9
|
44 |
-
pytz==2024.1
|
45 |
-
PyYAML==6.0.2
|
46 |
-
regex==2024.7.24
|
47 |
-
requests==2.32.3
|
48 |
-
rich==13.8.0
|
49 |
-
ruff==0.6.2
|
50 |
-
safetensors==0.4.4
|
51 |
-
semantic-version==2.10.0
|
52 |
-
shellingham==1.5.4
|
53 |
-
six==1.16.0
|
54 |
-
sniffio==1.3.1
|
55 |
-
starlette==0.38.2
|
56 |
-
tokenizers==0.19.1
|
57 |
-
tomlkit==0.12.0
|
58 |
-
tqdm==4.66.5
|
59 |
-
transformers==4.44.2
|
60 |
-
typer==0.12.5
|
61 |
-
typing_extensions==4.12.2
|
62 |
-
tzdata==2024.1
|
63 |
-
urllib3==2.2.2
|
64 |
-
uvicorn==0.30.6
|
65 |
-
websockets==12.0
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|