Leo8613 commited on
Commit
f45ee40
·
verified ·
1 Parent(s): d297eb3

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +23 -8
app.py CHANGED
@@ -3,15 +3,30 @@ import numpy as np
3
  import cv2
4
  import gradio as gr
5
 
6
- # Charger le modèle sans compilation (pour éviter les erreurs liées à batch_shape)
7
- try:
8
- generator = tf.keras.models.load_model('generator.h5', compile=False) # Charger sans compilation
9
- except ValueError as e:
10
- print("Erreur lors du chargement du modèle, vérifier batch_shape.")
11
- print(e)
12
-
13
- # Fonction pour générer une vidéo à partir d'un bruit aléatoire
 
 
 
 
 
 
 
 
 
 
 
 
14
  def generate_video():
 
 
 
15
  # Générer un bruit aléatoire (entrée pour le générateur)
16
  noise = np.random.normal(0, 1, (1, 16, 64, 64, 3)) # Exemple de bruit pour 16 frames de 64x64x3
17
  generated_video = generator.predict(noise) # Générer la vidéo
 
3
  import cv2
4
  import gradio as gr
5
 
6
+ # Fonction pour charger le modèle sans 'batch_shape'
7
+ def load_model_safe(model_path):
8
+ try:
9
+ # Charger le modèle sans compilation pour éviter des erreurs liées au batch_shape
10
+ model = tf.keras.models.load_model(model_path, compile=False)
11
+ return model
12
+ except ValueError as e:
13
+ print(f"Erreur lors du chargement du modèle: {e}")
14
+ return None
15
+
16
+ # Charger le modèle en toute sécurité
17
+ generator = load_model_safe('generator.h5') # Assurez-vous que le fichier 'generator.h5' est dans le même répertoire
18
+
19
+ # Vérifier que le modèle a bien été chargé
20
+ if generator is None:
21
+ print("Le modèle n'a pas pu être chargé. Vérifiez le fichier 'generator.h5'.")
22
+ else:
23
+ print("Modèle chargé avec succès.")
24
+
25
+ # Fonction pour générer une vidéo à partir du générateur
26
  def generate_video():
27
+ if generator is None:
28
+ return "Le modèle n'a pas pu être chargé."
29
+
30
  # Générer un bruit aléatoire (entrée pour le générateur)
31
  noise = np.random.normal(0, 1, (1, 16, 64, 64, 3)) # Exemple de bruit pour 16 frames de 64x64x3
32
  generated_video = generator.predict(noise) # Générer la vidéo