Spaces:
Build error
Build error
import logging | |
import os | |
import json | |
import matplotlib.pyplot as plt | |
import gradio as gr | |
from faiss import read_index_binary, write_index_binary | |
from config import * | |
from videomatch import index_hashes_for_video, get_decent_distance, \ | |
get_video_index, compare_videos, get_change_points, get_videomatch_df | |
from plot import plot_segment_comparison | |
# Basic logging template only showing info, change to debug during debugging | |
logging.basicConfig() | |
logging.getLogger().setLevel(logging.INFO) | |
def transfer_data_indices_to_temp(temp_path = VIDEO_DIRECTORY, data_path='./data'): | |
""" The binary indices created from the .json file are not stored in the temporary directory | |
This function will load these indices and write them to the temporary directory. | |
Doing it this way preserves the way to link dynamically downloaded files and the static | |
files are the same. | |
Args: | |
temp_path (str): Directory of temporary storage for binary indices. | |
data_path (str): Directory of the indices created from the .json file. | |
Returns: | |
None. | |
""" | |
index_files = os.listdir(data_path) | |
for index_file in index_files: | |
# Read from static location and write to temp storage | |
binary_index = read_index_binary(os.path.join(data_path, index_file)) | |
write_index_binary(binary_index, f'{temp_path}/{index_file}') | |
def compare(url, target): | |
""" Compare a single url (user submitted) to a single target entry and return the corresponding | |
figure and decision (.json-esque list of dictionaries) | |
Args: | |
url (str): User submitted url of a video which will be downloaded and cached. | |
target (dict): Target entry with a 'url' and 'mp4' attribute. | |
Returns: | |
fig (Figure): Figure that shows the comparison between two videos. | |
segment_decisions (dict): JSON-style dictionary containing the decision information of the comparison between two videos. | |
""" | |
target_title = target['url'] | |
target_mp4 = target['mp4'] | |
# Get source and target indices | |
source_index, source_hash_vectors = get_video_index(url) | |
target_index, _ = get_video_index(target_mp4) | |
# Get decent distance by comparing url index with the target hash vectors + target index | |
distance = get_decent_distance(source_index, source_hash_vectors, target_index, MIN_DISTANCE, MAX_DISTANCE) | |
if distance == None: | |
logging.info(f"No matches found between {url} and {target_mp4}!") | |
return plt.figure(), [] | |
else: | |
# Compare videos with heuristic distance | |
lims, D, I, hash_vectors = compare_videos(source_hash_vectors, target_index, MIN_DISTANCE = distance) | |
# Get dataframe holding all information | |
df = get_videomatch_df(lims, D, I, hash_vectors, distance) | |
# Determine change point using ROBUST method based on column ROLL_OFFSET_MODE | |
change_points = get_change_points(df, metric="ROLL_OFFSET_MODE", method="ROBUST") | |
# Plot and get figure and .json-style segment decision | |
fig, segment_decision = plot_segment_comparison(df, change_points, video_id=target_title, video_mp4=target_mp4) | |
return fig, segment_decision | |
def multiple_comparison(url, return_figure=False): | |
""" Compare a url (user submitted) to all target entries and return the corresponding | |
figures and decisions (.json-style list of dictionaries). These target entries are defined in the main | |
by loading .json file containing the videos to compare to. | |
Args: | |
url (str): User submitted url which will be downloaded and cached. | |
return_figure (bool): Toggle parameter to decide if to return figures or decision, needed for Gradio plotting. | |
Returns: | |
Either a Figure or a .json-style dictionary with decision information. | |
""" | |
# Figure and decision (list of dicts) storage | |
figures, decisions = [], [] | |
for target in TARGET_ENTRIES: | |
# Make single comparison | |
fig, segment_decision = compare(url, target) | |
# Add decisions to global decision list | |
decisions.extend(segment_decision) | |
figures.append(fig) | |
# Return figure or decision | |
if return_figure: | |
return figures | |
return decisions | |
def plot_multiple_comparison(url): | |
""" Helper function to return figure instead of decisions that is needed for Gradio. | |
Args: | |
url (str): User submitted url which will be downloaded and cached. | |
Returns: | |
The multiple comparison, but then returning the plots as Figure(s). | |
""" | |
return multiple_comparison(url, return_figure=True) | |
# Write stored target videos to temporary storage | |
transfer_data_indices_to_temp() # NOTE: Only works after doing 'git lfs pull' to actually obtain the .index files | |
# Load stored target videos that will be compared to | |
with open('apb2022.json', "r") as json_file: | |
TARGET_ENTRIES = json.load(json_file) | |
# Some example videos that can be compared to | |
EXAMPLE_VIDEO_URLS = [#"https://www.youtube.com/watch?v=qIaqMqMweM4", | |
"https://drive.google.com/uc?id=1Y1-ypXOvLrp1x0cjAe_hMobCEdA0UbEo&export=download", | |
#"https://video.twimg.com/amplify_video/1575576025651617796/vid/480x852/jP057nPfPJSUM0kR.mp4?tag=14", | |
#"https://drive.google.com/uc?id=1XW0niHR1k09vPNv1cp6NvdGXe7FHJc1D&export=download", | |
] | |
# Interface to simply index | |
index_iface = gr.Interface(fn=lambda url: index_hashes_for_video(url).ntotal, | |
inputs="text", | |
outputs="text", | |
examples=EXAMPLE_VIDEO_URLS) | |
# Interface to plot comparisons | |
plot_compare_iface = gr.Interface(fn=plot_multiple_comparison, | |
inputs=["text"], | |
outputs=[gr.Plot(label=entry['url']) for entry in TARGET_ENTRIES], | |
examples=EXAMPLE_VIDEO_URLS) | |
# Interface to get .json decision list | |
auto_compare_iface = gr.Interface(fn=multiple_comparison, | |
inputs=["text"], | |
outputs=["json"], | |
examples=EXAMPLE_VIDEO_URLS) | |
# Interface consists of three tabs | |
iface = gr.TabbedInterface([auto_compare_iface, plot_compare_iface, index_iface], ["AutoCompare", "PlotAutoCompare", "Index"]) | |
if __name__ == "__main__": | |
# To be able to plot in Gradio as we want, these steps are a fix | |
import matplotlib | |
matplotlib.use('SVG') | |
iface.launch(show_error=True) | |
# iface.launch(share=True, debug=True) | |