Books_recommendation / pages /recommendations.py
PolyakovK's picture
secret
3191631
import numpy as np
import pandas as pd
import streamlit as st
import requests
from sentence_transformers import util
from sentence_transformers import SentenceTransformer, util
import os
st.set_page_config(page_title="Custom Button Example", layout="wide")
from dotenv import load_dotenv
from langchain.chat_models.gigachat import GigaChat
from langchain.prompts.chat import (
ChatPromptTemplate,
HumanMessagePromptTemplate,
SystemMessagePromptTemplate,
)
load_dotenv()
credentials = os.getenv('API_KEY')
chat = GigaChat(model='GigaChat', credentials=credentials, verify_ssl_certs=False)
@st.cache_resource
def load_model_all_mpnet():
return SentenceTransformer('all-mpnet-base-v2')
model_mp = load_model_all_mpnet()
@st.cache_data
def load_embeddings(file_path):
return np.load(file_path)
book_embeddings_mp = load_embeddings('data/book_embeddings.npy')
@st.cache_data
def load_data(file_path):
return pd.read_csv(file_path)
df = load_data('data/books_data_cleaned.csv')
@st.cache_resource
def load_model_msmarco():
return SentenceTransformer('msmarco-roberta-base-v3')
model_ms = load_model_msmarco()
@st.cache_data
def load_embeddings(file_path):
return np.load(file_path)
book_embeddings_ms = load_embeddings('data/book_embeddings_ms.npy')
def get_embedding(text, model):
text = model.encode(text, convert_to_tensor=True)
return text
def get_top_10_recommendations(query, model, book_embeddings, top_k):
query_embedding = get_embedding(query, model).cpu()
similarities = util.pytorch_cos_sim(query_embedding, book_embeddings)[0]
top_results = similarities.cpu().numpy().argsort()[::-1][:top_k]
top_books = df.iloc[top_results].copy()
similarity_scores = similarities.cpu().numpy()[top_results]
top_books['similarity_score'] = similarity_scores
return top_books
st.title('Рекомендации книг')
search = st.radio(
"Выберите тип семантического поиска:",
[":blue[Симметричный]", ":blue[Асимметричный]"],
captions=[
"Используем 'all-mpnet-base-v2'",
"Используем 'msmarco-roberta-base-v3'",
],
horizontal=True,
)
def params(search):
if search == ":blue[Симметричный]":
text = '''Я ищу книги в жанре фэнтези, которые описывают приключения магов и волшебников, обучающихся в специальных магических школах и сражающихся с темными силами или злыми существами. Особенно интересуют произведения, где главные герои сталкиваются с эпическими испытаниями и развивают свои уникальные способности.'''
model = model_mp
book_embeddings = book_embeddings_mp
return text, model, book_embeddings
elif search == ":blue[Асимметричный]":
text = '''путешествие во времени'''
model = model_ms
book_embeddings = book_embeddings_ms
return text, model, book_embeddings
text, model, book_embeddings = params(search)
col1, col2 = st.columns([3, 1])
with col1:
query = st.text_area('Введите запрос, чтобы получить рекомендации', f'{text}', height=95)
with col2:
number = st.number_input(
"Сколько книг найти?", value=3
)
find_button = st.button('Найти', key='find_button', use_container_width=True)
if find_button and query:
top_10_books = get_top_10_recommendations(query, model, book_embeddings, number)
for idx, row in top_10_books.iterrows():
with st.container():
col1, col2 = st.columns([1, 3])
with col1:
st.image(row['image_url'], width = 300)
with col2:
st.subheader(f"{row['title']}")
st.write(f"**Автор:** {row['author']}")
tab1, tab2 = st.tabs(['Аннотация', 'Краткое содержание'])
with tab1:
st.write(row['annotation'])
st.metric(label="Схожесть", value=f"{row['similarity_score']:.3f}")
st.write(f"**Ссылка:** {row['page_url']}")
with tab2:
template = "ты умеешь кратко в несколько предложений описывать содержание книги по ее названию"
system_message_prompt = SystemMessagePromptTemplate.from_template(template)
human_template = "Кратко опиши содержание книги под названием: {book_title}"
human_message_prompt = HumanMessagePromptTemplate.from_template(human_template)
chat_prompt = ChatPromptTemplate.from_messages(
[system_message_prompt, human_message_prompt]
)
formatted_prompt = chat_prompt.format_prompt(
book_title=row['title']
)
response = chat(formatted_prompt.to_messages())
st.write(response.content)
st.write("---")
st.write("---")