File size: 3,036 Bytes
0efd337
2ad20f5
0e4ab50
75c6b52
2ad20f5
75c6b52
 
 
 
 
 
 
 
 
 
2ad20f5
 
0e4ab50
 
 
75c6b52
 
0e4ab50
 
 
 
 
 
 
 
2ad20f5
75c6b52
 
 
2ad20f5
 
 
0e4ab50
2ad20f5
0e4ab50
 
 
 
e8c2ecc
2ad20f5
e8c2ecc
 
 
0146fb1
0ba38df
2bafb40
0ba38df
75c6b52
 
 
0146fb1
 
e8c2ecc
75c6b52
 
 
e8c2ecc
 
0146fb1
75c6b52
0efd337
75c6b52
 
 
2ad20f5
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
import gradio as gr
from transformers import T5Tokenizer, T5ForConditionalGeneration
from langchain.memory import ConversationBufferMemory
import torch

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# Load all three Flan-T5 models (small, base, large)
models = {
    "small": T5ForConditionalGeneration.from_pretrained("google/flan-t5-small").to(device),
    "base": T5ForConditionalGeneration.from_pretrained("google/flan-t5-base").to(device),
    "large": T5ForConditionalGeneration.from_pretrained("google/flan-t5-large").to(device)
}

# Load the tokenizer (same tokenizer for all models)
tokenizer = T5Tokenizer.from_pretrained("google/flan-t5-base")

# Set up conversational memory using LangChain's ConversationBufferMemory
memory = ConversationBufferMemory()

# Define the chatbot function with memory and model size selection
def chat_with_flan(input_text, model_size):
    # Retrieve conversation history and append the current user input
    conversation_history = memory.load_memory_variables({})['history']
    
    # Combine the history with the current user input
    full_input = f"{conversation_history}\nUser: {input_text}\nAssistant:"
    
    # Tokenize the input for the model
    input_ids = tokenizer.encode(full_input, return_tensors="pt")
    
    # Get the model based on the selected size
    model = models[model_size]
    
    # Generate the response from the model
    outputs = model.generate(input_ids, max_length=200, num_return_sequences=1)
    
    # Decode the model output
    response = tokenizer.decode(outputs[0], skip_special_tokens=True)
    
    # Update the memory with the user input and model response
    memory.save_context({"input": input_text}, {"output": response})
    
    return conversation_history + f"\nUser: {input_text}\nAssistant: {response}"

# Set up the Gradio interface with the input box below the output box
with gr.Blocks() as interface:
    chatbot_output = gr.Textbox(label="Conversation", lines=15, placeholder="Chat history will appear here...", interactive=False)
    
    # Add the instruction message above the input box
    gr.Markdown("**Instructions:** Press `Shift + Enter` to submit, and `Enter` for a new line.")
    
    # Add a dropdown for selecting the model size (small, base, large)
    model_selector = gr.Dropdown(choices=["small", "base", "large"], value="base", label="Select Model Size")

    # Input box for the user
    user_input = gr.Textbox(label="Your Input", placeholder="Type your message here...", lines=2, show_label=True)
    
    # Define the function to update the chat based on selected model
    def update_chat(input_text, model_size):
        updated_history = chat_with_flan(input_text, model_size)
        return updated_history, ""

    # Submit when pressing Enter
    user_input.submit(update_chat, inputs=[user_input, model_selector], outputs=[chatbot_output, user_input])

    # Layout for model selector and chatbot UI
    gr.Row([model_selector])
    
# Launch the Gradio app
interface.launch()