MikeTrizna commited on
Commit
c923f4c
·
1 Parent(s): 380651c

Initial Space commit

Browse files
app.py ADDED
@@ -0,0 +1,110 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import streamlit as st
2
+ import os
3
+ import urllib
4
+ import fastai.vision.all as fai_vision
5
+ import numpy as np
6
+ from pathlib import Path
7
+ import pathlib
8
+ from PIL import Image
9
+ import platform
10
+ import altair as alt
11
+ import pandas as pd
12
+
13
+ def main():
14
+ st.title('Fish Masker and Classifier')
15
+
16
+ data_loader, segmenter = load_unet_model()
17
+ classification_model = load_classification_model()
18
+
19
+ st.markdown("Upload an Amazonian fish photo for masking.")
20
+ uploaded_image = st.file_uploader("", IMAGE_TYPES)
21
+ if uploaded_image:
22
+ image_data = uploaded_image.read()
23
+ st.markdown('## Original image')
24
+ st.image(image_data, use_column_width=True)
25
+
26
+ original_pil = Image.open(uploaded_image)
27
+
28
+ original_pil.save('original.jpg')
29
+
30
+ single_file = [Path('original.jpg')]
31
+ single_pil = Image.open(single_file[0])
32
+ input_dl = segmenter.dls.test_dl(single_file)
33
+ masks, _ = segmenter.get_preds(dl=input_dl)
34
+ masked_pil, percentage_fish = mask_fish_pil(single_pil, masks[0])
35
+
36
+ st.markdown('## Masked image')
37
+ st.markdown(f'**{percentage_fish:.1f}%** of pixels were labeled as "fish"')
38
+ st.image(masked_pil, use_column_width=True)
39
+
40
+ masked_pil.save('masked.jpg')
41
+
42
+ st.markdown('## Classification')
43
+
44
+ prediction = classification_model.predict('masked.jpg')
45
+ pred_chart = predictions_to_chart(prediction, classes = classification_model.dls.vocab)
46
+ st.altair_chart(pred_chart, use_container_width=True)
47
+
48
+
49
+ def mask_fish_pil(unmasked_fish, fastai_mask):
50
+ unmasked_np = np.array(unmasked_fish)
51
+ np_mask = fastai_mask.argmax(dim=0).numpy()
52
+ total_pixels = np_mask.size
53
+ fish_pixels = np.count_nonzero(np_mask)
54
+ percentage_fish = (fish_pixels / total_pixels) * 100
55
+ np_mask = (255 / np_mask.max() * (np_mask - np_mask.min())).astype(np.uint8)
56
+ np_mask = np.array(Image.fromarray(np_mask).resize(unmasked_np.shape[1::-1], Image.BILINEAR))
57
+ np_mask = np_mask.reshape(*np_mask.shape, 1) / 255
58
+ masked_fish_np = (unmasked_np * np_mask).astype(np.uint8)
59
+ masked_fish_pil = Image.fromarray(masked_fish_np)
60
+ return masked_fish_pil, percentage_fish
61
+
62
+ def predictions_to_chart(prediction, classes):
63
+ pred_rows = []
64
+ for i, conf in enumerate(list(prediction[2])):
65
+ pred_row = {'class': classes[i],
66
+ 'probability': round(float(conf) * 100,2)}
67
+ pred_rows.append(pred_row)
68
+ pred_df = pd.DataFrame(pred_rows)
69
+ pred_df.head()
70
+ top_probs = pred_df.sort_values('probability', ascending=False).head(4)
71
+ chart = (
72
+ alt.Chart(top_probs)
73
+ .mark_bar()
74
+ .encode(
75
+ x=alt.X("probability:Q", scale=alt.Scale(domain=(0, 100))),
76
+ y=alt.Y("class:N",
77
+ sort=alt.EncodingSortField(field="probability", order="descending"))
78
+ )
79
+ )
80
+ return chart
81
+
82
+ @st.cache(allow_output_mutation=True)
83
+ def load_unet_model():
84
+ data_loader = fai_vision.SegmentationDataLoaders.from_label_func(
85
+ path = Path("."),
86
+ bs = 1,
87
+ fnames = [Path('test_fish.jpg')],
88
+ label_func = lambda x: x,
89
+ codes = np.array(["Photo", "Masks"], dtype=str),
90
+ item_tfms = [fai_vision.Resize(256, method = 'squish'),],
91
+ batch_tfms = [fai_vision.IntToFloatTensor(div_mask = 255)],
92
+ valid_pct = 0.2, num_workers = 0)
93
+ segmenter = fai_vision.unet_learner(data_loader, fai_vision.resnet34)
94
+ segmenter.load('fish_mask_model')
95
+ return data_loader, segmenter
96
+
97
+ @st.cache(allow_output_mutation=True)
98
+ def load_classification_model():
99
+ plt = platform.system()
100
+
101
+ if plt == 'Linux' or plt == 'Darwin':
102
+ pathlib.WindowsPath = pathlib.PosixPath
103
+ inf_model = fai_vision.load_learner('models/fish_classification_model.pkl', cpu=True)
104
+
105
+ return inf_model
106
+
107
+ IMAGE_TYPES = ["png", "jpg","jpeg"]
108
+
109
+ if __name__ == "__main__":
110
+ main()
models/fish_classification_model.pkl ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:2ac16550590dd60da201ce13e2f1b057d5343ef490db8663c463f8bbefef610e
3
+ size 179319095
models/fish_mask_model.pth ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:29b8afc516eb9f19e99dc53e924839a7157ac241d13f0945aec4717574c7908a
3
+ size 494929527
requirements.txt ADDED
@@ -0,0 +1,5 @@
 
 
 
 
 
 
1
+ streamlit==0.89
2
+ fastai==2.2
3
+ protobuf==3.20
4
+ altair
5
+ pandas
test_fish.jpg ADDED