|
import gradio as gr |
|
import torch |
|
import cv2 |
|
import pytesseract |
|
from transformers import AutoTokenizer, AutoModelForSequenceClassification |
|
|
|
|
|
model_name = "flax-community/yolov5s-v1-coco" |
|
model = torch.hub.load('ultralytics/yolov5', 'yolov5s', pretrained=True) |
|
|
|
|
|
classification_model_name = "distilbert-base-uncased" |
|
classification_tokenizer = AutoTokenizer.from_pretrained(classification_model_name) |
|
classification_model = AutoModelForSequenceClassification.from_pretrained(classification_model_name) |
|
|
|
|
|
def perform_ocr(image): |
|
|
|
gray_image = cv2.cvtColor(image, cv2.COLOR_BGR2GRAY) |
|
|
|
text = pytesseract.image_to_string(gray_image) |
|
return text |
|
|
|
|
|
def predict_runner_status(video_file): |
|
cap = cv2.VideoCapture(video_file.name) |
|
|
|
results = [] |
|
|
|
while cap.isOpened(): |
|
ret, frame = cap.read() |
|
if not ret: |
|
break |
|
|
|
|
|
results_detection = model(frame) |
|
|
|
|
|
|
|
objects = results_detection.pred[0][:, -1].numpy() |
|
if 0 in objects: |
|
|
|
person_bbox = results_detection.pred[0][np.where(objects == 0)][0][:4] |
|
person_bbox = person_bbox.astype(int) |
|
person_img = frame[person_bbox[1]:person_bbox[3], person_bbox[0]:person_bbox[2]] |
|
|
|
|
|
text = perform_ocr(person_img) |
|
|
|
|
|
inputs_classification = classification_tokenizer(text, return_tensors="pt", padding=True, truncation=True) |
|
outputs_classification = classification_model(**inputs_classification) |
|
predicted_class = torch.argmax(outputs_classification.logits).item() |
|
if predicted_class == 1: |
|
runner_status = "Out" |
|
else: |
|
runner_status = "Safe" |
|
|
|
result = { |
|
"frame_number": cap.get(cv2.CAP_PROP_POS_FRAMES), |
|
"runner_status": runner_status |
|
} |
|
results.append(result) |
|
|
|
cap.release() |
|
|
|
return results |
|
|
|
inputs = gr.inputs.Video(type="file", label="Upload a baseball video") |
|
outputs = gr.outputs.Label(type="auto", label="Runner Status") |
|
interface = gr.Interface(fn=predict_runner_status, inputs=inputs, outputs=outputs, title="Baseball Runner Status Predictor") |
|
interface.launch(share=True) |
|
|