|
import requests |
|
import streamlit as st |
|
from annotated_text import annotated_text |
|
from openfoodfacts.images import generate_image_url, generate_json_ocr_url |
|
|
|
|
|
@st.cache_data |
|
def send_prediction_request(ocr_url: str, model_version: str): |
|
return requests.get( |
|
"https://robotoff.openfoodfacts.net/api/v1/predict/ingredient_list", |
|
params={"ocr_url": ocr_url, "model_version": model_version}, |
|
).json() |
|
|
|
|
|
def get_product(barcode: str): |
|
r = requests.get(f"https://world.openfoodfacts.org/api/v2/product/{barcode}") |
|
|
|
if r.status_code == 404: |
|
return None |
|
|
|
return r.json()["product"] |
|
|
|
|
|
def display_ner_tags(text: str, entities: list[dict]): |
|
spans = [] |
|
previous_idx = 0 |
|
for entity in entities: |
|
score = entity["score"] |
|
lang = entity["lang"]["lang"] |
|
start_idx = entity["start"] |
|
end_idx = entity["end"] |
|
spans.append(text[previous_idx:start_idx]) |
|
spans.append((text[start_idx:end_idx], f"score={score:.3f} | lang={lang}")) |
|
previous_idx = end_idx |
|
spans.append(text[previous_idx:]) |
|
annotated_text(spans) |
|
|
|
|
|
def run( |
|
barcode: str, |
|
model_version: str, |
|
min_threshold: float = 0.5, |
|
): |
|
product = get_product(barcode) |
|
st.markdown(f"[Product page](https://world.openfoodfacts.org/product/{barcode})") |
|
|
|
if not product: |
|
st.error(f"Product {barcode} not found") |
|
return |
|
|
|
images = product.get("images", []) |
|
|
|
if not images: |
|
st.error(f"No images found for product {barcode}") |
|
return |
|
|
|
for image_id, _ in images.items(): |
|
if not image_id.isdigit(): |
|
continue |
|
|
|
ocr_url = generate_json_ocr_url(barcode, image_id) |
|
prediction = send_prediction_request(ocr_url, model_version) |
|
|
|
st.divider() |
|
image_url = generate_image_url(barcode, image_id) |
|
st.markdown(f"[Image {image_id}]({image_url}), [OCR]({ocr_url})") |
|
st.image(image_url) |
|
|
|
if "error" in prediction: |
|
st.warning(f"Error: {prediction['description']}") |
|
continue |
|
|
|
entities = prediction["entities"] |
|
text = prediction["text"] |
|
filtered_entities = [e for e in entities if e["score"] >= min_threshold] |
|
display_ner_tags(text, filtered_entities) |
|
|
|
|
|
query_params = st.experimental_get_query_params() |
|
default_barcode = query_params["barcode"][0] if "barcode" in query_params else "" |
|
|
|
st.title("Ingredient extraction demo") |
|
st.markdown( |
|
"This demo leverages the ingredient entity detection model, " |
|
"that takes the OCR text as input and predict ingredient lists." |
|
) |
|
barcode = st.text_input( |
|
"barcode", help="Barcode of the product", value=default_barcode |
|
).strip() |
|
model_version = "1" |
|
st.experimental_set_query_params(barcode=barcode) |
|
|
|
threshold = st.number_input( |
|
"threshold", |
|
help="Minimum threshold for entity predictions", |
|
min_value=0.0, |
|
max_value=1.0, |
|
value=0.98, |
|
) |
|
|
|
if barcode: |
|
run(barcode, model_version=model_version, min_threshold=threshold) |
|
|