import streamlit as st from PIL import Image from polos.models import download_model, load_checkpoint # モデルのロード @st.cache_resource() def load_model(): model_path = download_model("polos") model = load_checkpoint(model_path) return model model = load_model() # Streamlitインターフェースの設定 st.title('Polos Demo') # セッションステートの初期化 if 'image' not in st.session_state: st.session_state.image = None if 'user_input' not in st.session_state: st.session_state.user_input = '' if 'user_refs' not in st.session_state: st.session_state.user_refs = [ "there is a dog sitting on a couch with a person reaching out", "a dog laying on a couch with a person", 'a dog is laying on a couch with a person' ] if 'score' not in st.session_state: st.session_state.score = None # デフォルト画像の取得 @st.cache_resource() def get_default_image(): try: return Image.open("test.jpg").convert("RGB") except FileNotFoundError: return Image.new('RGB', (200, 200), color = 'gray') # デフォルト画像が見つからない場合の代替画像 default_image = get_default_image() # 画像アップロードのためのウィジェット uploaded_image = st.file_uploader("Upload your image:", type=["jpg", "jpeg", "png"]) if uploaded_image is not None: st.session_state.image = Image.open(uploaded_image).convert("RGB") elif st.session_state.image is None: st.session_state.image = default_image # 常に画像を表示 st.image(st.session_state.image, caption="Displayed Image", use_column_width=True) # 参照文の入力フィールド user_refs = st.text_area("Enter reference sentences (separate each by a newline):", "\n".join(st.session_state.user_refs)) st.session_state.user_refs = user_refs.split("\n") # ユーザー入力のテキストフィールド user_input = st.text_input("Enter the input sentence:", value=st.session_state.user_input) st.session_state.user_input = user_input # Computeボタン if st.button('Compute'): # データの準備 data = [ { "img": st.session_state.image, "mt": st.session_state.user_input, "refs": st.session_state.user_refs } ] # モデル予測 if st.session_state.user_input: _, scores = model.predict(data, batch_size=1, cuda=False) st.session_state.score = scores[0] # スコアの表示 if st.session_state.score is not None: st.metric(label="Score", value=f"{st.session_state.score:.5f}")