|
from llama_index.llms.huggingface import HuggingFaceLLM, HuggingFaceInferenceAPI |
|
from llama_index.llms.openai import OpenAI |
|
from llama_index.llms.replicate import Replicate |
|
|
|
from dotenv import load_dotenv |
|
import os |
|
import streamlit as st |
|
|
|
load_dotenv() |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
integrated_llms = { |
|
"gpt-3.5-turbo-0125": "openai", |
|
"meta/llama-2-13b-chat": "replicate", |
|
"mistralai/Mistral-7B-Instruct-v0.2": "huggingface", |
|
|
|
|
|
} |
|
|
|
|
|
def load_llm(model_name: str, source: str = "huggingface"): |
|
print("model_name: ", model_name, "source: ", source) |
|
if integrated_llms.get(model_name) is None: |
|
return None |
|
try: |
|
if source.startswith("openai"): |
|
llm_gpt_3_5_turbo_0125 = OpenAI( |
|
model=model_name, |
|
api_key=st.session_state.openai_api_key, |
|
) |
|
|
|
return llm_gpt_3_5_turbo_0125 |
|
|
|
elif source.startswith("replicate"): |
|
llm_llama_13b_v2_replicate = Replicate( |
|
model=model_name, |
|
is_chat_model=True, |
|
additional_kwargs={"max_new_tokens": 250}, |
|
prompt_key=st.session_state.replicate_api_token, |
|
) |
|
|
|
return llm_llama_13b_v2_replicate |
|
|
|
elif source.startswith("huggingface"): |
|
llm_mixtral_8x7b = HuggingFaceInferenceAPI( |
|
model_name=model_name, |
|
token=st.session_state.hf_token, |
|
) |
|
|
|
return llm_mixtral_8x7b |
|
|
|
except Exception as e: |
|
print(e) |
|
|