Phi3-ORPO / app.py
justinj92's picture
Update app.py
2a04ac7 verified
raw
history blame
2.52 kB
from llama_index.core import VectorStoreIndex,SimpleDirectoryReader,ServiceContext,SummaryIndex
from llama_index.llms.huggingface import HuggingFaceLLM
from llama_index.core import Settings
from llama_index.embeddings.huggingface import HuggingFaceEmbedding
import torch
import spaces
import subprocess
subprocess.run(
"pip install flash-attn --no-build-isolation",
env={"FLASH_ATTENTION_SKIP_CUDA_BUILD": "TRUE"},
shell=True,
)
documents = SimpleDirectoryReader("./data").load_data()
# vector_index = VectorStoreIndex.from_documents(documents)
summary_index = SummaryIndex.from_documents(documents)
def messages_to_prompt(messages):
prompt = ""
system_found = False
for message in messages:
if message.role == "system":
prompt += f"<|system|>\n{message.content}<|end|>\n"
system_found = True
elif message.role == "user":
prompt += f"<|user|>\n{message.content}<|end|>\n"
elif message.role == "assistant":
prompt += f"<|assistant|>\n{message.content}<|end|>\n"
else:
prompt += f"<|user|>\n{message.content}<|end|>\n"
# trailing prompt
prompt += "<|assistant|>\n"
if not system_found:
prompt = (
"<|system|>\nYou are a helpful AI research assistant built by Justin. You only answer from the context provided.<|end|>\n" + prompt
)
return prompt
llm = HuggingFaceLLM(
model_name="justinj92/phi3-orpo",
model_kwargs={
"trust_remote_code": True,
"torch_dtype": torch.bfloat16
},
generate_kwargs={"do_sample": True, "temperature": 0.7},
tokenizer_name="justinj92/phi3-orpo",
query_wrapper_prompt=(
"<|system|>\n"
"You are a helpful AI research assistant built by Justin. You only answer from the context provided.<|end|>\n"
"<|user|>\n"
"{query_str}<|end|>\n"
"<|assistant|>\n"
),
messages_to_prompt=messages_to_prompt,
is_chat_model=True,
)
Settings.llm = llm
Settings.embed_model = HuggingFaceEmbedding(
model_name="BAAI/bge-small-en-v1.5"
)
service_context = ServiceContext.from_defaults(
chunk_size=1024,
llm=llm,
embed_model=Settings.embed_model
)
index = VectorStoreIndex.from_documents(documents, service_context=service_context)
query_engine = index.as_query_engine()
@spaces.GPU
def predict(input, history):
response = query_engine.query(input)
return str(response)
import gradio as gr
gr.ChatInterface(predict).launch(share=True)