import base64
import hashlib
import os
import subprocess
from dataclasses import dataclass
from typing import Final
import faiss
import numpy as np
import pandas as pd
import streamlit as st
from streamlit.logger import get_logger
from pipeline import clip_wrapper
from pipeline.process_videos import DATAFRAME_PATH
NUM_FRAMES_TO_RETURN = 21
logger = get_logger(__name__)
class SemanticSearcher:
def __init__(self, dataset: pd.DataFrame):
dim_columns = dataset.filter(regex="^dim_").columns
self.embedder = clip_wrapper.ClipWrapper().texts2vec
self.metadata = dataset.drop(columns=dim_columns)
self.index = faiss.IndexFlatIP(len(dim_columns))
self.index.add(np.ascontiguousarray(dataset[dim_columns].to_numpy(np.float32)))
def search(self, query: str) -> list["SearchResult"]:
v = self.embedder([query]).detach().numpy()
D, I = self.index.search(v, NUM_FRAMES_TO_RETURN)
return [
SearchResult(
video_id=row["video_id"],
frame_idx=row["frame_idx"],
timestamp=row["timestamp"],
base64_image=row["base64_image"],
score=score,
)
for score, (_, row) in zip(D[0], self.metadata.iloc[I[0]].iterrows())
]
@st.cache_resource
def get_semantic_searcher():
return SemanticSearcher(pd.read_parquet(DATAFRAME_PATH))
@st.cache_data
def get_git_hash() -> str:
return subprocess.check_output(["git", "rev-parse", "HEAD"]).decode().strip()
@dataclass
class SearchResult:
video_id: str
frame_idx: int
timestamp: float
base64_image: str
score: float
def get_video_url(video_id: str, timestamp: float) -> str:
timestamp = max(0, timestamp - 1)
return f"https://www.youtube.com/watch?v={video_id}&t={int(timestamp)}"
def display_search_results(results: list[SearchResult]) -> None:
col_count = 3 # Number of videos per row
col_num = 0 # Counter to keep track of the current column
row = st.empty() # Placeholder for the current row
for i, result in enumerate(results):
if col_num == 0:
row = st.columns(col_count) # Create a new row of columns
with row[col_num]:
# Apply CSS styling to the video container
st.markdown(
"""
""",
unsafe_allow_html=True,
)
st.markdown(
f"""
""",
unsafe_allow_html=True,
)
col_num += 1
if col_num >= col_count:
col_num = 0
def main():
st.set_page_config(page_title="video-semantic-search", layout="wide")
st.header("Visual content search over music videos")
st.markdown("_App by Ben Tenmann and Sidney Radcliffe_")
searcher = get_semantic_searcher()
num_videos = len(searcher.metadata.video_id.unique())
st.text_input(
f"What are you looking for? Search over {num_videos} music videos.", key="query"
)
query = st.session_state["query"]
if query:
logger.info(f"Recieved query... {hashlib.md5(query.encode()).hexdigest()}")
st.text("Click image to open video")
display_search_results(searcher.search(query))
st.text(f"Build: {get_git_hash()[0:7]}")
if __name__ == "__main__":
main()