Phi-2_Chan / app.py
yurezsml's picture
Upload app.py
7a514c1 verified
raw
history blame
No virus
6.01 kB
import torch
import transformers
from transformers import AutoModel, AutoTokenizer, AutoModelForCausalLM, BitsAndBytesConfig
from peft import (
PeftModel,
LoraConfig,
get_peft_model,
prepare_model_for_kbit_training
)
import bs4
import requests
from typing import List
import nltk
from nltk import sent_tokenize
from tqdm import tqdm
import numpy as np
import faiss
import re
import unicodedata
import gradio as gr
import asyncio
device = "cuda" if torch.cuda.is_available() else "cpu"
device
base_model_id = "microsoft/phi-2"
model = AutoModelForCausalLM.from_pretrained(
base_model_id,
device_map='auto',
trust_remote_code=True
)
ft_model = PeftModel.from_pretrained(model, "yurezsml/phi2_chan", offload_dir="./")
def remove_accents(input_str):
nfkd_form = unicodedata.normalize('NFKD', input_str)
return u"".join([c for c in nfkd_form if not unicodedata.combining(c)])
def preprocess(text):
text = text.lower()
temp = remove_accents(text)
text = text.replace('\xa0', ' ')
text = text.replace('\n\n', '\n')
text = text.replace('()', '')
text = text.replace('[]', '')
text = re.sub("[\(\[].*?[\)\]]", "", text)
text = text.replace('а́', 'а')
return text
def split_text(text: str, n=2, character=" ") -> List[str]:
text = preprocess(text)
all_sentences = sent_tokenize(text)
return [' '.join(all_sentences[i : i + n]) for i in range(0, len(all_sentences), 2)]
def split_documents(documents: List[str]) -> list:
texts = []
for text in documents:
if text is not None:
for passage in split_text(text):
texts.append(passage)
return texts
def embed(text, model, tokenizer):
encoded_input = tokenizer(text, padding=True, truncation=True, max_length=512, return_tensors='pt').to(model.device)
with torch.no_grad():
model_output = model(**encoded_input)
token_embeddings = model_output[0] #First element of model_output contains all token embeddings
input_mask_expanded = encoded_input['attention_mask'].unsqueeze(-1).expand(token_embeddings.size()).float()
sum_embeddings = torch.sum(token_embeddings * input_mask_expanded, 1)
sum_mask = torch.clamp(input_mask_expanded.sum(1), min=1e-9)
return sum_embeddings / sum_mask
response = requests.get("https://en.wikipedia.org/wiki/Chandler_Bing")
base_text = ''
if response:
html = bs4.BeautifulSoup(response.text, 'html.parser')
title = html.select("#firstHeading")[0].text
paragraphs = html.select("p")
for para in paragraphs:
base_text = base_text + para.text
fact_coh_tokenizer = AutoTokenizer.from_pretrained("DeepPavlov/bert-base-multilingual-cased-sentence")
fact_coh_model = AutoModel.from_pretrained("DeepPavlov/bert-base-multilingual-cased-sentence")
fact_coh_model.to(device)
nltk.download('punkt')
subsample_documents = split_documents([base_text])
batch_size = 8
total_batches = len(subsample_documents) // batch_size + (0 if len(subsample_documents) % batch_size == 0 else 1)
base = list()
for i in tqdm(range(0, len(subsample_documents), batch_size), total=total_batches, desc="Processing Batches"):
batch_texts = subsample_documents[i:i + batch_size]
base.extend(embed(batch_texts, fact_coh_model, fact_coh_tokenizer))
base = np.array([vector.cpu().numpy() for vector in base])
index = faiss.IndexFlatL2(base.shape[1])
index.add(base)
async def get_context(subsample_documents, query, index, model, tokenizer):
k = 5
xq = embed(query.lower(), model, tokenizer).cpu().numpy()
D, I = index.search(xq.reshape(1, 768), k)
return subsample_documents[I[0][0]]
async def get_prompt(question, use_rag, answers_history: list[str]):
eval_prompt = '###system: answer the question as Chandler. '
for idx, text in enumerate(answers_history):
if idx % 2 == 0:
eval_prompt = eval_prompt + f' ###question: {text}'
else:
eval_prompt = eval_prompt + f' ###answer: {text} '
if use_rag:
context = await asyncio.wait_for(get_context(subsample_documents, question, index, fact_coh_model, fact_coh_tokenizer), timeout=60)
eval_prompt = eval_prompt + f' Chandler. {context}'
eval_prompt = eval_prompt + f' ###question: {question} '
eval_prompt = ' '.join(eval_prompt.split())
return eval_prompt
async def get_answer(question, use_rag, answers_history: list[str]):
eval_prompt = await asyncio.wait_for(get_prompt(question, use_rag, answers_history), timeout=60)
model_input = tokenizer(eval_prompt, return_tensors="pt").to(device)
ft_model.eval()
with torch.no_grad():
answer = tokenizer.decode(ft_model.generate(**model_input, max_new_tokens=30, repetition_penalty=1.11)[0], skip_special_tokens=True) + '\n'
answer = ' '.join(answer.split())
if eval_prompt in answer:
answer = answer.replace(eval_prompt,'')
answer = answer.split('###answer')[1]
dialog = ''
for idx, text in enumerate(answers_history):
if idx % 2 == 0:
dialog = dialog + f'you: {text}\n'
else:
dialog = dialog + f'Chandler: {text}\n'
dialog = dialog + f'you: {question}\n'
dialog = dialog + f'Chandler: {answer}\n'
answers_history.append(question)
answers_history.append(answer)
return dialog, answers_history
async def async_proc(question, use_rag, answers_history: list[str]):
try:
return await asyncio.wait_for(get_answer(question, use_rag, answers_history), timeout=60)
except asyncio.TimeoutError:
return "Processing timed out.", answers_history
gr.Interface(
fn=async_proc,
inputs=[
gr.Textbox(
label="Question",
),
gr.Checkbox(label="Use RAG", info="Pick to RAG to improve factual coherence"),
gr.State(value=[]),
],
outputs=[
gr.Textbox(
label="Chat"
),
gr.State(),
],
title="Асинхронный сервис для чат-бота по сериалу Друзья",
concurrency_limit=5
).queue().launch(share=True, debug=True)