from streamlit.logger import get_logger from models.model_seeds import seeds from models.openai.finetuned_models import finetuned_models, get_finetuned_chain from models.openai.role_models import get_role_chain, get_template_role_models from models.databricks.scenario_sim_biz import get_databricks_biz_chain from models.databricks.texter_sim_llm import get_databricks_chain logger = get_logger(__name__) def get_chain(issue, language, source, memory, temperature, texter_name=""): if source in ("OA_finetuned"): OA_engine = finetuned_models[f"{issue}-{language}"] return get_finetuned_chain(OA_engine, memory, temperature) elif source in ('OA_rolemodel'): seed = seeds.get(issue, "GCT")['prompt'] template = get_template_role_models(issue, language, texter_name=texter_name, seed=seed) return get_role_chain(template, memory, temperature) elif source in ('CTL_llama2'): if language == "English": language = "en" elif language == "Spanish": language = "es" return get_databricks_biz_chain(source, issue, language, memory, temperature) elif source in ('CTL_llama3'): if language == "English": language = "en" elif language == "Spanish": language = "es" return get_databricks_chain(source, issue, language, memory, temperature, texter_name=texter_name) def custom_chain_predict(llm_chain, input, stop): inputs = llm_chain.prep_inputs({"input":input, "stop":stop}) llm_chain._validate_inputs(inputs) outputs = llm_chain._call(inputs) llm_chain._validate_outputs(outputs) llm_chain.memory.chat_memory.add_user_message(inputs['input']) for out in outputs[llm_chain.output_key]: llm_chain.memory.chat_memory.add_ai_message(out) return outputs[llm_chain.output_key]