Spaces:
Running
Running
import os | |
os.system("pip install gradio==4.44.1") | |
import gradio as gr | |
from transformers import AutoTokenizer, AutoModelForSeq2SeqLM, pipeline | |
from datasets import concatenate_datasets, load_dataset | |
import gc | |
from peft import PeftModel, PeftConfig | |
from langchain.chains import RetrievalQA | |
from langchain_community.vectorstores import FAISS | |
from langchain.embeddings import OpenAIEmbeddings | |
from langchain.docstore.document import Document | |
from langchain.llms import HuggingFacePipeline | |
from langchain.embeddings import HuggingFaceEmbeddings | |
from langchain.chains.question_answering import load_qa_chain | |
from langchain.prompts import PromptTemplate | |
import torch | |
import random | |
from langchain.document_loaders import WebBaseLoader | |
from langchain.text_splitter import CharacterTextSplitter | |
from langchain.memory import ConversationBufferMemory | |
import requests | |
import re | |
device = torch.device("cuda" if torch.cuda.is_available() else "cpu") | |
# Load Samsum dataset for generating questions | |
train_dataset = load_dataset("samsum", split='train', trust_remote_code=True) | |
val_dataset = load_dataset("samsum", split='validation', trust_remote_code=True) | |
samsum_dataset = concatenate_datasets([train_dataset, val_dataset]) | |
model_name = "google/flan-t5-base" | |
tokenizer = AutoTokenizer.from_pretrained(model_name) | |
base_model = AutoModelForSeq2SeqLM.from_pretrained(model_name, torch_dtype=torch.bfloat16).to(device) | |
rlhf_model_path = "raghav-gaggar/PEFT_RLHF_TextSummarizer" | |
config = PeftConfig.from_pretrained(rlhf_model_path) | |
ppo_model = PeftModel.from_pretrained(base_model, rlhf_model_path).to(device) | |
merged_model = ppo_model.merge_and_unload().to(device) | |
base_model.eval() | |
ppo_model.eval() | |
merged_model.eval() | |
dialogsum_dataset = load_dataset("knkarthick/dialogsum", trust_remote_code=True) | |
def format_dialogsum_as_document(example): | |
return Document(page_content=f"Dialogue:\n {example['dialogue']}\n\nSummary: {example['summary']}") | |
# Create documents from DialogSum dataset | |
documents = [] | |
for split in ['train', 'validation', 'test']: | |
documents.extend([format_dialogsum_as_document(example) for example in dialogsum_dataset[split]]) | |
# Split the documents into chunks | |
text_splitter = CharacterTextSplitter(chunk_size=5200, chunk_overlap=0) | |
docs = text_splitter.split_documents(documents) | |
# Create embeddings and vector store for DialogSum documents | |
embeddings = HuggingFaceEmbeddings( | |
model_name="sentence-transformers/all-MiniLM-L6-v2", | |
model_kwargs={"device": "cuda" if torch.cuda.is_available() else "cpu"}, | |
encode_kwargs={"batch_size": 32} | |
) | |
vector_store = FAISS.from_documents(docs, embeddings) | |
# Initialize retriever for DialogSum documents | |
retriever = vector_store.as_retriever(search_kwargs={"k": 1}) | |
prompt_template = """ | |
Concisely summarize the dialogue in the end, like the example provided - | |
Example - | |
{context} | |
Dialogue to be summarized: | |
{question} | |
Summary:""" | |
PROMPT = PromptTemplate( | |
template=prompt_template, input_variables=["context", "question"] | |
) | |
# Create a Hugging Face pipeline | |
summarization_pipeline = pipeline( | |
"summarization", | |
model=merged_model, | |
tokenizer=tokenizer, | |
max_length=150, | |
min_length=20, | |
do_sample=False, | |
) | |
# Wrap the pipeline in a LangChain LLM | |
llm = HuggingFacePipeline(pipeline=summarization_pipeline) | |
qa_chain = RetrievalQA.from_chain_type( | |
llm, retriever=retriever, chain_type_kwargs={"prompt": PROMPT} | |
) | |
# Function for Gradio interface | |
def summarize_conversation(question): | |
result = qa_chain({"query": question}) | |
return result["result"] | |
examples = [["Amanda: I baked cookies. Do you want some? \nJerry: Sure! \nAmanda: I'll bring you tomorrow :-)"]] | |
# Create Gradio interface | |
iface = gr.Interface( | |
fn=summarize_conversation, | |
inputs=gr.Textbox(lines=10, label="Enter conversation here"), | |
outputs=gr.Textbox(label="Summary"), | |
title="Conversation Summarizer", | |
description="Enter a conversation, and the AI will provide a concise summary.", | |
examples = examples | |
) | |
# Launch the app | |
iface.launch() |