WoodLB commited on
Commit
e3b75a8
·
1 Parent(s): 1afa5e7
Files changed (1) hide show
  1. app.py +263 -41
app.py CHANGED
@@ -1,42 +1,264 @@
1
  import streamlit as st
2
- import pandas as pd
3
- import numpy as np
4
-
5
- st.title('Uber pickups in NYC')
6
-
7
- DATE_COLUMN = 'date/time'
8
- DATA_URL = ('https://s3-us-west-2.amazonaws.com/'
9
- 'streamlit-demo-data/uber-raw-data-sep14.csv.gz')
10
-
11
- @st.cache_data
12
- def load_data(nrows):
13
- data = pd.read_csv(DATA_URL, nrows=nrows)
14
- lowercase = lambda x: str(x).lower()
15
- data.rename(lowercase, axis='columns', inplace=True)
16
- data[DATE_COLUMN] = pd.to_datetime(data[DATE_COLUMN])
17
- return data
18
-
19
- data_load_state = st.text('Loading data...')
20
- data = load_data(10000)
21
- data_load_state.text("Done! (using st.cache)")
22
-
23
- if st.checkbox('Show raw data'):
24
- st.subheader('Raw data')
25
- st.write(data)
26
-
27
- st.subheader('Number of pickups by hour')
28
- hist_values = np.histogram(data[DATE_COLUMN].dt.hour, bins=24, range=(0,24))[0]
29
- st.bar_chart(hist_values)
30
-
31
- # Some number in the range 0-23
32
- hour_to_filter = st.slider('hour', 0, 23, 17)
33
- filtered_data = data[data[DATE_COLUMN].dt.hour == hour_to_filter]
34
-
35
- st.subheader('Map of all pickups at %s:00' % hour_to_filter)
36
- st.map(filtered_data)
37
-
38
- uploaded_file = st.file_uploader("Choose a file")
39
- if uploaded_file is not None:
40
- st.write(uploaded_file.name)
41
- bytes_data = uploaded_file.getvalue()
42
- st.write(len(bytes_data), "bytes")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  import streamlit as st
2
+
3
+ x = st.slider("Select a value")
4
+ st.write(x, "squared is", x * x)
5
+
6
+ # -*- coding: utf-8 -*-
7
+ """Accelerator_Model_Training_Notebook.ipynb
8
+
9
+ Automatically generated by Colaboratory.
10
+
11
+ Original file is located at
12
+ https://colab.research.google.com/drive/1CSyAE9DhwGTl7bLaSoo7QSyMuoEqJpCj
13
+
14
+ ##This is the Image Classification Model Training Accelerator Notebook
15
+
16
+ In this notebook, you will input your labelbox API Key, the Model Run ID and Ontology ID associated with the dataset you created using the labelbox platform.
17
+
18
+ Please note this Notebook will run through given you have followed the beginning of the accelerator tutorial and set up a project that labels **images as one option of a radio classification list**.
19
+
20
+ label names must be lower case.
21
+
22
+ Inout your API_Key, Ontology_ID, and Model_Run_ID
23
+ """
24
+ from pydantic import PydanticUserError
25
+ def train_and_inference(api_key, ontology_id, model_run_id):
26
+ st.write('thisisstarting')
27
+ api_key = api_key # insert Labelbox API key
28
+ ontology_id = ontology_id # get the ontology ID from the Settings tab at the top left of your model run
29
+ model_run_id = model_run_id #get the model run ID from the settings gear icon on the right side of your Model Run
30
+ st.write('1')
31
+ import pydantic
32
+ st.write(pydantic.__version__)
33
+
34
+ import numpy as np
35
+ st.write('2')
36
+ import tensorflow as tf
37
+ st.write('3')
38
+ from tensorflow.keras import layers
39
+ st.write('4')
40
+ from tensorflow.keras.models import Sequential
41
+ st.write('5')
42
+ from tensorflow.keras.preprocessing.image import ImageDataGenerator
43
+ st.write('6')
44
+ import os
45
+ st.write('7')
46
+ import labelbox
47
+ st.write('zat')
48
+ from labelbox import Client
49
+ st.write('8')
50
+ from labelbox import (
51
+ Label, ImageData,
52
+ Radio,
53
+ ClassificationAnnotation, ClassificationAnswer
54
+ )
55
+ st.write('9')
56
+ import pandas as pd
57
+ import shutil
58
+
59
+ import json
60
+ import uuid
61
+ import time
62
+ import requests
63
+ st.write('madeithrhougtheimports')
64
+
65
+ """Connect to labelbox client
66
+ Define Model Variables
67
+ """
68
+
69
+ client = Client(api_key)
70
+ EPOCHS = 10
71
+
72
+ """#Setup Training
73
+
74
+ Export Classifications from Model Run
75
+ """
76
+
77
+ model_run = client.get_model_run(model_run_id)
78
+
79
+ client.enable_experimental = True
80
+ data_json = model_run.export_labels(download=True)
81
+ print(data_json)
82
+
83
+ """Separate datarows into folders."""
84
+
85
+ import requests
86
+ import os
87
+
88
+ def download_and_save_image(url, destination_folder, filename):
89
+ if not os.path.exists(destination_folder):
90
+ os.makedirs(destination_folder)
91
+
92
+ response = requests.get(url, stream=True)
93
+ response.raise_for_status()
94
+
95
+ with open(os.path.join(destination_folder, filename), 'wb') as file:
96
+ for chunk in response.iter_content(8192):
97
+ file.write(chunk)
98
+
99
+ BASE_DIR = 'dataset'
100
+
101
+ for entry in data_json:
102
+ data_split = entry['Data Split']
103
+ if data_split not in ['training', 'validation']: # we are skipping 'test' for now
104
+ continue
105
+
106
+ image_url = entry['Labeled Data']
107
+ label = entry['Label']['classifications'][0]['answer']['value']
108
+
109
+ destination_folder = os.path.join(BASE_DIR, data_split, label)
110
+ filename = os.path.basename(image_url)
111
+
112
+ download_and_save_image(image_url, destination_folder, filename)
113
+
114
+ """#Train Model"""
115
+
116
+ import tensorflow as tf
117
+ from tensorflow.keras.preprocessing.image import ImageDataGenerator
118
+ from tensorflow.keras.applications import MobileNetV2
119
+ from tensorflow.keras.layers import Dense, GlobalAveragePooling2D
120
+ from tensorflow.keras.models import Model
121
+ from tensorflow.keras.optimizers import Adam
122
+
123
+ TRAIN_DIR = 'dataset/training'
124
+ VALIDATION_DIR = 'dataset/validation'
125
+ IMG_HEIGHT, IMG_WIDTH = 224, 224 # default size for MobileNetV2
126
+ BATCH_SIZE = 32
127
+
128
+ train_datagen = ImageDataGenerator(
129
+ rescale=1./255,
130
+ rotation_range=20,
131
+ width_shift_range=0.2,
132
+ height_shift_range=0.2,
133
+ shear_range=0.2,
134
+ zoom_range=0.2,
135
+ horizontal_flip=True,
136
+ fill_mode='nearest'
137
+ )
138
+
139
+ validation_datagen = ImageDataGenerator(rescale=1./255)
140
+
141
+ train_ds = train_datagen.flow_from_directory(
142
+ TRAIN_DIR,
143
+ target_size=(IMG_HEIGHT, IMG_WIDTH),
144
+ batch_size=BATCH_SIZE,
145
+ class_mode='categorical'
146
+ )
147
+
148
+ validation_ds = validation_datagen.flow_from_directory(
149
+ VALIDATION_DIR,
150
+ target_size=(IMG_HEIGHT, IMG_WIDTH),
151
+ batch_size=BATCH_SIZE,
152
+ class_mode='categorical'
153
+ )
154
+
155
+ base_model = MobileNetV2(input_shape=(IMG_HEIGHT, IMG_WIDTH, 3),
156
+ include_top=False,
157
+ weights='imagenet')
158
+
159
+ # Freeze the base model
160
+ for layer in base_model.layers:
161
+ layer.trainable = False
162
+
163
+ # Create custom classification head
164
+ x = base_model.output
165
+ x = GlobalAveragePooling2D()(x)
166
+ x = Dense(1024, activation='relu')(x)
167
+ predictions = Dense(train_ds.num_classes, activation='softmax')(x)
168
+
169
+ model = Model(inputs=base_model.input, outputs=predictions)
170
+
171
+ model.compile(optimizer=Adam(learning_rate=0.0001),
172
+ loss='categorical_crossentropy',
173
+ metrics=['accuracy'])
174
+
175
+
176
+ history = model.fit(
177
+ train_ds,
178
+ validation_data=validation_ds,
179
+ epochs=EPOCHS
180
+ )
181
+
182
+ """#Run Inference on Model run Datarows"""
183
+
184
+ import numpy as np
185
+ import requests
186
+ from tensorflow.keras.preprocessing import image
187
+ from PIL import Image
188
+ from io import BytesIO
189
+ # Fetch the image from the URL
190
+ def load_image_from_url(img_url, target_size=(224, 224)):
191
+ response = requests.get(img_url)
192
+ img = Image.open(BytesIO(response.content))
193
+ img = img.resize(target_size)
194
+ img_array = image.img_to_array(img)
195
+ return np.expand_dims(img_array, axis=0)
196
+ def make_prediction(img_url):
197
+ # Image URL
198
+ img_url = img_url
199
+
200
+ # Load and preprocess the image
201
+ img_data = load_image_from_url(img_url)
202
+ img_data = img_data / 255.0 # Normalize the image data to [0,1]
203
+
204
+ # Make predictions
205
+ predictions = model.predict(img_data)
206
+ predicted_class = np.argmax(predictions[0])
207
+
208
+ # Retrieve the confidence score (probability) for the predicted class
209
+ confidence = predictions[0][predicted_class]
210
+
211
+ # Map the predicted class index to its corresponding label
212
+ class_map = train_ds.class_indices
213
+ inverse_map = {v: k for k, v in class_map.items()}
214
+ predicted_label = inverse_map[predicted_class]
215
+
216
+ return predicted_label, confidence
217
+
218
+ from tensorflow.errors import InvalidArgumentError # Add this import
219
+ ontology = client.get_ontology(ontology_id)
220
+ label_list = []
221
+ for datarow in model_run.export_labels(download=True):
222
+ try:
223
+ label, confidence = make_prediction(datarow['Labeled Data'])
224
+ except InvalidArgumentError as e:
225
+ print(f"InvalidArgumentError: {e}. Skipping this data row.")
226
+ continue # Skip to the next datarow if an exception occurs
227
+ my_checklist_answer = ClassificationAnswer(
228
+ name = label,
229
+ confidence=confidence)
230
+ checklist_prediction = ClassificationAnnotation(
231
+ name=ontology.classifications()[0].instructions,
232
+ value=Radio(
233
+ answer = my_checklist_answer
234
+ ))
235
+ # print(datarow["DataRow ID"])
236
+ label_prediction = Label(
237
+ data=ImageData(uid=datarow['DataRow ID']),
238
+ annotations = [checklist_prediction])
239
+ label_list.append(label_prediction)
240
+
241
+ prediction_import = model_run.add_predictions(
242
+ name="prediction_upload_job"+str(uuid.uuid4()),
243
+ predictions=label_list)
244
+
245
+ prediction_import.wait_until_done()
246
+
247
+ st.write(prediction_import.errors == [])
248
+ if prediction_import.errors == []:
249
+ return "you're a wizard harry"
250
+
251
+ st.title("Key Input and Button Example")
252
+ api_key = st.text_input("Enter your api key:", type="password")
253
+ model_run_id = st.text_input("Enter your model run ID:")
254
+ ontology_id = st.text_input("Enter your ontology ID:")
255
+
256
+ if st.button("Train and run inference"):
257
+ st.write('letsgo')
258
+ # Check if the key is not empty
259
+ if api_key + model_run_id + ontology_id:
260
+ result = train_and_inference(api_key, ontology_id, model_run_id)
261
+ st.write(result)
262
+ else:
263
+ st.warning("Please enter all keys.")
264
+