import streamlit as st import torch from langchain import HuggingFacePipeline from langchain.chains import RetrievalQA from streamlit_extras.row import row if 'model' not in st.session_state: st.session_state['model'] = 0 if 'max_length' not in st.session_state: st.session_state['max_length'] = 0 if 'temperature' not in st.session_state: st.session_state['temperature'] = 0 if 'repetition_penalty' not in st.session_state: st.session_state['repetition_penalty'] = 0 def load_llm_model(max_length, temperature, repetition_penalty): # llm = HuggingFacePipeline.from_model_id(model_id= 'lmsys/fastchat-t5-3b-v1.0', # task= 'text2text-generation', # model_kwargs={ "device_map": "auto", # "load_in_8bit": True,"max_length": 256, "temperature": 0, # "repetition_penalty": 1.5}) llm = HuggingFacePipeline.from_model_id(model_id= 'lmsys/fastchat-t5-3b-v1.0', task= 'text2text-generation', model_kwargs={ "max_length": max_length, "temperature": temperature, "torch_dtype":torch.float32, "repetition_penalty": repetition_penalty}) return llm st.title("Model Download") # st.subheader("This page allows users to adjust some parameters of the model before downloading") st.subheader("ผู้ใช้สามารถปรับเลือกการตั้งค่าต่อไปนี้เพื่อทำการดาวน์โลดโมเดล") # model_row = row([2, 2, 2], vertical_align="bottom") # max_length = model_row.number_input("max_length", value = 256) # temperature = model_row.number_input("temperature", value = 0) # repetition_penalty = model_row.number_input("repetition_penalty", value = 1.3) max_length = st.number_input("max_length", value = 256, step = 128) st.caption(""" กำหนดจำนวนคำของโมเดลภาษา หากตั้งค่าน้อย โมเดลจะตอบสั้นและกระชับ ถ้าตั้งค่าให้มากๆ การตอบกลับอาจมีรายละเอียดมากขึ้น แต่ต้องระวังเพราะคำตอบที่ยาวเกินไปอาจไม่มีจุดโฟกัส """) st.divider() temperature = st.number_input("temperature", value = 0.0, step = 0.1, max_value = 1.0) st.caption(""" กำหนดความคิดสร้างสรรค์และความหลากหลายในการตอบของโมเดล ค่าที่ต่ำสุด 0 จะเป็นการตอบแบบมีการควบคุมสูงสุด โดย 1 จะมีความหลากหลายสูงสุด และสร้างสรรค์มากสุด แต่ความมีเหตุผลอาจลดลง """) st.divider() repetition_penalty = st.number_input("repetition_penalty", value = 1.3, step = 0.1, max_value = 2.0) st.caption(""" กำหนดให้โมเดลพยายามหลีกเลี่ยงการใช้คำหรือวลีเดียวกันซ้ำๆ ค่าที่สูงขึ้นจะทำให้โมเดลเลี่ยงการตอบโดยใช้คำ หรือวลีเดิมๆ """) load_model_button = st.button("ดาวน์โลดโมเดล") if load_model_button: st.session_state['max_length'] = max_length st.session_state['temperature'] = temperature st.session_state['repetition_penalty'] = repetition_penalty st.session_state['model'] = load_llm_model(max_length, temperature, repetition_penalty) st.write("⚠️ Please expect to wait **1 - 2 minutes ** for the application to download the 3-billion-parameter LLM") st.write('Successfully model loaded ✅') # st.write('Successfully mผ te['repetition_penalty']) # st.markdown(type(st.session_state['model']))