import torch import transformers from transformers import AutoModel, AutoTokenizer, AutoModelForCausalLM, BitsAndBytesConfig from peft import ( PeftModel, LoraConfig, get_peft_model, prepare_model_for_kbit_training ) import bs4 import requests from typing import List import nltk from nltk import sent_tokenize from tqdm import tqdm import numpy as np import faiss import re import unicodedata import gradio as gr import asyncio device = "cuda" if torch.cuda.is_available() else "cpu" device base_model_id = "microsoft/phi-2" model = AutoModelForCausalLM.from_pretrained( base_model_id, device_map='auto', trust_remote_code=True ) ft_model = PeftModel.from_pretrained(model, "yurezsml/phi2_chan", offload_dir="./") def remove_accents(input_str): nfkd_form = unicodedata.normalize('NFKD', input_str) return u"".join([c for c in nfkd_form if not unicodedata.combining(c)]) def preprocess(text): text = text.lower() temp = remove_accents(text) text = text.replace('\xa0', ' ') text = text.replace('\n\n', '\n') text = text.replace('()', '') text = text.replace('[]', '') text = re.sub("[\(\[].*?[\)\]]", "", text) text = text.replace('а́', 'а') return text def split_text(text: str, n=2, character=" ") -> List[str]: text = preprocess(text) all_sentences = sent_tokenize(text) return [' '.join(all_sentences[i : i + n]) for i in range(0, len(all_sentences), 2)] def split_documents(documents: List[str]) -> list: texts = [] for text in documents: if text is not None: for passage in split_text(text): texts.append(passage) return texts def embed(text, model, tokenizer): encoded_input = tokenizer(text, padding=True, truncation=True, max_length=512, return_tensors='pt').to(model.device) with torch.no_grad(): model_output = model(**encoded_input) token_embeddings = model_output[0] #First element of model_output contains all token embeddings input_mask_expanded = encoded_input['attention_mask'].unsqueeze(-1).expand(token_embeddings.size()).float() sum_embeddings = torch.sum(token_embeddings * input_mask_expanded, 1) sum_mask = torch.clamp(input_mask_expanded.sum(1), min=1e-9) return sum_embeddings / sum_mask response = requests.get("https://en.wikipedia.org/wiki/Chandler_Bing") base_text = '' if response: html = bs4.BeautifulSoup(response.text, 'html.parser') title = html.select("#firstHeading")[0].text paragraphs = html.select("p") for para in paragraphs: base_text = base_text + para.text fact_coh_tokenizer = AutoTokenizer.from_pretrained("DeepPavlov/bert-base-multilingual-cased-sentence") fact_coh_model = AutoModel.from_pretrained("DeepPavlov/bert-base-multilingual-cased-sentence") fact_coh_model.to(device) nltk.download('punkt') subsample_documents = split_documents([base_text]) batch_size = 8 total_batches = len(subsample_documents) // batch_size + (0 if len(subsample_documents) % batch_size == 0 else 1) base = list() for i in tqdm(range(0, len(subsample_documents), batch_size), total=total_batches, desc="Processing Batches"): batch_texts = subsample_documents[i:i + batch_size] base.extend(embed(batch_texts, fact_coh_model, fact_coh_tokenizer)) base = np.array([vector.cpu().numpy() for vector in base]) index = faiss.IndexFlatL2(base.shape[1]) index.add(base) async def get_context(subsample_documents, query, index, model, tokenizer): k = 5 xq = embed(query.lower(), model, tokenizer).cpu().numpy() D, I = index.search(xq.reshape(1, 768), k) return subsample_documents[I[0][0]] async def get_prompt(question, use_rag, answers_history: list[str]): eval_prompt = '###system: answer the question as Chandler. ' for idx, text in enumerate(answers_history): if idx % 2 == 0: eval_prompt = eval_prompt + f' ###question: {text}' else: eval_prompt = eval_prompt + f' ###answer: {text} ' if use_rag: context = await asyncio.wait_for(get_context(subsample_documents, question, index, fact_coh_model, fact_coh_tokenizer), timeout=60) eval_prompt = eval_prompt + f' Chandler. {context}' eval_prompt = eval_prompt + f' ###question: {question} ' eval_prompt = ' '.join(eval_prompt.split()) return eval_prompt async def get_answer(question, use_rag, answers_history: list[str]): eval_prompt = await asyncio.wait_for(get_prompt(question, use_rag, answers_history), timeout=60) model_input = tokenizer(eval_prompt, return_tensors="pt").to(device) ft_model.eval() with torch.no_grad(): answer = tokenizer.decode(ft_model.generate(**model_input, max_new_tokens=30, repetition_penalty=1.11)[0], skip_special_tokens=True) + '\n' answer = ' '.join(answer.split()) if eval_prompt in answer: answer = answer.replace(eval_prompt,'') answer = answer.split('###answer')[1] dialog = '' for idx, text in enumerate(answers_history): if idx % 2 == 0: dialog = dialog + f'you: {text}\n' else: dialog = dialog + f'Chandler: {text}\n' dialog = dialog + f'you: {question}\n' dialog = dialog + f'Chandler: {answer}\n' answers_history.append(question) answers_history.append(answer) return dialog, answers_history async def async_proc(question, use_rag, answers_history: list[str]): try: return await asyncio.wait_for(get_answer(question, use_rag, answers_history), timeout=60) except asyncio.TimeoutError: return "Processing timed out.", answers_history gr.Interface( fn=async_proc, inputs=[ gr.Textbox( label="Question", ), gr.Checkbox(label="Use RAG", info="Pick to RAG to improve factual coherence"), gr.State(value=[]), ], outputs=[ gr.Textbox( label="Chat" ), gr.State(), ], title="Асинхронный сервис для чат-бота по сериалу Друзья", concurrency_limit=5 ).queue().launch(share=True, debug=True)