rogerkoranteng commited on
Commit
5ccf2c1
·
verified ·
1 Parent(s): f34852a

Upload folder using huggingface_hub

Browse files
Files changed (3) hide show
  1. fined-tuned.lora.h5 +3 -0
  2. main.py +53 -25
  3. requirements.txt +0 -65
fined-tuned.lora.h5 ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:3f50855153040325fe2a203e7ee03c3aa8c98f3f3db7cccc435fa759527bd7b5
3
+ size 5560280
main.py CHANGED
@@ -1,32 +1,60 @@
1
- from openai import OpenAI
2
  import gradio as gr
3
- from dotenv import load_dotenv
 
 
4
 
5
- load_dotenv()
6
 
7
- client = OpenAI()
8
 
 
 
9
 
10
- def generate_response(message, history):
11
- formatted_history = []
12
- for user, assistant in history:
13
- formatted_history.append({"role": "user", "content": user})
14
- formatted_history.append({"role": "assistant", "content": assistant})
15
-
16
- formatted_history.append({"role": "user", "content": message})
17
-
18
- response = client.chat.completions.create(model='gpt-3.5-turbo',
19
- messages=formatted_history,
20
- temperature=1.0)
21
 
22
- return response.choices[0].message.content
 
 
23
 
24
-
25
- gr.ChatInterface(generate_response,
26
- chatbot=gr.Chatbot(height=300),
27
- textbox=gr.Textbox(placeholder="You can ask me anything", container=False, scale=7),
28
- title="OpenAI Chat Bot",
29
- retry_btn=None,
30
- undo_btn="Delete Previous",
31
- clear_btn="Clear").launch(share=True)
32
- gr.ChatInterface(generate_response).launch()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  import gradio as gr
2
+ import os
3
+ import keras_nlp
4
+ from transformers import AutoModelForCausalLM
5
 
 
6
 
7
+ # Set Kaggle API credentials
8
 
9
+ os.environ["KAGGLE_USERNAME"] = "rogerkorantenng"
10
+ os.environ["KAGGLE_KEY"] = "9a33b6e88bcb6058b1281d777fa6808d"
11
 
12
+ # Load LoRA weights if you have them
13
+ LoRA_weights_path = "fined-tuned.lora.h5"
14
+ gemma_lm = keras_nlp.models.GemmaCausalLM.from_preset("gemma_2b_en")
 
 
 
 
 
 
 
 
15
 
16
+ gemma_lm.backbone.enable_lora(rank=4) # Enable LoRA with rank 4
17
+ gemma_lm.preprocessor.sequence_length = 512 # Limit sequence length
18
+ gemma_lm.backbone.load_lora_weights(LoRA_weights_path) # Load LoRA weights
19
 
20
+ # Define the response generation function
21
+ def generate_response(message, history):
22
+ # Create a prompt template
23
+ template = "Instruction:\n{instruction}\n\nResponse:\n{response}"
24
+
25
+ # Format the history and the current message into the prompt
26
+ formatted_history = ""
27
+ for user_msg, bot_msg in history:
28
+ formatted_history += template.format(instruction=user_msg, response=bot_msg)
29
+
30
+ # Add the latest message from the user
31
+ prompt = template.format(instruction=message, response="")
32
+ print(prompt)
33
+
34
+ # Combine history with the latest prompt
35
+ final_prompt = formatted_history + prompt
36
+ print(final_prompt)
37
+
38
+ # Generate response from the model
39
+ response = gemma_lm.generate(final_prompt, max_length=256)
40
+ # Only keep the generated response
41
+ response = response.split("Response:")[1].strip()
42
+
43
+ print(response)
44
+
45
+ # Extract and return the generated response text
46
+ return response # Adjust this if your model's output structure differs
47
+
48
+ # Create the Gradio chat interface
49
+ interface = gr.ChatInterface(
50
+ fn=generate_response, # Function that generates responses
51
+ chatbot=gr.Chatbot(height=300), # Chatbot UI component
52
+ textbox=gr.Textbox(placeholder="Hello, am Sage, your mental health advisor", container=False, scale=7),
53
+ title="Local Model Chat Bot",
54
+ retry_btn=None, # Disable retry button
55
+ undo_btn="Delete Previous", # Enable undo button
56
+ clear_btn="Clear" # Enable clear button
57
+ )
58
+
59
+ # Launch the Gradio app
60
+ interface.launch(share=True)
requirements.txt CHANGED
@@ -1,65 +0,0 @@
1
- aiofiles==23.2.1
2
- annotated-types==0.7.0
3
- anyio==4.4.0
4
- certifi==2024.7.4
5
- charset-normalizer==3.3.2
6
- click==8.1.7
7
- contourpy==1.3.0
8
- cycler==0.12.1
9
- distro==1.9.0
10
- fastapi==0.112.2
11
- ffmpy==0.4.0
12
- filelock==3.15.4
13
- fonttools==4.53.1
14
- fsspec==2024.6.1
15
- gradio==4.42.0
16
- gradio_client==1.3.0
17
- h11==0.14.0
18
- httpcore==1.0.5
19
- httpx==0.27.2
20
- huggingface-hub==0.24.6
21
- idna==3.8
22
- importlib_resources==6.4.4
23
- Jinja2==3.1.4
24
- jiter==0.5.0
25
- kiwisolver==1.4.5
26
- markdown-it-py==3.0.0
27
- MarkupSafe==2.1.5
28
- matplotlib==3.9.2
29
- mdurl==0.1.2
30
- numpy==2.1.0
31
- openai==1.42.0
32
- orjson==3.10.7
33
- packaging==24.1
34
- pandas==2.2.2
35
- pillow==10.4.0
36
- pydantic==2.8.2
37
- pydantic_core==2.20.1
38
- pydub==0.25.1
39
- Pygments==2.18.0
40
- pyparsing==3.1.4
41
- python-dateutil==2.9.0.post0
42
- python-dotenv==1.0.1
43
- python-multipart==0.0.9
44
- pytz==2024.1
45
- PyYAML==6.0.2
46
- regex==2024.7.24
47
- requests==2.32.3
48
- rich==13.8.0
49
- ruff==0.6.2
50
- safetensors==0.4.4
51
- semantic-version==2.10.0
52
- shellingham==1.5.4
53
- six==1.16.0
54
- sniffio==1.3.1
55
- starlette==0.38.2
56
- tokenizers==0.19.1
57
- tomlkit==0.12.0
58
- tqdm==4.66.5
59
- transformers==4.44.2
60
- typer==0.12.5
61
- typing_extensions==4.12.2
62
- tzdata==2024.1
63
- urllib3==2.2.2
64
- uvicorn==0.30.6
65
- websockets==12.0