|
from llama_index.core import VectorStoreIndex,SimpleDirectoryReader,ServiceContext,SummaryIndex |
|
from llama_index.llms.huggingface import HuggingFaceLLM |
|
from llama_index.core import Settings |
|
from llama_index.embeddings.huggingface import HuggingFaceEmbedding |
|
import torch |
|
import spaces |
|
import subprocess |
|
|
|
subprocess.run( |
|
"pip install flash-attn --no-build-isolation", |
|
env={"FLASH_ATTENTION_SKIP_CUDA_BUILD": "TRUE"}, |
|
shell=True, |
|
) |
|
|
|
|
|
documents = SimpleDirectoryReader("./data").load_data() |
|
|
|
summary_index = SummaryIndex.from_documents(documents) |
|
|
|
def messages_to_prompt(messages): |
|
prompt = "" |
|
system_found = False |
|
for message in messages: |
|
if message.role == "system": |
|
prompt += f"<|system|>\n{message.content}<|end|>\n" |
|
system_found = True |
|
elif message.role == "user": |
|
prompt += f"<|user|>\n{message.content}<|end|>\n" |
|
elif message.role == "assistant": |
|
prompt += f"<|assistant|>\n{message.content}<|end|>\n" |
|
else: |
|
prompt += f"<|user|>\n{message.content}<|end|>\n" |
|
|
|
|
|
prompt += "<|assistant|>\n" |
|
|
|
if not system_found: |
|
prompt = ( |
|
"<|system|>\nYou are a helpful AI research assistant built by Justin. You only answer from the context provided.<|end|>\n" + prompt |
|
) |
|
|
|
return prompt |
|
|
|
llm = HuggingFaceLLM( |
|
model_name="justinj92/phi3-orpo", |
|
model_kwargs={ |
|
"trust_remote_code": True, |
|
"torch_dtype": torch.bfloat16 |
|
}, |
|
generate_kwargs={"do_sample": True, "temperature": 0.7}, |
|
tokenizer_name="justinj92/phi3-orpo", |
|
query_wrapper_prompt=( |
|
"<|system|>\n" |
|
"You are a helpful AI research assistant built by Justin. You only answer from the context provided.<|end|>\n" |
|
"<|user|>\n" |
|
"{query_str}<|end|>\n" |
|
"<|assistant|>\n" |
|
), |
|
messages_to_prompt=messages_to_prompt, |
|
is_chat_model=True, |
|
) |
|
|
|
Settings.llm = llm |
|
Settings.embed_model = HuggingFaceEmbedding( |
|
model_name="BAAI/bge-small-en-v1.5" |
|
) |
|
|
|
service_context = ServiceContext.from_defaults( |
|
chunk_size=1024, |
|
llm=llm, |
|
embed_model=Settings.embed_model |
|
) |
|
|
|
index = VectorStoreIndex.from_documents(documents, service_context=service_context) |
|
|
|
query_engine = index.as_query_engine() |
|
|
|
@spaces.GPU |
|
def predict(input, history): |
|
response = query_engine.query(input) |
|
return str(response) |
|
|
|
import gradio as gr |
|
gr.ChatInterface(predict).launch(share=True) |
|
|