|
from io import BytesIO |
|
import gradio as gr |
|
import tensorflow as tf |
|
import logging |
|
import requests |
|
import numpy as np |
|
from PIL import Image |
|
from tensorflow.keras.utils import CustomObjectScope |
|
from tensorflow.keras.layers.experimental.preprocessing import RandomHeight |
|
|
|
logging.basicConfig( |
|
level=logging.WARNING, |
|
format="%(asctime)s - %(levelname)s - %(message)s", |
|
) |
|
with CustomObjectScope({'RandomHeight': RandomHeight}): |
|
model_0 = tf.keras.models.load_model('bestmodel.h5') |
|
|
|
|
|
print("TensorFlow version:", tf.__version__) |
|
print("GPU Available:", tf.config.list_physical_devices('GPU')) |
|
|
|
|
|
gpus = tf.config.experimental.list_physical_devices('GPU') |
|
if gpus: |
|
try: |
|
for gpu in gpus: |
|
tf.config.experimental.set_memory_growth(gpu, True) |
|
except RuntimeError as e: |
|
print(e) |
|
|
|
def clean_url(input_string): |
|
if "http" in input_string: |
|
return input_string[input_string.find("http"):] |
|
return input_string |
|
|
|
|
|
def classify_image(inp): |
|
|
|
image = None |
|
logging.warning("entree dans ckassify_image") |
|
logging.warning(inp) |
|
if isinstance(inp, np.ndarray): |
|
image = Image.fromarray(inp) |
|
logging.warning("1") |
|
|
|
if isinstance(inp, str) : |
|
logging.warning("2") |
|
|
|
inp = clean_url(inp) |
|
response = requests.get(inp) |
|
response.raise_for_status() |
|
image = Image.open(BytesIO(response.content)) |
|
if isinstance(inp, str) and (inp.startswith("http://") or inp.startswith("https://")): |
|
logging.warning("3") |
|
response = requests.get(inp) |
|
response.raise_for_status() |
|
image = Image.open(BytesIO(response.content)) |
|
|
|
else: |
|
logging.warning("4") |
|
|
|
image = inp |
|
if image.mode != "RGB": |
|
image = image.convert("RGB") |
|
|
|
image = image.resize((224, 224), Image.Resampling.LANCZOS) |
|
|
|
|
|
image_array = np.array(image) |
|
|
|
|
|
if len(image_array.shape) == 2: |
|
image_array = np.stack([image_array] * 3, axis=-1) |
|
|
|
|
|
inp = image_array.reshape((-1, 224, 224, 3)) |
|
|
|
prediction = model_0.predict(inp) |
|
output = "" |
|
if prediction[0][prediction.argmax()] < 0.84: |
|
output = "bonne image" |
|
elif prediction.argmax() == 0: |
|
output = "Rifle violence" |
|
elif prediction.argmax() == 1: |
|
output = "guns violence" |
|
elif prediction.argmax() == 2: |
|
output = "knife violence" |
|
elif prediction.argmax() == 3: |
|
output = "image porno" |
|
elif prediction.argmax() == 4: |
|
output = "personne habillée" |
|
else: |
|
output = "tank violence" |
|
return output |
|
|
|
|
|
gr.Interface( |
|
fn=classify_image, inputs=gr.Text(label="Url de l image"), outputs="text",live=True,title="API de détection des images violentes", |
|
).launch(share=True) |
|
|