yuwd commited on
Commit
a005919
1 Parent(s): 709136d
Dockerfile CHANGED
@@ -1,4 +1,4 @@
1
- FROM python:3.9.2-slim
2
  RUN apt-get update && apt-get install -y gcc g++
3
  WORKDIR /app
4
 
 
1
+ FROM --platform=linux/amd64 python:3.9.2-slim
2
  RUN apt-get update && apt-get install -y gcc g++
3
  WORKDIR /app
4
 
README.md CHANGED
@@ -1,6 +1,6 @@
1
  ---
2
- title: Streamlit Docker Template
3
- emoji: 📉
4
  colorFrom: blue
5
  colorTo: green
6
  sdk: docker
@@ -8,8 +8,14 @@ app_port: 8501
8
  pinned: false
9
  ---
10
 
11
- ## 🧠 Streamlit Docker Template 🔎
 
 
 
 
 
12
 
13
- Streamlit Docker Template is a template for creating a Streamlit app with Docker and Hugging Face Spaces.
14
-
15
- Code from https://docs.streamlit.io/library/get-started/create-an-app
 
 
1
  ---
2
+ title: Polos Demo
3
+ emoji: 🌟
4
  colorFrom: blue
5
  colorTo: green
6
  sdk: docker
 
8
  pinned: false
9
  ---
10
 
11
+ ## Get Started on M1 Mac
12
+ ```bash
13
+ git submodule update --init --recursive
14
+ docker build . -t polos_demo
15
+ docker run -it -d -v `pwd`:/workspace -p 8080:8080 --platform linux/amd64 polos_demo
16
+ docker exec -it $process_id bash
17
 
18
+ root@28cb354f7609:~# sh install.sh
19
+ root@28cb354f7609:~# poetry run python test.py
20
+ root@28cb354f7609:~# poetry run streamlit run test.py --server.port 8080
21
+ ```
app.py CHANGED
@@ -2,7 +2,8 @@ import streamlit as st
2
  from PIL import Image
3
  from polos.models import download_model, load_checkpoint
4
 
5
- @st.cache(allow_output_mutation=True)
 
6
  def load_model():
7
  model_path = download_model("polos")
8
  model = load_checkpoint(model_path)
@@ -10,29 +11,67 @@ def load_model():
10
 
11
  model = load_model()
12
 
13
- default_image = Image.open("test.jpg").convert("RGB")
14
- default_refs = [
15
- "there is a dog sitting on a couch with a person reaching out",
16
- "a dog laying on a couch with a person",
17
- 'a dog is laying on a couch with a person'
18
- ]
19
-
20
- data = [
21
- {
22
- "img": default_image,
23
- "mt": "",
24
- "refs": default_refs
25
- }
26
- ]
27
-
28
  # Streamlitインターフェースの設定
29
  st.title('Polos Demo')
30
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
31
  # ユーザー入力のテキストフィールド
32
- user_input = st.text_input("Enter the input sentence:", '')
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
33
 
34
- # 入力がある場合、モデルを使用してスコアを計算
35
- if user_input:
36
- data[0]['mt'] = user_input
37
- _, scores = model.predict(data, batch_size=1, cuda=False)
38
- st.write("Score:", scores)
 
2
  from PIL import Image
3
  from polos.models import download_model, load_checkpoint
4
 
5
+ # モデルのロード
6
+ @st.cache_resource()
7
  def load_model():
8
  model_path = download_model("polos")
9
  model = load_checkpoint(model_path)
 
11
 
12
  model = load_model()
13
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
14
  # Streamlitインターフェースの設定
15
  st.title('Polos Demo')
16
 
17
+ # セッションステートの初期化
18
+ if 'image' not in st.session_state:
19
+ st.session_state.image = None
20
+ if 'user_input' not in st.session_state:
21
+ st.session_state.user_input = ''
22
+ if 'user_refs' not in st.session_state:
23
+ st.session_state.user_refs = [
24
+ "there is a dog sitting on a couch with a person reaching out",
25
+ "a dog laying on a couch with a person",
26
+ 'a dog is laying on a couch with a person'
27
+ ]
28
+ if 'score' not in st.session_state:
29
+ st.session_state.score = None
30
+
31
+ # デフォルト画像の取得
32
+ @st.cache_resource()
33
+ def get_default_image():
34
+ try:
35
+ return Image.open("test.jpg").convert("RGB")
36
+ except FileNotFoundError:
37
+ return Image.new('RGB', (200, 200), color = 'gray') # デフォルト画像が見つからない場合の代替画像
38
+
39
+ default_image = get_default_image()
40
+
41
+ # 画像アップロードのためのウィジェット
42
+ uploaded_image = st.file_uploader("Upload your image:", type=["jpg", "jpeg", "png"])
43
+ if uploaded_image is not None:
44
+ st.session_state.image = Image.open(uploaded_image).convert("RGB")
45
+ elif st.session_state.image is None:
46
+ st.session_state.image = default_image
47
+
48
+ # 常に画像を表示
49
+ st.image(st.session_state.image, caption="Displayed Image", use_column_width=True)
50
+
51
+ # 参照文の入力フィールド
52
+ user_refs = st.text_area("Enter reference sentences (separate each by a newline):", "\n".join(st.session_state.user_refs))
53
+ st.session_state.user_refs = user_refs.split("\n")
54
+
55
  # ユーザー入力のテキストフィールド
56
+ user_input = st.text_input("Enter the input sentence:", value=st.session_state.user_input)
57
+ st.session_state.user_input = user_input
58
+
59
+ # Computeボタン
60
+ if st.button('Compute'):
61
+ # データの準備
62
+ data = [
63
+ {
64
+ "img": st.session_state.image,
65
+ "mt": st.session_state.user_input,
66
+ "refs": st.session_state.user_refs
67
+ }
68
+ ]
69
+
70
+ # モデル予測
71
+ if st.session_state.user_input:
72
+ _, scores = model.predict(data, batch_size=1, cuda=False)
73
+ st.session_state.score = scores[0]
74
 
75
+ # スコアの表示
76
+ if st.session_state.score is not None:
77
+ st.metric(label="Score", value=f"{st.session_state.score:.5f}")
 
 
polos/models/__init__.py CHANGED
@@ -19,13 +19,10 @@ str2model = {
19
  }
20
 
21
  def get_cache_folder():
22
- if "HOME" in os.environ:
23
- cache_directory = os.environ["HOME"] + "/.cache/torch/yuigawada/"
24
- if not os.path.exists(cache_directory):
25
- os.makedirs(cache_directory)
26
- return cache_directory
27
- else:
28
- raise Exception("HOME environment variable is not defined.")
29
 
30
  def download_model(model: str, saving_directory: str = None) -> ModelBase:
31
  """Function that loads pretrained models from AWS.
 
19
  }
20
 
21
  def get_cache_folder():
22
+ cache_directory = "./.cache/"
23
+ if not os.path.exists(cache_directory):
24
+ os.makedirs(cache_directory)
25
+ return cache_directory
 
 
 
26
 
27
  def download_model(model: str, saving_directory: str = None) -> ModelBase:
28
  """Function that loads pretrained models from AWS.
polos/models/__pycache__/__init__.cpython-39.pyc CHANGED
Binary files a/polos/models/__pycache__/__init__.cpython-39.pyc and b/polos/models/__pycache__/__init__.cpython-39.pyc differ
 
polos/models/encoders/__pycache__/xlmr.cpython-39.pyc CHANGED
Binary files a/polos/models/encoders/__pycache__/xlmr.cpython-39.pyc and b/polos/models/encoders/__pycache__/xlmr.cpython-39.pyc differ
 
polos/models/encoders/xlmr.py CHANGED
@@ -29,10 +29,7 @@ XLMR_LARGE_V0_MODEL_NAME = "xlmr.large.v0/model.pt"
29
  XLMR_BASE_V0_URL = "https://dl.fbaipublicfiles.com/fairseq/models/xlmr.base.v0.tar.gz"
30
  XLMR_BASE_V0_MODEL_NAME = "xlmr.base.v0/model.pt"
31
 
32
- if "HOME" in os.environ:
33
- saving_directory = os.environ["HOME"] + "/.cache/torch/yuigawada/"
34
- else:
35
- raise Exception("HOME environment variable is not defined.")
36
 
37
 
38
  class XLMREncoder(Encoder):
 
29
  XLMR_BASE_V0_URL = "https://dl.fbaipublicfiles.com/fairseq/models/xlmr.base.v0.tar.gz"
30
  XLMR_BASE_V0_MODEL_NAME = "xlmr.base.v0/model.pt"
31
 
32
+ saving_directory = "./.cache/"
 
 
 
33
 
34
 
35
  class XLMREncoder(Encoder):