import numpy as np import pandas as pd import streamlit as st import requests from sentence_transformers import util from sentence_transformers import SentenceTransformer, util import os st.set_page_config(page_title="Custom Button Example", layout="wide") from dotenv import load_dotenv from langchain.chat_models.gigachat import GigaChat from langchain.prompts.chat import ( ChatPromptTemplate, HumanMessagePromptTemplate, SystemMessagePromptTemplate, ) load_dotenv() credentials = os.getenv('API_KEY') chat = GigaChat(model='GigaChat', credentials=credentials, verify_ssl_certs=False) @st.cache_resource def load_model_all_mpnet(): return SentenceTransformer('all-mpnet-base-v2') model_mp = load_model_all_mpnet() @st.cache_data def load_embeddings(file_path): return np.load(file_path) book_embeddings_mp = load_embeddings('data/book_embeddings.npy') @st.cache_data def load_data(file_path): return pd.read_csv(file_path) df = load_data('data/books_data_cleaned.csv') @st.cache_resource def load_model_msmarco(): return SentenceTransformer('msmarco-roberta-base-v3') model_ms = load_model_msmarco() @st.cache_data def load_embeddings(file_path): return np.load(file_path) book_embeddings_ms = load_embeddings('data/book_embeddings_ms.npy') def get_embedding(text, model): text = model.encode(text, convert_to_tensor=True) return text def get_top_10_recommendations(query, model, book_embeddings, top_k): query_embedding = get_embedding(query, model).cpu() similarities = util.pytorch_cos_sim(query_embedding, book_embeddings)[0] top_results = similarities.cpu().numpy().argsort()[::-1][:top_k] top_books = df.iloc[top_results].copy() similarity_scores = similarities.cpu().numpy()[top_results] top_books['similarity_score'] = similarity_scores return top_books st.title('Рекомендации книг') search = st.radio( "Выберите тип семантического поиска:", [":blue[Симметричный]", ":blue[Асимметричный]"], captions=[ "Используем 'all-mpnet-base-v2'", "Используем 'msmarco-roberta-base-v3'", ], horizontal=True, ) def params(search): if search == ":blue[Симметричный]": text = '''Я ищу книги в жанре фэнтези, которые описывают приключения магов и волшебников, обучающихся в специальных магических школах и сражающихся с темными силами или злыми существами. Особенно интересуют произведения, где главные герои сталкиваются с эпическими испытаниями и развивают свои уникальные способности.''' model = model_mp book_embeddings = book_embeddings_mp return text, model, book_embeddings elif search == ":blue[Асимметричный]": text = '''путешествие во времени''' model = model_ms book_embeddings = book_embeddings_ms return text, model, book_embeddings text, model, book_embeddings = params(search) col1, col2 = st.columns([3, 1]) with col1: query = st.text_area('Введите запрос, чтобы получить рекомендации', f'{text}', height=95) with col2: number = st.number_input( "Сколько книг найти?", value=3 ) find_button = st.button('Найти', key='find_button', use_container_width=True) if find_button and query: top_10_books = get_top_10_recommendations(query, model, book_embeddings, number) for idx, row in top_10_books.iterrows(): with st.container(): col1, col2 = st.columns([1, 3]) with col1: st.image(row['image_url'], width = 300) with col2: st.subheader(f"{row['title']}") st.write(f"**Автор:** {row['author']}") tab1, tab2 = st.tabs(['Аннотация', 'Краткое содержание']) with tab1: st.write(row['annotation']) st.metric(label="Схожесть", value=f"{row['similarity_score']:.3f}") st.write(f"**Ссылка:** {row['page_url']}") with tab2: template = "ты умеешь кратко в несколько предложений описывать содержание книги по ее названию" system_message_prompt = SystemMessagePromptTemplate.from_template(template) human_template = "Кратко опиши содержание книги под названием: {book_title}" human_message_prompt = HumanMessagePromptTemplate.from_template(human_template) chat_prompt = ChatPromptTemplate.from_messages( [system_message_prompt, human_message_prompt] ) formatted_prompt = chat_prompt.format_prompt( book_title=row['title'] ) response = chat(formatted_prompt.to_messages()) st.write(response.content) st.write("---") st.write("---")