|
import streamlit as st |
|
import os |
|
import pathlib |
|
from typing import List |
|
|
|
|
|
from models.llms import load_llm, integrated_llms |
|
from models.embeddings import hf_embed_model, openai_embed_model |
|
from models.llamaCustom import LlamaCustom |
|
from utils.chatbox import show_previous_messages, show_chat_input |
|
from utils.util import validate_openai_api_key |
|
|
|
|
|
from llama_index.core import ( |
|
SimpleDirectoryReader, |
|
Document, |
|
VectorStoreIndex, |
|
StorageContext, |
|
Settings, |
|
load_index_from_storage, |
|
) |
|
from llama_index.core.memory import ChatMemoryBuffer |
|
from llama_index.core.base.llms.types import ChatMessage |
|
|
|
|
|
from huggingface_hub import HfApi |
|
|
|
SAVE_DIR = "uploaded_files" |
|
VECTOR_STORE_DIR = "vectorStores" |
|
HF_REPO_ID = "zhtet/RegBotBeta" |
|
|
|
|
|
Settings.embed_model = hf_embed_model |
|
|
|
|
|
hf_api = HfApi() |
|
|
|
|
|
def init_session_state(): |
|
if "llama_messages" not in st.session_state: |
|
st.session_state.llama_messages = [ |
|
{"role": "assistant", "content": "How can I help you today?"} |
|
] |
|
|
|
|
|
if "llama_chat_history" not in st.session_state: |
|
st.session_state.llama_chat_history = [ |
|
ChatMessage.from_str(role="assistant", content="How can I help you today?") |
|
] |
|
|
|
if "llama_custom" not in st.session_state: |
|
st.session_state.llama_custom = None |
|
|
|
if "openai_api_key" not in st.session_state: |
|
st.session_state.openai_api_key = "" |
|
|
|
if "replicate_api_token" not in st.session_state: |
|
st.session_state.replicate_api_token = "" |
|
|
|
if "hf_token" not in st.session_state: |
|
st.session_state.hf_token = "" |
|
|
|
|
|
|
|
def index_docs( |
|
filename: str, |
|
) -> VectorStoreIndex: |
|
try: |
|
index_path = pathlib.Path(f"{VECTOR_STORE_DIR}/{filename.replace('.', '_')}") |
|
if pathlib.Path.exists(index_path): |
|
print("Loading index from storage ...") |
|
storage_context = StorageContext.from_defaults(persist_dir=index_path) |
|
index = load_index_from_storage(storage_context=storage_context) |
|
|
|
else: |
|
reader = SimpleDirectoryReader(input_files=[f"{SAVE_DIR}/{filename}"]) |
|
docs = reader.load_data(show_progress=True) |
|
index = VectorStoreIndex.from_documents( |
|
documents=docs, |
|
show_progress=True, |
|
) |
|
index.storage_context.persist( |
|
persist_dir=f"vectorStores/{filename.replace('.', '_')}" |
|
) |
|
|
|
except Exception as e: |
|
print(f"Error: {e}") |
|
raise e |
|
return index |
|
|
|
|
|
def check_api_key(model_name: str, source: str): |
|
if source.startswith("openai"): |
|
if not st.session_state.openai_api_key: |
|
with st.expander("OpenAI API Key", expanded=True): |
|
openai_api_key = st.text_input( |
|
label="Enter your OpenAI API Key:", |
|
type="password", |
|
help="Get your key from https://platform.openai.com/account/api-keys", |
|
value=st.session_state.openai_api_key, |
|
) |
|
|
|
if openai_api_key and st.spinner("Validating OpenAI API Key ..."): |
|
result = validate_openai_api_key(openai_api_key) |
|
if result["status"] == "success": |
|
st.session_state.openai_api_key = openai_api_key |
|
st.success(result["message"]) |
|
else: |
|
st.error(result["message"]) |
|
st.info("You can still select a different model to proceed.") |
|
st.stop() |
|
|
|
elif source.startswith("replicate"): |
|
if not st.session_state.replicate_api_token: |
|
with st.expander("Replicate API Token", expanded=True): |
|
replicate_api_token = st.text_input( |
|
label="Enter your Replicate API Token:", |
|
type="password", |
|
help="Get your key from https://replicate.ai/account", |
|
value=st.session_state.replicate_api_token, |
|
) |
|
|
|
|
|
|
|
if replicate_api_token: |
|
st.session_state.replicate_api_token = replicate_api_token |
|
|
|
os.environ["REPLICATE_API_TOKEN"] = replicate_api_token |
|
|
|
elif source.startswith("huggingface"): |
|
if not st.session_state.hf_token: |
|
with st.expander("Hugging Face Token", expanded=True): |
|
hf_token = st.text_input( |
|
label="Enter your Hugging Face Token:", |
|
type="password", |
|
help="Get your key from https://huggingface.co/settings/token", |
|
value=st.session_state.hf_token, |
|
) |
|
|
|
if hf_token: |
|
st.session_state.hf_token = hf_token |
|
|
|
os.environ["HF_TOKEN"] = hf_token |
|
|
|
|
|
init_session_state() |
|
|
|
st.set_page_config(page_title="Llama", page_icon="🦙") |
|
|
|
st.header("Llama Index with Custom LLM Demo") |
|
|
|
tab1, tab2 = st.tabs(["Config", "Chat"]) |
|
|
|
with tab1: |
|
selected_llm_name = st.selectbox( |
|
label="Select a model:", |
|
options=[f"{key} | {value}" for key, value in integrated_llms.items()], |
|
) |
|
model_name, source = selected_llm_name.split("|") |
|
|
|
check_api_key(model_name=model_name.strip(), source=source.strip()) |
|
|
|
selected_file = st.selectbox( |
|
label="Choose a file to chat with: ", options=os.listdir(SAVE_DIR) |
|
) |
|
|
|
if st.button("Submit", key="submit", help="Submit the form"): |
|
with st.status("Loading ...", expanded=True) as status: |
|
try: |
|
st.write("Loading Model ...") |
|
llama_llm = load_llm( |
|
model_name=model_name.strip(), source=source.strip() |
|
) |
|
if llama_llm is None: |
|
raise ValueError("Model not found!") |
|
Settings.llm = llama_llm |
|
|
|
st.write("Processing Data ...") |
|
index = index_docs(selected_file) |
|
|
|
st.write("Finishing Up ...") |
|
llama_custom = LlamaCustom(model_name=selected_llm_name, index=index) |
|
st.session_state.llama_custom = llama_custom |
|
|
|
status.update(label="Ready to query!", state="complete", expanded=False) |
|
except Exception as e: |
|
status.update(label="Error!", state="error", expanded=False) |
|
st.error(f"Error: {e}") |
|
st.stop() |
|
|
|
with tab2: |
|
messages_container = st.container(height=300) |
|
show_previous_messages(framework="llama", messages_container=messages_container) |
|
show_chat_input( |
|
disabled=False, |
|
framework="llama", |
|
model=st.session_state.llama_custom, |
|
messages_container=messages_container, |
|
) |
|
|
|
def clear_history(): |
|
messages_container.empty() |
|
st.session_state.llama_messages = [ |
|
{"role": "assistant", "content": "How can I help you today?"} |
|
] |
|
|
|
st.session_state.llama_chat_history = [ |
|
ChatMessage.from_str(role="assistant", content="How can I help you today?") |
|
] |
|
|
|
if st.button("Clear Chat History"): |
|
clear_history() |
|
st.rerun() |
|
|