File size: 3,055 Bytes
ab39072
7689af3
 
364b61e
0e0d778
53ad9da
254ceab
6f3a14d
 
 
364b61e
 
 
 
6f3a14d
aceaff0
6f3a14d
d36ead2
 
 
 
 
 
 
 
 
 
 
 
 
364b61e
 
 
 
 
d36ead2
7689af3
7850be6
53ad9da
364b61e
57198c0
53ad9da
 
364b61e
 
 
 
 
 
 
 
 
0e0d778
364b61e
0e0d778
 
 
364b61e
7850be6
364b61e
 
 
23b41f2
 
7850be6
22c0b35
7850be6
 
 
 
 
 
 
 
 
 
 
7689af3
0b01d71
c45ab5a
bbbaf12
 
7689af3
 
 
 
 
 
 
 
 
 
 
 
 
 
 
abb6226
98a44fe
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
from io import BytesIO
import gradio as gr
import tensorflow as tf
import logging
import requests
import numpy as np
from PIL import Image
from tensorflow.keras.utils import CustomObjectScope
from tensorflow.keras.layers.experimental.preprocessing import RandomHeight

logging.basicConfig(
    level=logging.WARNING,  
    format="%(asctime)s - %(levelname)s - %(message)s",  
)
with CustomObjectScope({'RandomHeight': RandomHeight}):
    model_0 = tf.keras.models.load_model('bestmodel.h5')

    
print("TensorFlow version:", tf.__version__)
print("GPU Available:", tf.config.list_physical_devices('GPU'))

# Configure TensorFlow to use memory growth
gpus = tf.config.experimental.list_physical_devices('GPU')
if gpus:
    try:
        for gpu in gpus:
            tf.config.experimental.set_memory_growth(gpu, True)
    except RuntimeError as e:
        print(e)

def clean_url(input_string):
    if "http" in input_string:
        return input_string[input_string.find("http"):]
    return input_string


def classify_image(inp):
    # Convert to PIL Image if we have a numpy array
    image = None
    logging.warning("entree dans ckassify_image")
    logging.warning(inp)
    if isinstance(inp, np.ndarray):
        image = Image.fromarray(inp)
        logging.warning("1")

    if isinstance(inp, str) :
        logging.warning("2")

        inp = clean_url(inp)
        response = requests.get(inp)
        response.raise_for_status()
        image = Image.open(BytesIO(response.content))
    if isinstance(inp, str) and (inp.startswith("http://") or inp.startswith("https://")):
        logging.warning("3")
        response = requests.get(inp)
        response.raise_for_status()
        image = Image.open(BytesIO(response.content))
       
    else:
              logging.warning("4")

              image = inp
    if image.mode != "RGB":
        image = image.convert("RGB")
    # Resize image to 224x224
    image = image.resize((224, 224), Image.Resampling.LANCZOS)
    
    # Convert to numpy array and ensure correct shape
    image_array = np.array(image)
    
    # Handle grayscale images
    if len(image_array.shape) == 2:
        image_array = np.stack([image_array] * 3, axis=-1)
    
    # Add batch dimension and ensure correct shape
    inp = image_array.reshape((-1, 224, 224, 3))
    # inp = inp.reshape((-1, 224, 224, 3))
    prediction = model_0.predict(inp)
    output  = ""
    if prediction[0][prediction.argmax()] < 0.84:
      output = "bonne image"
    elif prediction.argmax() == 0:
      output = "Rifle violence"
    elif prediction.argmax() == 1:
      output = "guns violence"
    elif prediction.argmax() == 2:
      output = "knife violence"
    elif prediction.argmax() == 3:
      output = "image porno"
    elif prediction.argmax() == 4:
      output = "personne habillée" 
    else:
      output = "tank violence" 
    return output


gr.Interface(
    fn=classify_image, inputs=gr.Text(label="Url de l image"), outputs="text",live=True,title="API de détection des images violentes",
).launch(share=True)