import streamlit as st import os import time from PIL import Image import math from streamlit_sparrow_labeling import st_sparrow_labeling import requests from config import settings import json class DataInference: class Model: # pageTitle = "Data Inference" subheader_2 = "Upload" initial_msg = "Please upload a file for inference" upload_help = "Upload a file to extract data from it" upload_button_text = "Upload" upload_button_text_desc = "Choose a file" extract_data = "Extract Data" model_in_use = "donut" img_file = None def set_image_file(self, img_file): st.session_state['img_file'] = img_file def get_image_file(self): if 'img_file' not in st.session_state: return None return st.session_state['img_file'] data_result = None def set_data_result(self, data_result): st.session_state['data_result'] = data_result def get_data_result(self): if 'data_result' not in st.session_state: return None return st.session_state['data_result'] def view(self, model, ui_width, device_type, device_width): # st.title(model.pageTitle) with st.sidebar: st.markdown("---") st.subheader(model.subheader_2) with st.form("upload-form", clear_on_submit=True): uploaded_file = st.file_uploader(model.upload_button_text_desc, accept_multiple_files=False, type=['png', 'jpg', 'jpeg'], help=model.upload_help, disabled=True) submitted = st.form_submit_button(model.upload_button_text, disabled=True) if submitted and uploaded_file is not None: ret = self.upload_file(uploaded_file) if ret is not False: model.set_image_file(ret) model.set_data_result(None) if model.get_image_file() is not None: doc_img = Image.open(model.get_image_file()) doc_height = doc_img.height doc_width = doc_img.width canvas_width, number_of_columns = self.canvas_available_width(ui_width, doc_width, device_type, device_width) if number_of_columns > 1: col1, col2 = st.columns([number_of_columns, 10 - number_of_columns]) with col1: self.render_doc(model, doc_img, canvas_width, doc_height, doc_width) with col2: self.render_results(model) else: self.render_doc(model, doc_img, canvas_width, doc_height, doc_width) self.render_results(model) else: st.title(model.initial_msg) def upload_file(self, uploaded_file): timestamp = str(time.time()) timestamp = timestamp.replace(".", "") file_name, file_extension = os.path.splitext(uploaded_file.name) uploaded_file.name = file_name + "_" + timestamp + file_extension if os.path.exists(os.path.join("docs/inference/", uploaded_file.name)): st.write("File already exists") return False if len(uploaded_file.name) > 500: st.write("File name too long") return False with open(os.path.join("docs/inference/", uploaded_file.name), "wb") as f: f.write(uploaded_file.getbuffer()) st.success("File uploaded successfully") return os.path.join("docs/inference/", uploaded_file.name) def canvas_available_width(self, ui_width, doc_width, device_type, device_width): doc_width_pct = (doc_width * 100) / ui_width if doc_width_pct < 45: canvas_width_pct = 37 elif doc_width_pct < 55: canvas_width_pct = 49 else: canvas_width_pct = 60 if ui_width > 700 and canvas_width_pct == 37 and device_type == "desktop": return math.floor(canvas_width_pct * ui_width / 100), 4 elif ui_width > 700 and canvas_width_pct == 49 and device_type == "desktop": return math.floor(canvas_width_pct * ui_width / 100), 5 elif ui_width > 700 and canvas_width_pct == 60 and device_type == "desktop": return math.floor(canvas_width_pct * ui_width / 100), 6 else: if device_type == "desktop": ui_width = device_width - math.floor((device_width * 22) / 100) elif device_type == "mobile": ui_width = device_width - math.floor((device_width * 13) / 100) return ui_width, 1 def render_doc(self, model, doc_img, canvas_width, doc_height, doc_width): height = 1296 width = 864 annotations_json = { "meta": { "version": "v0.1", "split": "train", "image_id": 0, "image_size": { "width": doc_width, "height": doc_height } }, "words": [] } st_sparrow_labeling( fill_color="rgba(0, 151, 255, 0.3)", stroke_width=2, stroke_color="rgba(0, 50, 255, 0.7)", background_image=doc_img, initial_rects=annotations_json, height=height, width=width, drawing_mode="transform", display_toolbar=False, update_streamlit=False, canvas_width=canvas_width, doc_height=doc_height, doc_width=doc_width, image_rescale=True, key="doc_annotation" + model.get_image_file() ) def render_results(self, model): with st.form(key="results_form"): button_placeholder = st.empty() submit = button_placeholder.form_submit_button(model.extract_data, type="primary") if 'inference_error' in st.session_state: st.error(st.session_state.inference_error) del st.session_state.inference_error if submit: button_placeholder.empty() api_url = "https://katanaml-org-sparrow-ml.hf.space/api-inference/v1/sparrow-ml/inference" file_path = model.get_image_file() with open(file_path, "rb") as file: model_in_use = model.model_in_use sparrow_key = settings.sparrow_key # Prepare the payload files = { 'file': (file.name, file, 'image/jpeg') } data = { 'image_url': '', 'model_in_use': model_in_use, 'sparrow_key': sparrow_key } with st.spinner("Extracting data from document..."): response = requests.post(api_url, data=data, files=files, timeout=180) if response.status_code != 200: print('Request failed with status code:', response.status_code) print('Response:', response.text) st.session_state["inference_error"] = "Error extracting data from document" st.experimental_rerun() model.set_data_result(response.text) # Display JSON data in Streamlit st.markdown("---") st.json(response.text) # replace file extension to json file_path = file_path.replace(".jpg", ".json") with open(file_path, "w") as f: json.dump(response.text, f, indent=2) st.experimental_rerun() else: if model.get_data_result() is not None: st.markdown("---") st.json(model.get_data_result())