import base64 import os from dataclasses import dataclass from typing import Final import faiss import numpy as np import pandas as pd import streamlit as st from pipeline import clip_wrapper class SemanticSearcher: def __init__(self, dataset: pd.DataFrame): dim_columns = dataset.filter(regex="^dim_").columns self.embedder = clip_wrapper.ClipWrapper().texts2vec self.metadata = dataset.drop(columns=dim_columns) self.index = faiss.IndexFlatIP(len(dim_columns)) self.index.add(np.ascontiguousarray(dataset[dim_columns].to_numpy(np.float32))) def search(self, query: str) -> list["SearchResult"]: v = self.embedder([query]).detach().numpy() D, I = self.index.search(v, 10) return [ SearchResult( video_id=row["video_id"], frame_idx=row["frame_idx"], timestamp=row["timestamp"], score=score, ) for score, (_, row) in zip(D[0], self.metadata.iloc[I[0]].iterrows()) ] DATASET_PATH: Final[str] = os.environ.get("DATASET_PATH", "data/dataset.parquet") SEARCHER: Final[SemanticSearcher] = SemanticSearcher(pd.read_parquet(DATASET_PATH)) @dataclass class SearchResult: video_id: str frame_idx: int timestamp: float score: float def get_video_url(video_id: str, timestamp: float) -> str: return f"https://www.youtube.com/watch?v={video_id}&t={int(timestamp)}" def display_search_results(results: list[SearchResult]) -> None: col_count = 3 # Number of videos per row col_num = 0 # Counter to keep track of the current column row = st.empty() # Placeholder for the current row for i, result in enumerate(results): if col_num == 0: row = st.columns(col_count) # Create a new row of columns with row[col_num]: # Apply CSS styling to the video container st.markdown( """ """, unsafe_allow_html=True, ) # Display the embedded YouTube video # st.video(get_video_url(result.video_id), start_time=int(result.timestamp)) # st.image(f"data/images/{result.video_id}/{result.frame_idx}.jpg") with open( f"data/images/{result.video_id}/{result.frame_idx}.jpg", "rb" ) as f: image = f.read() encoded = base64.b64encode(image).decode() st.markdown( f""" frame {result.frame_idx} """, unsafe_allow_html=True, ) col_num += 1 if col_num >= col_count: col_num = 0 def main(): st.set_page_config(page_title="video-semantic-search", layout="wide") st.header("Video Semantic Search") st.text_input("What are you looking for?", key="query") query = st.session_state["query"] if query: display_search_results(SEARCHER.search(query)) if __name__ == "__main__": main()