Spaces:
Sleeping
Sleeping
import base64 | |
import io | |
import json | |
import os | |
import gradio as gr | |
import matplotlib.pyplot as plt | |
import spaces | |
import torch | |
from huggingface_hub import login | |
from PIL import Image | |
from transformers import AutoProcessor, MllamaForConditionalGeneration | |
def check_environment(): | |
required_vars = ["HF_TOKEN"] | |
missing_vars = [var for var in required_vars if var not in os.environ] | |
if missing_vars: | |
raise ValueError( | |
f"Missing required environment variables: {', '.join(missing_vars)}\n" | |
"Please set the HF_TOKEN environment variable with your Hugging Face token" | |
) | |
# Login to Hugging Face | |
check_environment() | |
login(token=os.environ["HF_TOKEN"], add_to_git_credential=True) | |
import torch | |
from transformers import AutoProcessor, MllamaForConditionalGeneration | |
base_model_path = "taesiri/FireNet-LLama-3.2-11B-Vision-Base" | |
processor = AutoProcessor.from_pretrained(base_model_path) | |
model = MllamaForConditionalGeneration.from_pretrained( | |
base_model_path, torch_dtype=torch.bfloat16, device_map="cuda" | |
) | |
model.tie_weights() | |
def create_color_palette_image(colors): | |
if not colors or not isinstance(colors, list): | |
return None | |
try: | |
# Validate color format | |
for color in colors: | |
if not isinstance(color, str) or not color.startswith("#"): | |
return None | |
# Create figure and axis | |
fig, ax = plt.subplots(figsize=(10, 2)) | |
# Create rectangles for each color | |
for i, color in enumerate(colors): | |
ax.add_patch(plt.Rectangle((i, 0), 1, 1, facecolor=color)) | |
# Set the view limits and aspect ratio | |
ax.set_xlim(0, len(colors)) | |
ax.set_ylim(0, 1) | |
ax.set_xticks([]) | |
ax.set_yticks([]) | |
return fig # Return the matplotlib figure directly | |
except Exception as e: | |
print(f"Error creating color palette: {e}") | |
return None | |
def inference(image): | |
if image is None: | |
return ["Please provide an image"] * 4 | |
if not isinstance(image, Image.Image): | |
try: | |
image = Image.fromarray(image) | |
except Exception as e: | |
print(f"Image conversion error: {e}") | |
return ["Invalid image format"] * 4 | |
# Prepare input | |
messages = [ | |
{ | |
"role": "user", | |
"content": [ | |
{"type": "image"}, | |
{ | |
"type": "text", | |
"text": "Analyze this image for fire, smoke, haze, or other related conditions.", | |
}, | |
], | |
} | |
] | |
input_text = processor.apply_chat_template(messages, add_generation_prompt=True) | |
try: | |
# Move inputs to the correct device | |
inputs = processor( | |
image, input_text, add_special_tokens=False, return_tensors="pt" | |
).to(model.device) | |
# Clear CUDA cache after inference | |
with torch.no_grad(): | |
output = model.generate(**inputs, max_new_tokens=2048) | |
if torch.cuda.is_available(): | |
torch.cuda.empty_cache() | |
except Exception as e: | |
print(f"Inference error: {e}") | |
return ["Error during inference"] * 4 | |
# Decode output | |
result = processor.decode(output[0], skip_special_tokens=True) | |
print("DEBUG: Full decoded output:", result) | |
try: | |
json_str = result.strip().split("assistant\n")[1].strip() | |
parsed_json = json.loads(json_str) | |
# Create specific JSON subsets for each section | |
fire_analysis = { | |
"predictions": parsed_json.get("predictions", "N/A"), | |
"description": parsed_json.get("description", "No description available"), | |
"confidence_scores": parsed_json.get("confidence_score", {}), | |
} | |
environment_analysis = { | |
"environmental_factors": parsed_json.get("environmental_factors", {}) | |
} | |
detection_analysis = { | |
"detections": parsed_json.get("detections", []), | |
"detection_count": len(parsed_json.get("detections", [])), | |
} | |
report_analysis = { | |
"uncertainty_factors": parsed_json.get("uncertainty_factors", []), | |
"false_positive_indicators": parsed_json.get( | |
"false_positive_indicators", [] | |
), | |
} | |
return ( | |
json.dumps(fire_analysis, indent=2), | |
json.dumps(environment_analysis, indent=2), | |
json.dumps(detection_analysis, indent=2), | |
json.dumps(report_analysis, indent=2), | |
json_str, | |
"", | |
"Analysis complete", | |
parsed_json, | |
) | |
except Exception as e: | |
print("DEBUG: Error processing response:", e) | |
return ( | |
"Error processing response", | |
"", | |
"", | |
"", | |
str(result), | |
str(e), | |
"Error", | |
{}, | |
) | |
# Update Gradio interface | |
with gr.Blocks() as demo: | |
gr.Markdown("# Fire Detection Demo") | |
with gr.Row(): | |
with gr.Column(scale=1): | |
image_input = gr.Image( | |
type="pil", | |
label="Upload Image", | |
elem_id="large-image", | |
) | |
submit_btn = gr.Button("Analyze Image", variant="primary") | |
# Updated examples | |
gr.Examples( | |
examples=[ | |
"examples/1727808849.jpg", | |
"examples/1727809389.jpg", | |
"examples/Birch MWF014-0001.jpg", | |
"examples/frame_000036.jpg", | |
"examples/frame_000168.jpg", | |
], | |
inputs=image_input, | |
label="Example Images", | |
examples_per_page=5, | |
) | |
with gr.Tabs() as tabs: | |
with gr.Tab("Analysis Results"): | |
with gr.Row(): | |
with gr.Column(): | |
fire_output = gr.JSON( | |
label="Fire Details", | |
) | |
with gr.Column(): | |
environment_output = gr.JSON( | |
label="Environment Details", | |
) | |
with gr.Row(): | |
with gr.Column(): | |
detection_output = gr.JSON( | |
label="Detection Details", | |
) | |
with gr.Column(): | |
report_output = gr.JSON( | |
label="Report Details", | |
) | |
with gr.Tab("JSON Output", id=0): | |
json_output = gr.JSON( | |
label="Detailed JSON Results", | |
) | |
with gr.Tab("Raw Output"): | |
raw_output = gr.Textbox( | |
label="Raw JSON Response", | |
lines=10, | |
) | |
error_box = gr.Textbox(label="Error Messages", visible=False) | |
status_text = gr.Textbox(label="Status", value="Ready", interactive=False) | |
submit_btn.click( | |
fn=inference, | |
inputs=[image_input], | |
outputs=[ | |
fire_output, | |
environment_output, | |
detection_output, | |
report_output, | |
raw_output, | |
error_box, | |
status_text, | |
json_output, | |
], | |
) | |
demo.launch(share=True) | |