File size: 1,789 Bytes
cd4405e
5ccf2c1
 
 
cd4405e
5ccf2c1
 
 
cd4405e
5ccf2c1
a5774b4
5ccf2c1
 
 
 
cd4405e
5ccf2c1
a5774b4
5ccf2c1
 
 
a5774b4
5ccf2c1
a5774b4
5ccf2c1
 
a5774b4
5ccf2c1
a5774b4
5ccf2c1
a5774b4
5ccf2c1
 
 
 
 
a5774b4
5ccf2c1
a5774b4
 
 
 
 
5ccf2c1
 
 
a5774b4
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
import gradio as gr
import os
import keras_nlp
from transformers import AutoModelForCausalLM

# Set Kaggle API credentials
os.environ["KAGGLE_USERNAME"] = "rogerkorantenng"
os.environ["KAGGLE_KEY"] = "9a33b6e88bcb6058b1281d777fa6808d"

# Load LoRA weights if you have them
LoRA_weights_path = "fined-tuned-model.lora.h5"
gemma_lm = keras_nlp.models.GemmaCausalLM.from_preset("gemma_2b_en")
gemma_lm.backbone.enable_lora(rank=4)  # Enable LoRA with rank 4
gemma_lm.preprocessor.sequence_length = 512  # Limit sequence length
gemma_lm.backbone.load_lora_weights(LoRA_weights_path)  # Load LoRA weights

# Define the response generation function
def generate_response(message):
    # Create a prompt template
    template = "Instruction:\n{instruction}\n\nResponse:\n{response}"

    # Create the prompt with the current message
    prompt = template.format(instruction=message, response="")
    print("Prompt:\n", prompt)

    # Generate response from the model
    response = gemma_lm.generate(prompt, max_length=256)
    # Only keep the generated response
    response = response.split("Response:")[-1].strip()

    print("Generated Response:\n", response)

    # Extract and return the generated response text
    return response  # Adjust this if your model's output structure differs

# Create the Gradio chat interface
interface = gr.Interface(
    fn=generate_response,  # Function that generates responses
    inputs=gr.Textbox(placeholder="Hello, I am Sage, your mental health advisor", lines=2, scale=7),
    outputs=gr.Textbox(),
    title="Welcome to Sage, your dedicated mental health advisor.",
#     description="Chat with Sage, your mental health advisor.",
#     live=True
)

# Launch the Gradio app
interface.launch(share=True, share_server_address="hopegivers.tech:7000")