|
import gradio as gr |
|
import torch |
|
from transformers import AutoModelForCausalLM, AutoTokenizer |
|
|
|
|
|
model_name = "HuggingFaceH4/zephyr-7b-beta" |
|
model = AutoModelForCausalLM.from_pretrained( |
|
model_name, |
|
torch_dtype=torch.bfloat16, |
|
device_map="auto" |
|
) |
|
tokenizer = AutoTokenizer.from_pretrained(model_name) |
|
|
|
|
|
messages = [ |
|
{"role": "system", "content": "You are an experienced Fashion designer who starts conversation with proper greeting, " |
|
"giving valuable and catchy fashion advice and suggestions, stays to the point and precise."} |
|
] |
|
|
|
|
|
def reset_chat(): |
|
global messages |
|
messages = [] |
|
return [], "New Chat" |
|
|
|
|
|
def submit_questionnaire(name, age, location, gender, ethnicity, height, weight, |
|
style_preference, color_palette, everyday_style): |
|
|
|
|
|
return "Thank you for completing the questionnaire!" |
|
|
|
|
|
def chat(user_input): |
|
global messages |
|
if user_input: |
|
|
|
messages.append({"role": "user", "content": user_input}) |
|
|
|
|
|
chat_input = tokenizer.apply_chat_template(messages, tokenize=False, add_generation_prompt=True) |
|
|
|
|
|
try: |
|
model_inputs = tokenizer(chat_input, return_tensors="pt").to(model.device) |
|
generated_ids = model.generate( |
|
**model_inputs, |
|
max_new_tokens=256, |
|
do_sample=True, |
|
temperature=0.7, |
|
top_k=50, |
|
top_p=0.95 |
|
) |
|
response = tokenizer.batch_decode(generated_ids, skip_special_tokens=True)[0].strip() |
|
|
|
except Exception as e: |
|
response = f"Error: {str(e)}" |
|
|
|
|
|
messages.append({"role": "assistant", "content": response}) |
|
|
|
return messages, response |
|
return messages, "" |
|
|
|
|
|
with gr.Blocks() as demo: |
|
gr.Markdown("## Fashion Assistant Chatbot") |
|
|
|
|
|
with gr.Row(): |
|
with gr.Column(): |
|
name = gr.Textbox(label="Name") |
|
age = gr.Number(label="Age", value=25, minimum=1, maximum=100) |
|
location = gr.Textbox(label="Location") |
|
gender = gr.Radio(label="Gender", choices=["Male", "Female", "Other"]) |
|
ethnicity = gr.Radio(label="Ethnicity", choices=["Asian", "Black", "Hispanic", "White", "Other"]) |
|
height = gr.Number(label="Height (cm)", value=170, minimum=50, maximum=250) |
|
weight = gr.Number(label="Weight (kg)", value=70, minimum=20, maximum=200) |
|
|
|
with gr.Column(): |
|
submit_btn = gr.Button("Submit Questionnaire") |
|
reset_btn = gr.Button("Reset Chat") |
|
|
|
|
|
style_preference = gr.Radio(label="Which style do you prefer the most?", choices=["Casual", "Formal", "Streetwear", "Athleisure", "Baggy"]) |
|
color_palette = gr.Radio(label="What color palette do you wear often?", choices=["Neutrals", "Bright Colors", "Pastels", "Dark Shades"]) |
|
everyday_style = gr.Radio(label="How would you describe your everyday style?", choices=["Relaxed", "Trendy", "Elegant", "Bold"]) |
|
|
|
|
|
chatbox = gr.Chatbot(type='messages') |
|
user_input = gr.Textbox(label="Your Message", placeholder="Type your message here...") |
|
|
|
|
|
output_message = gr.Textbox(label="Output Message") |
|
submit_btn.click(submit_questionnaire, inputs=[name, age, location, gender, ethnicity, height, weight, |
|
style_preference, color_palette, everyday_style], outputs=output_message) |
|
|
|
reset_btn.click(reset_chat, outputs=[chatbox]) |
|
user_input.submit(chat, inputs=user_input, outputs=[chatbox, user_input]) |
|
|
|
|
|
demo.launch() |
|
|