taesiri's picture
backup
dadff11
import base64
import io
import json
import os
import gradio as gr
import matplotlib.pyplot as plt
import spaces
import torch
from huggingface_hub import login
from PIL import Image
from transformers import AutoProcessor, MllamaForConditionalGeneration
def check_environment():
required_vars = ["HF_TOKEN"]
missing_vars = [var for var in required_vars if var not in os.environ]
if missing_vars:
raise ValueError(
f"Missing required environment variables: {', '.join(missing_vars)}\n"
"Please set the HF_TOKEN environment variable with your Hugging Face token"
)
# Login to Hugging Face
check_environment()
login(token=os.environ["HF_TOKEN"], add_to_git_credential=True)
import torch
from transformers import AutoProcessor, MllamaForConditionalGeneration
base_model_path = "taesiri/FireNet-LLama-3.2-11B-Vision-Base"
processor = AutoProcessor.from_pretrained(base_model_path)
model = MllamaForConditionalGeneration.from_pretrained(
base_model_path, torch_dtype=torch.bfloat16, device_map="cuda"
)
model.tie_weights()
def create_color_palette_image(colors):
if not colors or not isinstance(colors, list):
return None
try:
# Validate color format
for color in colors:
if not isinstance(color, str) or not color.startswith("#"):
return None
# Create figure and axis
fig, ax = plt.subplots(figsize=(10, 2))
# Create rectangles for each color
for i, color in enumerate(colors):
ax.add_patch(plt.Rectangle((i, 0), 1, 1, facecolor=color))
# Set the view limits and aspect ratio
ax.set_xlim(0, len(colors))
ax.set_ylim(0, 1)
ax.set_xticks([])
ax.set_yticks([])
return fig # Return the matplotlib figure directly
except Exception as e:
print(f"Error creating color palette: {e}")
return None
@spaces.GPU
def inference(image):
if image is None:
return ["Please provide an image"] * 4
if not isinstance(image, Image.Image):
try:
image = Image.fromarray(image)
except Exception as e:
print(f"Image conversion error: {e}")
return ["Invalid image format"] * 4
# Prepare input
messages = [
{
"role": "user",
"content": [
{"type": "image"},
{
"type": "text",
"text": "Analyze this image for fire, smoke, haze, or other related conditions.",
},
],
}
]
input_text = processor.apply_chat_template(messages, add_generation_prompt=True)
try:
# Move inputs to the correct device
inputs = processor(
image, input_text, add_special_tokens=False, return_tensors="pt"
).to(model.device)
# Clear CUDA cache after inference
with torch.no_grad():
output = model.generate(**inputs, max_new_tokens=2048)
if torch.cuda.is_available():
torch.cuda.empty_cache()
except Exception as e:
print(f"Inference error: {e}")
return ["Error during inference"] * 4
# Decode output
result = processor.decode(output[0], skip_special_tokens=True)
print("DEBUG: Full decoded output:", result)
try:
json_str = result.strip().split("assistant\n")[1].strip()
parsed_json = json.loads(json_str)
# Create specific JSON subsets for each section
fire_analysis = {
"predictions": parsed_json.get("predictions", "N/A"),
"description": parsed_json.get("description", "No description available"),
"confidence_scores": parsed_json.get("confidence_score", {}),
}
environment_analysis = {
"environmental_factors": parsed_json.get("environmental_factors", {})
}
detection_analysis = {
"detections": parsed_json.get("detections", []),
"detection_count": len(parsed_json.get("detections", [])),
}
report_analysis = {
"uncertainty_factors": parsed_json.get("uncertainty_factors", []),
"false_positive_indicators": parsed_json.get(
"false_positive_indicators", []
),
}
return (
json.dumps(fire_analysis, indent=2),
json.dumps(environment_analysis, indent=2),
json.dumps(detection_analysis, indent=2),
json.dumps(report_analysis, indent=2),
json_str,
"",
"Analysis complete",
parsed_json,
)
except Exception as e:
print("DEBUG: Error processing response:", e)
return (
"Error processing response",
"",
"",
"",
str(result),
str(e),
"Error",
{},
)
# Update Gradio interface
with gr.Blocks() as demo:
gr.Markdown("# Fire Detection Demo")
with gr.Row():
with gr.Column(scale=1):
image_input = gr.Image(
type="pil",
label="Upload Image",
elem_id="large-image",
)
submit_btn = gr.Button("Analyze Image", variant="primary")
# Updated examples
gr.Examples(
examples=[
"examples/1727808849.jpg",
"examples/1727809389.jpg",
"examples/Birch MWF014-0001.jpg",
"examples/frame_000036.jpg",
"examples/frame_000168.jpg",
],
inputs=image_input,
label="Example Images",
examples_per_page=5,
)
with gr.Tabs() as tabs:
with gr.Tab("Analysis Results"):
with gr.Row():
with gr.Column():
fire_output = gr.JSON(
label="Fire Details",
)
with gr.Column():
environment_output = gr.JSON(
label="Environment Details",
)
with gr.Row():
with gr.Column():
detection_output = gr.JSON(
label="Detection Details",
)
with gr.Column():
report_output = gr.JSON(
label="Report Details",
)
with gr.Tab("JSON Output", id=0):
json_output = gr.JSON(
label="Detailed JSON Results",
)
with gr.Tab("Raw Output"):
raw_output = gr.Textbox(
label="Raw JSON Response",
lines=10,
)
error_box = gr.Textbox(label="Error Messages", visible=False)
status_text = gr.Textbox(label="Status", value="Ready", interactive=False)
submit_btn.click(
fn=inference,
inputs=[image_input],
outputs=[
fire_output,
environment_output,
detection_output,
report_output,
raw_output,
error_box,
status_text,
json_output,
],
)
demo.launch(share=True)