Spaces:
Sleeping
Sleeping
File size: 3,036 Bytes
0efd337 2ad20f5 0e4ab50 75c6b52 2ad20f5 75c6b52 2ad20f5 0e4ab50 75c6b52 0e4ab50 2ad20f5 75c6b52 2ad20f5 0e4ab50 2ad20f5 0e4ab50 e8c2ecc 2ad20f5 e8c2ecc 0146fb1 0ba38df 2bafb40 0ba38df 75c6b52 0146fb1 e8c2ecc 75c6b52 e8c2ecc 0146fb1 75c6b52 0efd337 75c6b52 2ad20f5 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 |
import gradio as gr
from transformers import T5Tokenizer, T5ForConditionalGeneration
from langchain.memory import ConversationBufferMemory
import torch
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
# Load all three Flan-T5 models (small, base, large)
models = {
"small": T5ForConditionalGeneration.from_pretrained("google/flan-t5-small").to(device),
"base": T5ForConditionalGeneration.from_pretrained("google/flan-t5-base").to(device),
"large": T5ForConditionalGeneration.from_pretrained("google/flan-t5-large").to(device)
}
# Load the tokenizer (same tokenizer for all models)
tokenizer = T5Tokenizer.from_pretrained("google/flan-t5-base")
# Set up conversational memory using LangChain's ConversationBufferMemory
memory = ConversationBufferMemory()
# Define the chatbot function with memory and model size selection
def chat_with_flan(input_text, model_size):
# Retrieve conversation history and append the current user input
conversation_history = memory.load_memory_variables({})['history']
# Combine the history with the current user input
full_input = f"{conversation_history}\nUser: {input_text}\nAssistant:"
# Tokenize the input for the model
input_ids = tokenizer.encode(full_input, return_tensors="pt")
# Get the model based on the selected size
model = models[model_size]
# Generate the response from the model
outputs = model.generate(input_ids, max_length=200, num_return_sequences=1)
# Decode the model output
response = tokenizer.decode(outputs[0], skip_special_tokens=True)
# Update the memory with the user input and model response
memory.save_context({"input": input_text}, {"output": response})
return conversation_history + f"\nUser: {input_text}\nAssistant: {response}"
# Set up the Gradio interface with the input box below the output box
with gr.Blocks() as interface:
chatbot_output = gr.Textbox(label="Conversation", lines=15, placeholder="Chat history will appear here...", interactive=False)
# Add the instruction message above the input box
gr.Markdown("**Instructions:** Press `Shift + Enter` to submit, and `Enter` for a new line.")
# Add a dropdown for selecting the model size (small, base, large)
model_selector = gr.Dropdown(choices=["small", "base", "large"], value="base", label="Select Model Size")
# Input box for the user
user_input = gr.Textbox(label="Your Input", placeholder="Type your message here...", lines=2, show_label=True)
# Define the function to update the chat based on selected model
def update_chat(input_text, model_size):
updated_history = chat_with_flan(input_text, model_size)
return updated_history, ""
# Submit when pressing Enter
user_input.submit(update_chat, inputs=[user_input, model_selector], outputs=[chatbot_output, user_input])
# Layout for model selector and chatbot UI
gr.Row([model_selector])
# Launch the Gradio app
interface.launch()
|