File size: 1,003 Bytes
03f6091
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
ad5defb
 
03f6091
 
 
 
 
 
 
ad5defb
03f6091
 
ad5defb
03f6091
 
ad5defb
03f6091
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
import streamlit as st
from PIL import Image
from polos.models import download_model, load_checkpoint

@st.cache(allow_output_mutation=True)
def load_model():
    model_path = download_model("polos")
    model = load_checkpoint(model_path)
    return model

model = load_model()

default_image = Image.open("test.jpg").convert("RGB")
default_refs = [
    "there is a dog sitting on a couch with a person reaching out",
    "a dog laying on a couch with a person",
    'a dog is laying on a couch with a person'
]

data = [
    {
        "img": default_image,
        "mt": "",
        "refs": default_refs
    }
]

# Streamlitインターフェースの設定
st.title('Polos Demo')

# ユーザー入力のテキストフィールド
user_input = st.text_input("Enter the input sentence:", '')

# 入力がある場合、モデルを使用してスコアを計算
if user_input:
    data[0]['mt'] = user_input
    _, scores = model.predict(data, batch_size=1, cuda=False)
    st.write("Score:", scores)