Spaces:
Running
Running
import numpy as np | |
import pandas as pd | |
import streamlit as st | |
import requests | |
from sentence_transformers import util | |
from sentence_transformers import SentenceTransformer, util | |
import os | |
st.set_page_config(page_title="Custom Button Example", layout="wide") | |
from dotenv import load_dotenv | |
from langchain.chat_models.gigachat import GigaChat | |
from langchain.prompts.chat import ( | |
ChatPromptTemplate, | |
HumanMessagePromptTemplate, | |
SystemMessagePromptTemplate, | |
) | |
load_dotenv() | |
credentials = os.getenv('API_KEY') | |
chat = GigaChat(model='GigaChat', credentials=credentials, verify_ssl_certs=False) | |
def load_model_all_mpnet(): | |
return SentenceTransformer('all-mpnet-base-v2') | |
model_mp = load_model_all_mpnet() | |
def load_embeddings(file_path): | |
return np.load(file_path) | |
book_embeddings_mp = load_embeddings('data/book_embeddings.npy') | |
def load_data(file_path): | |
return pd.read_csv(file_path) | |
df = load_data('data/books_data_cleaned.csv') | |
def load_model_msmarco(): | |
return SentenceTransformer('msmarco-roberta-base-v3') | |
model_ms = load_model_msmarco() | |
def load_embeddings(file_path): | |
return np.load(file_path) | |
book_embeddings_ms = load_embeddings('data/book_embeddings_ms.npy') | |
def get_embedding(text, model): | |
text = model.encode(text, convert_to_tensor=True) | |
return text | |
def get_top_10_recommendations(query, model, book_embeddings, top_k): | |
query_embedding = get_embedding(query, model).cpu() | |
similarities = util.pytorch_cos_sim(query_embedding, book_embeddings)[0] | |
top_results = similarities.cpu().numpy().argsort()[::-1][:top_k] | |
top_books = df.iloc[top_results].copy() | |
similarity_scores = similarities.cpu().numpy()[top_results] | |
top_books['similarity_score'] = similarity_scores | |
return top_books | |
st.title('Рекомендации книг') | |
search = st.radio( | |
"Выберите тип семантического поиска:", | |
[":blue[Симметричный]", ":blue[Асимметричный]"], | |
captions=[ | |
"Используем 'all-mpnet-base-v2'", | |
"Используем 'msmarco-roberta-base-v3'", | |
], | |
horizontal=True, | |
) | |
def params(search): | |
if search == ":blue[Симметричный]": | |
text = '''Я ищу книги в жанре фэнтези, которые описывают приключения магов и волшебников, обучающихся в специальных магических школах и сражающихся с темными силами или злыми существами. Особенно интересуют произведения, где главные герои сталкиваются с эпическими испытаниями и развивают свои уникальные способности.''' | |
model = model_mp | |
book_embeddings = book_embeddings_mp | |
return text, model, book_embeddings | |
elif search == ":blue[Асимметричный]": | |
text = '''путешествие во времени''' | |
model = model_ms | |
book_embeddings = book_embeddings_ms | |
return text, model, book_embeddings | |
text, model, book_embeddings = params(search) | |
col1, col2 = st.columns([3, 1]) | |
with col1: | |
query = st.text_area('Введите запрос, чтобы получить рекомендации', f'{text}', height=95) | |
with col2: | |
number = st.number_input( | |
"Сколько книг найти?", value=3 | |
) | |
find_button = st.button('Найти', key='find_button', use_container_width=True) | |
if find_button and query: | |
top_10_books = get_top_10_recommendations(query, model, book_embeddings, number) | |
for idx, row in top_10_books.iterrows(): | |
with st.container(): | |
col1, col2 = st.columns([1, 3]) | |
with col1: | |
st.image(row['image_url'], width = 300) | |
with col2: | |
st.subheader(f"{row['title']}") | |
st.write(f"**Автор:** {row['author']}") | |
tab1, tab2 = st.tabs(['Аннотация', 'Краткое содержание']) | |
with tab1: | |
st.write(row['annotation']) | |
st.metric(label="Схожесть", value=f"{row['similarity_score']:.3f}") | |
st.write(f"**Ссылка:** {row['page_url']}") | |
with tab2: | |
template = "ты умеешь кратко в несколько предложений описывать содержание книги по ее названию" | |
system_message_prompt = SystemMessagePromptTemplate.from_template(template) | |
human_template = "Кратко опиши содержание книги под названием: {book_title}" | |
human_message_prompt = HumanMessagePromptTemplate.from_template(human_template) | |
chat_prompt = ChatPromptTemplate.from_messages( | |
[system_message_prompt, human_message_prompt] | |
) | |
formatted_prompt = chat_prompt.format_prompt( | |
book_title=row['title'] | |
) | |
response = chat(formatted_prompt.to_messages()) | |
st.write(response.content) | |
st.write("---") | |
st.write("---") | |