import streamlit as st from transformers import ( T5ForConditionalGeneration, T5Tokenizer, pipeline, AutoTokenizer, AutoModelForCausalLM ) import torch # ----- Streamlit page config ----- st.set_page_config(page_title="Chat", layout="wide") # ----- Sidebar: Model controls ----- st.sidebar.title("Model Controls") model_options = { "1": "karthikeyan-r/calculation_model_11k", "2": "karthikeyan-r/slm-custom-model_6k" } model_choice = st.sidebar.selectbox( "Select Model", options=list(model_options.values()) ) load_model_button = st.sidebar.button("Load Model") clear_conversation_button = st.sidebar.button("Clear Conversation") clear_model_button = st.sidebar.button("Clear Model") # ----- Session States ----- if "model" not in st.session_state: st.session_state["model"] = None if "tokenizer" not in st.session_state: st.session_state["tokenizer"] = None if "qa_pipeline" not in st.session_state: st.session_state["qa_pipeline"] = None if "conversation" not in st.session_state: st.session_state["conversation"] = [] # ----- Load Model ----- def load_model(): if st.session_state["model"] is None or st.session_state["tokenizer"] is None: with st.spinner("Loading model..."): try: if model_choice == model_options["1"]: # Load the calculation model tokenizer = AutoTokenizer.from_pretrained(model_choice, cache_dir="./model_cache") model = AutoModelForCausalLM.from_pretrained(model_choice, cache_dir="./model_cache") # Add special tokens if needed if tokenizer.pad_token is None: tokenizer.add_special_tokens({'pad_token': '[PAD]'}) model.resize_token_embeddings(len(tokenizer)) if tokenizer.eos_token is None: tokenizer.add_special_tokens({'eos_token': '[EOS]'}) model.resize_token_embeddings(len(tokenizer)) model.config.pad_token_id = tokenizer.pad_token_id model.config.eos_token_id = tokenizer.eos_token_id st.session_state["model"] = model st.session_state["tokenizer"] = tokenizer st.session_state["qa_pipeline"] = None # Not needed for calculation model elif model_choice == model_options["2"]: # Load the T5 model for general QA device = 0 if torch.cuda.is_available() else -1 model = T5ForConditionalGeneration.from_pretrained(model_choice, cache_dir="./model_cache") tokenizer = T5Tokenizer.from_pretrained(model_choice, cache_dir="./model_cache") qa_pipe = pipeline( "text2text-generation", model=model, tokenizer=tokenizer, device=device ) st.session_state["model"] = model st.session_state["tokenizer"] = tokenizer st.session_state["qa_pipeline"] = qa_pipe st.success("Model loaded successfully and ready!") except Exception as e: st.error(f"Error loading model: {e}") if load_model_button: load_model() # ----- Clear Model ----- if clear_model_button: st.session_state["model"] = None st.session_state["tokenizer"] = None st.session_state["qa_pipeline"] = None st.success("Model cleared.") # ----- Clear Conversation ----- if clear_conversation_button: st.session_state["conversation"] = [] st.success("Conversation cleared.") # ----- Title ----- st.title("Chat Conversation UI") # ----- User Input and Processing ----- user_input = st.chat_input("Enter your query:") if user_input: # Save user input st.session_state["conversation"].append({ "role": "user", "content": user_input }) # Generate response if st.session_state["qa_pipeline"]: try: response = st.session_state["qa_pipeline"](f"Q: {user_input}", max_length=250) answer = response[0]["generated_text"] except Exception as e: answer = f"Error: {str(e)}" elif st.session_state["model"] and model_choice == model_options["1"]: try: tokenizer = st.session_state["tokenizer"] model = st.session_state["model"] inputs = tokenizer(f"Input: {user_input}\nOutput:", return_tensors="pt", padding=True, truncation=True) output = model.generate(inputs.input_ids, max_length=250, pad_token_id=tokenizer.pad_token_id) answer = tokenizer.decode(output[0], skip_special_tokens=True).split("Output:")[-1].strip() except Exception as e: answer = f"Error: {str(e)}" else: answer = "No model is loaded. Please select and load a model." # Save assistant response st.session_state["conversation"].append({ "role": "assistant", "content": answer }) # Display conversation for message in st.session_state["conversation"]: with st.chat_message(message["role"]): st.write(message["content"])