Spaces:
Sleeping
Sleeping
# Imports standard | |
import torch | |
import numpy as np | |
from PIL import Image | |
import matplotlib.pyplot as plt | |
import gradio as gr | |
import os | |
import subprocess | |
import sys | |
# Installation des dépendances nécessaires | |
subprocess.run(['apt-get', 'update'], check=True, stdout=subprocess.PIPE, stderr=subprocess.PIPE, text=True) | |
packages = ['openmpi-bin', 'libopenmpi-dev'] | |
command = ['apt-get', 'install', '-y'] + packages | |
subprocess.run(command, check=True, stdout=subprocess.PIPE, stderr=subprocess.PIPE, text=True) | |
subprocess.check_call([sys.executable, '-m', 'pip', 'install', 'mpi4py']) | |
subprocess.check_call([sys.executable, '-m', 'pip', 'install', 'pydicom']) | |
subprocess.check_call([sys.executable, '-m', 'pip', 'install', 'SimpleITK']) | |
# Imports Hugging Face | |
from huggingface_hub import hf_hub_download, login | |
import spaces | |
# Imports locaux | |
from modeling.BaseModel import BaseModel | |
from modeling import build_model | |
from utilities.distributed import init_distributed | |
from utilities.arguments import load_opt_from_config_files | |
from utilities.constants import BIOMED_CLASSES | |
from inference_utils.inference import interactive_infer_image | |
from inference_utils.output_processing import check_mask_stats | |
from inference_utils.processing_utils import read_rgb, get_instances | |
def init_huggingface(): | |
"""Initialize Hugging Face connection and download the model.""" | |
hf_token = os.getenv('HF_TOKEN') | |
if hf_token is None: | |
raise ValueError("Hugging Face token not found. Please set the HF_TOKEN environment variable.") | |
login(hf_token) | |
pretrained_path = hf_hub_download( | |
repo_id="microsoft/BiomedParse", | |
filename="biomedparse_v1.pt", | |
local_dir="pretrained" | |
) | |
return pretrained_path | |
def apply_distributed(opt): | |
"""Applique les paramètres distribués pour le mode multi-processus.""" | |
print(f"Configuration distribuée appliquée : {opt}") | |
def init_distributed(opt): | |
"""Initialize distributed mode without premature CUDA initialization.""" | |
opt['CUDA'] = opt.get('CUDA', True) and torch.cuda.is_available() | |
if 'OMPI_COMM_WORLD_SIZE' not in os.environ: | |
# Application started without MPI | |
opt['env_info'] = 'no MPI' | |
opt['world_size'] = 1 | |
opt['local_size'] = 1 | |
opt['rank'] = 0 | |
opt['local_rank'] = 0 # Ensure this is set to 0 | |
opt['master_address'] = '127.0.0.1' | |
opt['master_port'] = '8673' | |
else: | |
# Application started with MPI | |
opt['world_size'] = int(os.environ['OMPI_COMM_WORLD_SIZE']) | |
opt['local_size'] = int(os.environ['OMPI_COMM_WORLD_LOCAL_SIZE']) | |
opt['rank'] = int(os.environ['OMPI_COMM_WORLD_RANK']) | |
opt['local_rank'] = int(os.environ['OMPI_COMM_WORLD_LOCAL_RANK']) | |
if not opt['CUDA']: | |
assert opt['world_size'] == 1, 'Multi-GPU training without CUDA is not supported since we use NCCL as communication backend' | |
opt['device'] = torch.device("cpu") | |
else: | |
opt['device'] = torch.device("cuda", opt['local_rank']) # Ensure local_rank is integer | |
apply_distributed(opt) | |
return opt | |
def setup_model(): | |
"""Initialize the model on CPU without CUDA initialization.""" | |
opt = load_opt_from_config_files(["configs/biomedparse_inference.yaml"]) | |
opt = init_distributed(opt) | |
opt['device'] = 'cpu' | |
pretrained_path = init_huggingface() | |
model = BaseModel(opt, build_model(opt)) | |
state_dict = torch.load(pretrained_path, map_location='cpu', weights_only=True) | |
model.load_state_dict(state_dict, strict=False) | |
# Initialize train_class_names | |
model.train_class_names = BIOMED_CLASSES + ["background"] | |
return model.eval() | |
import numpy as np | |
from PIL import Image | |
def preprocess_image(image): | |
"""Preprocess image for SEEM model input.""" | |
if isinstance(image, Image.Image): | |
# Convert PIL Image to numpy array | |
image = np.array(image) | |
# Ensure image is float32 and normalized | |
image = image.astype(np.float32) / 255.0 | |
# Ensure correct dimensions (B, C, H, W) | |
if len(image.shape) == 3: | |
image = np.transpose(image, (2, 0, 1)) # HWC -> CHW | |
image = np.expand_dims(image, axis=0) # Add batch dimension | |
return image | |
def predict_image(model, image, prompts): | |
"""Process image prediction with proper formatting.""" | |
try: | |
# Convert PIL Image to numpy array if needed | |
if isinstance(image, Image.Image): | |
image = np.array(image) | |
# Ensure image is in float32 and normalized | |
image = image.astype(np.float32) / 255.0 | |
# Transpose from HWC to CHW format | |
if len(image.shape) == 3: | |
image = np.transpose(image, (2, 0, 1)) | |
# Add batch dimension if needed | |
if len(image.shape) == 3: | |
image = np.expand_dims(image, axis=0) | |
# Convert to tensor | |
image_tensor = torch.from_numpy(image) | |
# Move to GPU if available | |
if torch.cuda.is_available(): | |
device = torch.device("cuda", 0) | |
model = model.to(device) | |
image_tensor = image_tensor.to(device) | |
else: | |
device = torch.device("cpu") | |
# Create batched input | |
batched_inputs = [{ | |
"image": image_tensor, | |
"prompt": prompts, | |
"height": image_tensor.shape[-2], | |
"width": image_tensor.shape[-1] | |
}] | |
with torch.no_grad(): | |
pred_masks = model(batched_inputs) | |
# Move back to CPU if needed | |
if device.type == "cuda": | |
model = model.to("cpu") | |
pred_masks = [mask.cpu() for mask in pred_masks] | |
return pred_masks | |
except Exception as e: | |
print(f"Error processing image: {str(e)}") | |
raise | |
def process_image(image, text, model): | |
"""Process image with proper error handling.""" | |
try: | |
prompts = [p.strip() for p in text.split(',') if p.strip()] | |
if not prompts: | |
raise ValueError("No valid prompts provided") | |
pred_masks = predict_image(model, image, prompts) | |
# Create visualization | |
fig = plt.figure(figsize=(5 * (len(pred_masks) + 1), 5)) | |
# Show original image | |
plt.subplot(1, len(pred_masks) + 1, 1) | |
plt.imshow(preprocess_image(image)) | |
plt.title("Original") | |
plt.axis('off') | |
# Show predictions | |
for i, mask in enumerate(pred_masks): | |
plt.subplot(1, len(pred_masks) + 1, i+2) | |
plt.imshow(preprocess_image(image)) | |
plt.imshow(mask.cpu().numpy(), alpha=0.5, cmap='Reds') | |
plt.title(prompts[i]) | |
plt.axis('off') | |
return fig | |
except Exception as e: | |
print(f"Error in process_image: {str(e)}") | |
raise | |
def setup_gradio_interface(model): | |
"""Configure l'interface Gradio.""" | |
return gr.Interface( | |
fn=lambda img, txt: process_image(img, txt, model), | |
inputs=[ | |
gr.Image(type="numpy", label="Image médicale"), | |
gr.Textbox( | |
label="Prompts (séparés par des virgules)", | |
placeholder="edema, lesion, etc...", | |
elem_classes="white" | |
) | |
], | |
outputs=gr.Plot(), | |
title="Core IA - Traitement d'image medicale", | |
description="Chargez une image médicale et spécifiez les éléments à segmenter", | |
examples=[ | |
["examples/144DME_as_F.jpeg", "Dans cette image donne moi l'œdème"], | |
["examples/T0011.jpg", "disque optique, cupule optique"], | |
["examples/C3_EndoCV2021_00462.jpg", "Trouve moi le polyp"], | |
["examples/covid_1585.png", "Qu'est ce qui ne va pas ici ?"], | |
['examples/Part_1_516_pathology_breast.png', "cellules néoplasiques , cellules inflammatoires , cellules du tissu conjonctif"] | |
] | |
) | |
def main(): | |
"""Entry point avoiding CUDA initialization in main process.""" | |
try: | |
init_huggingface() | |
model = setup_model() # Load on CPU | |
interface = setup_gradio_interface(model) | |
interface.launch(debug=True) | |
except Exception as e: | |
print(f"Error during initialization: {str(e)}") | |
raise | |
if __name__ == "__main__": | |
main() | |