MBase / app.py
MNGames's picture
Update app.py
23ebda3 verified
raw
history blame
2.77 kB
import gradio as gr
import torch
import cv2
import pytesseract
from transformers import AutoTokenizer, AutoModelForSequenceClassification
# Load the Hugging Face model for object detection
model_name = "flax-community/yolov5s-v1-coco"
model = torch.hub.load('ultralytics/yolov5', 'yolov5s', pretrained=True)
# Load the Hugging Face model for text classification
classification_model_name = "distilbert-base-uncased"
classification_tokenizer = AutoTokenizer.from_pretrained(classification_model_name)
classification_model = AutoModelForSequenceClassification.from_pretrained(classification_model_name)
# Define function for OCR
def perform_ocr(image):
# Convert image to grayscale
gray_image = cv2.cvtColor(image, cv2.COLOR_BGR2GRAY)
# Perform OCR
text = pytesseract.image_to_string(gray_image)
return text
# Define function to process video and predict
def predict_runner_status(video_file):
cap = cv2.VideoCapture(video_file.name)
results = []
while cap.isOpened():
ret, frame = cap.read()
if not ret:
break
# Object detection
results_detection = model(frame)
# Logic for determining runner status using detected objects
# Example: if person detected, extract text and classify
objects = results_detection.pred[0][:, -1].numpy()
if 0 in objects: # 0 corresponds to person class
# Get the cropped region containing the person for OCR
person_bbox = results_detection.pred[0][np.where(objects == 0)][0][:4]
person_bbox = person_bbox.astype(int)
person_img = frame[person_bbox[1]:person_bbox[3], person_bbox[0]:person_bbox[2]]
# Perform OCR on the cropped image
text = perform_ocr(person_img)
# Classification using text classification model
inputs_classification = classification_tokenizer(text, return_tensors="pt", padding=True, truncation=True)
outputs_classification = classification_model(**inputs_classification)
predicted_class = torch.argmax(outputs_classification.logits).item()
if predicted_class == 1:
runner_status = "Out"
else:
runner_status = "Safe"
result = {
"frame_number": cap.get(cv2.CAP_PROP_POS_FRAMES),
"runner_status": runner_status
}
results.append(result)
cap.release()
return results
inputs = gr.inputs.Video(type="file", label="Upload a baseball video")
outputs = gr.outputs.Label(type="auto", label="Runner Status")
interface = gr.Interface(fn=predict_runner_status, inputs=inputs, outputs=outputs, title="Baseball Runner Status Predictor")
interface.launch(share=True)