Spaces:
Runtime error
Runtime error
# Import modules | |
import os | |
import torch | |
import gradio as gr | |
from langchain_community.llms import HuggingFacePipeline | |
from langchain_community.embeddings.sentence_transformer import SentenceTransformerEmbeddings | |
from langchain_community.vectorstores import Chroma | |
from langchain_core.runnables import RunnablePassthrough | |
from langchain_core.prompts import PromptTemplate | |
from peft import PeftModel | |
from transformers import AutoModelForCausalLM, AutoTokenizer, BitsAndBytesConfig, GenerationConfig, pipeline | |
HUGGINGFACE_ACCESS_TOKEN = os.environ["HUGGINGFACE_ACCESS_TOKEN"] | |
base_model = "microsoft/phi-2" | |
# Define the embedding function | |
# I use the "all-MiniLM-L6-v2" model | |
embedding_function = SentenceTransformerEmbeddings( | |
model_name="all-MiniLM-L6-v2", | |
model_kwargs={"device": "cuda"}, # Use the GPU | |
) | |
tokenizer = AutoTokenizer.from_pretrained( | |
base_model, | |
use_fast=True, | |
token=HUGGINGFACE_ACCESS_TOKEN, | |
) | |
tokenizer.pad_token = tokenizer.eos_token | |
tokenizer.padding_side = "right" | |
bnb_config = BitsAndBytesConfig( | |
load_in_4bit=True, | |
bnb_4bit_quant_type="nf4", | |
bnb_4bit_compute_dtype=torch.float16, | |
bnb_4bit_use_double_quant=False, | |
) | |
# Load the fine-tuned model by merging the base model and the adapter | |
# (checkpointed at 1 epoch = 77 steps) | |
adapter = "./results/checkpoint-77" | |
model = AutoModelForCausalLM.from_pretrained( | |
base_model, | |
quantization_config=bnb_config, | |
trust_remote_code=True, | |
device_map={"": 0}, | |
token=HUGGINGFACE_ACCESS_TOKEN, | |
) | |
model_ft = PeftModel.from_pretrained(model, adapter) | |
# For inference, use a text-generation pipeline | |
# NOTE: you could get a warning such as "The model 'PeftModelForCausalLM' is not | |
# supported for text-generation", but it's not a problem | |
config = GenerationConfig(max_new_tokens=200) | |
pipe = pipeline( | |
"text-generation", | |
model=model_ft, | |
tokenizer=tokenizer, | |
generation_config=config, | |
framework="pt", | |
) | |
""" | |
NOTE: Although not strictly required by the assignment, considering that for | |
Point 1 we created the embeddings of the emails and saved them in Chroma, it is | |
trivial to add a simple RAG system. Basically, when a question is asked, some | |
emails (or part of them) similar to the question are also sent to the model as | |
context. | |
""" | |
# Load the saved database | |
persist_directory = "./chroma_db" | |
db = Chroma( | |
persist_directory=persist_directory, | |
embedding_function=embedding_function, | |
) | |
# Setup a retriever so that we get the 2 most similar texts | |
retriever = db.as_retriever(search_type="similarity", search_kwargs={"k": 2}) | |
# Wrap the Hugging Face pipeline for langchain | |
llm = HuggingFacePipeline(pipeline=pipe) | |
# This is the template we will use for the text to submit to the model. | |
# In place of {context} will be inserted the context sentences retrieved from | |
# the RAG system, and in place of {question} will be inserted the question. | |
template = """Instruct: | |
You are an AI assistant for answering questions about the provided context. | |
You are given the following extracted parts of a document database and a question. Provide a short answer. | |
If you don't know the answer, just say "Hmm, I'm not sure." Don't try to make up an answer. | |
======= | |
{context} | |
======= | |
Question: {question} | |
Output:""" | |
custom_rag_prompt = PromptTemplate.from_template(template) | |
def format_docs(docs): | |
# Separates retrieved texts with a double return character | |
return "\n\n".join(doc.page_content for doc in docs) | |
# RAG pipeline | |
rag_chain = ( | |
{"context": retriever | format_docs, "question": RunnablePassthrough()} | |
| custom_rag_prompt | |
| llm | |
) | |
def get_answer(question): | |
if not question.strip(): | |
return "Please enter a question." | |
try: | |
# Submit the question to the pipeline and extract the output | |
answer = rag_chain.invoke(question).split("Output:")[1].strip() | |
except Exception as e: | |
answer = str(e) | |
return answer | |
# Define and launch the Gradio interface | |
interface = gr.Interface( | |
fn=get_answer, | |
inputs=gr.Textbox(label="Enter your question"), | |
outputs=gr.Textbox(label="Answer"), | |
title="Enron QA", | |
examples=[ | |
["What is the strategy in agricultural commodities training?"] | |
], | |
) | |
interface.launch() | |