File size: 3,591 Bytes
d1cb6f6
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
import os
import streamlit as st
from transformers import AutoModelForCausalLM, AutoTokenizer, Trainer, TrainingArguments
import torch
from torch.utils.data import Dataset

# Initialize model and tokenizer as global variables
model = None
tokenizer = None

# Dictionary to store user instructions for future responses
user_instructions = {}

# Dummy dataset class for user feedback
class FeedbackDataset(Dataset):
    def __init__(self, input_texts, target_texts):
        self.input_texts = input_texts
        self.target_texts = target_texts

    def __len__(self):
        return len(self.input_texts)

    def __getitem__(self, idx):
        inputs = tokenizer.encode(self.input_texts[idx], return_tensors="pt").squeeze()
        targets = tokenizer.encode(self.target_texts[idx], return_tensors="pt").squeeze()
        return {"input_ids": inputs, "labels": targets}

def load_model(model_name_or_path):
    global model, tokenizer

    st.write(f"Loading model from {model_name_or_path}...")

    # Load the tokenizer and model
    tokenizer = AutoTokenizer.from_pretrained(model_name_or_path)
    model = AutoModelForCausalLM.from_pretrained(model_name_or_path)

    st.success("Model loaded successfully!")

def generate_response(input_text):
    # Ensure model and tokenizer are loaded
    if model is None or tokenizer is None:
        st.error("Model is not loaded. Please load a model first.")
        return ""

    # Check if there's a user-defined response
    if input_text in user_instructions:
        return user_instructions[input_text]

    # Encode input text
    inputs = tokenizer.encode(input_text, return_tensors="pt")

    # Generate response using the model
    with torch.no_grad():
        outputs = model.generate(
            inputs, max_length=100, num_return_sequences=1, do_sample=True, top_k=50, top_p=0.95
        )

    # Decode and return the response
    response = tokenizer.decode(outputs[0], skip_special_tokens=True)
    return response

def train_on_feedback(input_text, correct_response):
    # Prepare dataset
    dataset = FeedbackDataset([input_text], [correct_response])

    # Training arguments
    training_args = TrainingArguments(
        output_dir="./feedback_model",
        num_train_epochs=1,
        per_device_train_batch_size=1,
        learning_rate=1e-5,
        logging_dir='./logs',
        logging_steps=10,
        save_steps=100
    )

    # Trainer for the feedback loop
    trainer = Trainer(
        model=model,
        args=training_args,
        train_dataset=dataset,
    )

    # Train model on the feedback
    trainer.train()

def chat_interface():
    st.title("🤖 Chat with AI")

    # Input for model name or path
    model_name_or_path = st.text_input("Enter model name or local path:", "gpt2")

    # Button to load the model
    if st.button("Load Model"):
        load_model(model_name_or_path)

    st.write("---")

    # Chat input
    input_text = st.text_input("You:")

    if st.button("Send"):
        response = generate_response(input_text)
        st.write("AI:", response)

        # Feedback section
        feedback = st.radio("Was this response helpful?", ("Yes", "No"))

        if feedback == "No":
            correct_response = st.text_input("What should the AI have said?")
            if st.button("Submit Feedback"):
                # Train model on feedback
                train_on_feedback(input_text, correct_response)
                st.success("Feedback recorded. AI will improve based on this feedback.")

# Run chat interface
if __name__ == "__main__":
    chat_interface()