Spaces:
Runtime error
Runtime error
MikeTrizna
commited on
Commit
·
c923f4c
1
Parent(s):
380651c
Initial Space commit
Browse files- app.py +110 -0
- models/fish_classification_model.pkl +3 -0
- models/fish_mask_model.pth +3 -0
- requirements.txt +5 -0
- test_fish.jpg +0 -0
app.py
ADDED
@@ -0,0 +1,110 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import streamlit as st
|
2 |
+
import os
|
3 |
+
import urllib
|
4 |
+
import fastai.vision.all as fai_vision
|
5 |
+
import numpy as np
|
6 |
+
from pathlib import Path
|
7 |
+
import pathlib
|
8 |
+
from PIL import Image
|
9 |
+
import platform
|
10 |
+
import altair as alt
|
11 |
+
import pandas as pd
|
12 |
+
|
13 |
+
def main():
|
14 |
+
st.title('Fish Masker and Classifier')
|
15 |
+
|
16 |
+
data_loader, segmenter = load_unet_model()
|
17 |
+
classification_model = load_classification_model()
|
18 |
+
|
19 |
+
st.markdown("Upload an Amazonian fish photo for masking.")
|
20 |
+
uploaded_image = st.file_uploader("", IMAGE_TYPES)
|
21 |
+
if uploaded_image:
|
22 |
+
image_data = uploaded_image.read()
|
23 |
+
st.markdown('## Original image')
|
24 |
+
st.image(image_data, use_column_width=True)
|
25 |
+
|
26 |
+
original_pil = Image.open(uploaded_image)
|
27 |
+
|
28 |
+
original_pil.save('original.jpg')
|
29 |
+
|
30 |
+
single_file = [Path('original.jpg')]
|
31 |
+
single_pil = Image.open(single_file[0])
|
32 |
+
input_dl = segmenter.dls.test_dl(single_file)
|
33 |
+
masks, _ = segmenter.get_preds(dl=input_dl)
|
34 |
+
masked_pil, percentage_fish = mask_fish_pil(single_pil, masks[0])
|
35 |
+
|
36 |
+
st.markdown('## Masked image')
|
37 |
+
st.markdown(f'**{percentage_fish:.1f}%** of pixels were labeled as "fish"')
|
38 |
+
st.image(masked_pil, use_column_width=True)
|
39 |
+
|
40 |
+
masked_pil.save('masked.jpg')
|
41 |
+
|
42 |
+
st.markdown('## Classification')
|
43 |
+
|
44 |
+
prediction = classification_model.predict('masked.jpg')
|
45 |
+
pred_chart = predictions_to_chart(prediction, classes = classification_model.dls.vocab)
|
46 |
+
st.altair_chart(pred_chart, use_container_width=True)
|
47 |
+
|
48 |
+
|
49 |
+
def mask_fish_pil(unmasked_fish, fastai_mask):
|
50 |
+
unmasked_np = np.array(unmasked_fish)
|
51 |
+
np_mask = fastai_mask.argmax(dim=0).numpy()
|
52 |
+
total_pixels = np_mask.size
|
53 |
+
fish_pixels = np.count_nonzero(np_mask)
|
54 |
+
percentage_fish = (fish_pixels / total_pixels) * 100
|
55 |
+
np_mask = (255 / np_mask.max() * (np_mask - np_mask.min())).astype(np.uint8)
|
56 |
+
np_mask = np.array(Image.fromarray(np_mask).resize(unmasked_np.shape[1::-1], Image.BILINEAR))
|
57 |
+
np_mask = np_mask.reshape(*np_mask.shape, 1) / 255
|
58 |
+
masked_fish_np = (unmasked_np * np_mask).astype(np.uint8)
|
59 |
+
masked_fish_pil = Image.fromarray(masked_fish_np)
|
60 |
+
return masked_fish_pil, percentage_fish
|
61 |
+
|
62 |
+
def predictions_to_chart(prediction, classes):
|
63 |
+
pred_rows = []
|
64 |
+
for i, conf in enumerate(list(prediction[2])):
|
65 |
+
pred_row = {'class': classes[i],
|
66 |
+
'probability': round(float(conf) * 100,2)}
|
67 |
+
pred_rows.append(pred_row)
|
68 |
+
pred_df = pd.DataFrame(pred_rows)
|
69 |
+
pred_df.head()
|
70 |
+
top_probs = pred_df.sort_values('probability', ascending=False).head(4)
|
71 |
+
chart = (
|
72 |
+
alt.Chart(top_probs)
|
73 |
+
.mark_bar()
|
74 |
+
.encode(
|
75 |
+
x=alt.X("probability:Q", scale=alt.Scale(domain=(0, 100))),
|
76 |
+
y=alt.Y("class:N",
|
77 |
+
sort=alt.EncodingSortField(field="probability", order="descending"))
|
78 |
+
)
|
79 |
+
)
|
80 |
+
return chart
|
81 |
+
|
82 |
+
@st.cache(allow_output_mutation=True)
|
83 |
+
def load_unet_model():
|
84 |
+
data_loader = fai_vision.SegmentationDataLoaders.from_label_func(
|
85 |
+
path = Path("."),
|
86 |
+
bs = 1,
|
87 |
+
fnames = [Path('test_fish.jpg')],
|
88 |
+
label_func = lambda x: x,
|
89 |
+
codes = np.array(["Photo", "Masks"], dtype=str),
|
90 |
+
item_tfms = [fai_vision.Resize(256, method = 'squish'),],
|
91 |
+
batch_tfms = [fai_vision.IntToFloatTensor(div_mask = 255)],
|
92 |
+
valid_pct = 0.2, num_workers = 0)
|
93 |
+
segmenter = fai_vision.unet_learner(data_loader, fai_vision.resnet34)
|
94 |
+
segmenter.load('fish_mask_model')
|
95 |
+
return data_loader, segmenter
|
96 |
+
|
97 |
+
@st.cache(allow_output_mutation=True)
|
98 |
+
def load_classification_model():
|
99 |
+
plt = platform.system()
|
100 |
+
|
101 |
+
if plt == 'Linux' or plt == 'Darwin':
|
102 |
+
pathlib.WindowsPath = pathlib.PosixPath
|
103 |
+
inf_model = fai_vision.load_learner('models/fish_classification_model.pkl', cpu=True)
|
104 |
+
|
105 |
+
return inf_model
|
106 |
+
|
107 |
+
IMAGE_TYPES = ["png", "jpg","jpeg"]
|
108 |
+
|
109 |
+
if __name__ == "__main__":
|
110 |
+
main()
|
models/fish_classification_model.pkl
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:2ac16550590dd60da201ce13e2f1b057d5343ef490db8663c463f8bbefef610e
|
3 |
+
size 179319095
|
models/fish_mask_model.pth
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:29b8afc516eb9f19e99dc53e924839a7157ac241d13f0945aec4717574c7908a
|
3 |
+
size 494929527
|
requirements.txt
ADDED
@@ -0,0 +1,5 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
streamlit==0.89
|
2 |
+
fastai==2.2
|
3 |
+
protobuf==3.20
|
4 |
+
altair
|
5 |
+
pandas
|
test_fish.jpg
ADDED