|
import os |
|
import streamlit as st |
|
from transformers import AutoModelForCausalLM, AutoTokenizer, Trainer, TrainingArguments |
|
import torch |
|
from torch.utils.data import Dataset |
|
|
|
|
|
model = None |
|
tokenizer = None |
|
|
|
|
|
user_instructions = {} |
|
|
|
|
|
class FeedbackDataset(Dataset): |
|
def __init__(self, input_texts, target_texts): |
|
self.input_texts = input_texts |
|
self.target_texts = target_texts |
|
|
|
def __len__(self): |
|
return len(self.input_texts) |
|
|
|
def __getitem__(self, idx): |
|
inputs = tokenizer.encode(self.input_texts[idx], return_tensors="pt").squeeze() |
|
targets = tokenizer.encode(self.target_texts[idx], return_tensors="pt").squeeze() |
|
return {"input_ids": inputs, "labels": targets} |
|
|
|
def load_model(model_name_or_path): |
|
global model, tokenizer |
|
|
|
st.write(f"Loading model from {model_name_or_path}...") |
|
|
|
|
|
tokenizer = AutoTokenizer.from_pretrained(model_name_or_path) |
|
model = AutoModelForCausalLM.from_pretrained(model_name_or_path) |
|
|
|
st.success("Model loaded successfully!") |
|
|
|
def generate_response(input_text): |
|
|
|
if model is None or tokenizer is None: |
|
st.error("Model is not loaded. Please load a model first.") |
|
return "" |
|
|
|
|
|
if input_text in user_instructions: |
|
return user_instructions[input_text] |
|
|
|
|
|
inputs = tokenizer.encode(input_text, return_tensors="pt") |
|
|
|
|
|
with torch.no_grad(): |
|
outputs = model.generate( |
|
inputs, max_length=100, num_return_sequences=1, do_sample=True, top_k=50, top_p=0.95 |
|
) |
|
|
|
|
|
response = tokenizer.decode(outputs[0], skip_special_tokens=True) |
|
return response |
|
|
|
def train_on_feedback(input_text, correct_response): |
|
|
|
dataset = FeedbackDataset([input_text], [correct_response]) |
|
|
|
|
|
training_args = TrainingArguments( |
|
output_dir="./feedback_model", |
|
num_train_epochs=1, |
|
per_device_train_batch_size=1, |
|
learning_rate=1e-5, |
|
logging_dir='./logs', |
|
logging_steps=10, |
|
save_steps=100 |
|
) |
|
|
|
|
|
trainer = Trainer( |
|
model=model, |
|
args=training_args, |
|
train_dataset=dataset, |
|
) |
|
|
|
|
|
trainer.train() |
|
|
|
def chat_interface(): |
|
st.title("π€ Chat with AI") |
|
|
|
|
|
model_name_or_path = st.text_input("Enter model name or local path:", "gpt2") |
|
|
|
|
|
if st.button("Load Model"): |
|
load_model(model_name_or_path) |
|
|
|
st.write("---") |
|
|
|
|
|
input_text = st.text_input("You:") |
|
|
|
if st.button("Send"): |
|
response = generate_response(input_text) |
|
st.write("AI:", response) |
|
|
|
|
|
feedback = st.radio("Was this response helpful?", ("Yes", "No")) |
|
|
|
if feedback == "No": |
|
correct_response = st.text_input("What should the AI have said?") |
|
if st.button("Submit Feedback"): |
|
|
|
train_on_feedback(input_text, correct_response) |
|
st.success("Feedback recorded. AI will improve based on this feedback.") |
|
|
|
|
|
if __name__ == "__main__": |
|
chat_interface() |