|
import gradio as gr |
|
from src.llm import query_chroma |
|
from src.reranker_warm import rank_anime_warm |
|
import pandas as pd |
|
from pathlib import Path |
|
import requests |
|
|
|
css = """ |
|
footer {display: none !important} |
|
.gradio-container { |
|
max-width: 1200px; |
|
margin: auto; |
|
} |
|
.contain { |
|
background: rgba(255, 255, 255, 0.05); |
|
border-radius: 12px; |
|
padding: 20px; |
|
} |
|
.submit-btn { |
|
background: linear-gradient(90deg, #4B79A1 0%, #283E51 100%) !important; |
|
border: none !important; |
|
color: white !important; |
|
} |
|
.submit-btn:hover { |
|
transform: translateY(-2px); |
|
box-shadow: 0 5px 15px rgba(0,0,0,0.2); |
|
} |
|
.title { |
|
text-align: center; |
|
font-size: 2.5em; |
|
font-weight: bold; |
|
margin-bottom: 1em; |
|
background: linear-gradient(90deg, #4B79A1 0%, #283E51 100%); |
|
-webkit-background-clip: text; |
|
-webkit-text-fill-color: transparent; |
|
} |
|
.output-image { |
|
width: 100% !important; |
|
max-width: 100% !important; |
|
} |
|
""" |
|
|
|
|
|
def download_pic(names: list[str]): |
|
df = pd.read_csv(str(Path(Path(__file__).parent, "src/data/final_anime_list.csv"))) |
|
pic_dir = Path(Path(__file__).parent, "src/data/pics") |
|
|
|
if not pic_dir.exists(): |
|
pic_dir.mkdir(exist_ok=True) |
|
|
|
df = df[df.Name.isin(names)] |
|
df = df[["Name", "Image URL", "Synopsis"]].set_index("Name").reindex(names) |
|
synopsis_list = df['Synopsis'].tolist() |
|
file_paths = [] |
|
for url in df["Image URL"]: |
|
file_name = "_".join(url.split("/")[-2:]) |
|
file_path = Path(pic_dir, file_name) |
|
|
|
if file_path.exists(): |
|
file_paths.append(str(file_path)) |
|
continue |
|
|
|
response = requests.get(url) |
|
response.raise_for_status() |
|
|
|
with open(Path(pic_dir, file_name), 'wb') as file: |
|
file.write(response.content) |
|
file_paths.append(str(file_path)) |
|
|
|
print(f"Image downloaded successfully: {url}") |
|
return file_paths, synopsis_list |
|
|
|
|
|
|
|
def integration_warm(query: str): |
|
anime_name_list = query_chroma(query=query, anime_count=100) |
|
|
|
|
|
|
|
|
|
|
|
anime_name_list = rank_anime_warm(userid=12, anime_list=anime_name_list)[:4] |
|
final_names = [x[0] for x in anime_name_list] |
|
|
|
anime_pic_list, synopsis_list = download_pic(list(final_names)) |
|
|
|
return [*anime_name_list, *anime_pic_list, *synopsis_list] |
|
|
|
|
|
|
|
|
|
def clear_prompt(): |
|
"""Function to clear the prompt box.""" |
|
return "" |
|
|
|
|
|
def feedback_button(action, anime_name): |
|
|
|
return f"You {action}d {anime_name}!" |
|
|
|
|
|
with gr.Blocks(theme=gr.themes.Soft(), css=css) as demo: |
|
gr.HTML('<div class="title">AniQuest</div>') |
|
gr.HTML( |
|
'<div style="text-align: center; margin-bottom: 2em; color: #666; font-size: 24px;">We recommendate animes based on your description</div>') |
|
gr.HTML(""" |
|
<div style="color: red; margin-bottom: 1em; text-align: center; padding: 10px; background: rgba(255,0,0,0.1); border-radius: 8px;"> |
|
β οΈ Welcome, [user_id: 12] to this recommendation system β οΈ |
|
</div> |
|
""") |
|
|
|
with gr.Column(): |
|
prompt = gr.Textbox( |
|
label="Query", |
|
placeholder="Describe the anime you want to watch next ...", |
|
lines=1 |
|
) |
|
with gr.Row(): |
|
generate_btn = gr.Button( |
|
"π Submit", |
|
elem_classes=["submit-btn"] |
|
) |
|
clear_btn = gr.Button( |
|
"π
Clear", |
|
elem_classes=["submit-btn"] |
|
) |
|
with gr.Row(): |
|
for i in range(4): |
|
|
|
anime_names = [] |
|
feedback_texts = [] |
|
|
|
with gr.Column(scale=1, elem_classes=["anime-block"]): |
|
|
|
exec(f"anime{i + 1} = gr.Textbox(label='Anime {i + 1}')") |
|
|
|
with gr.Row(): |
|
like_btn = gr.Button("π Like") |
|
dislike_btn = gr.Button("π Dislike") |
|
|
|
exec(f"image{i + 1} = gr.Image(label='Image', elem_classes=['output-image', 'fixed-width'])") |
|
exec( |
|
f"description{i + 1} = gr.HTML('<div class=\"anime-description\" style=\"margin-top: 10px; font-size: 14px; color: #666;\">Description for anime {i + 1}</div>')") |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
generate_btn.click( |
|
fn=integration_warm, |
|
inputs=[prompt], |
|
outputs=[anime1, anime2, anime3, anime4, image1, image2, image3, image4, description1, description2, description3, description4, ] |
|
) |
|
|
|
clear_btn.click( |
|
fn=clear_prompt, |
|
inputs=[], |
|
outputs=[prompt] |
|
) |
|
|
|
|
|
|
|
|
|
if __name__ == "__main__": |
|
|
|
|
|
demo.launch(server_name="0.0.0.0", server_port=7860) |
|
|