|
import os |
|
import pickle |
|
from json import dumps, loads |
|
from typing import Any, List, Mapping, Optional |
|
|
|
import numpy as np |
|
import openai |
|
import streamlit as st |
|
import pandas as pd |
|
from dotenv import load_dotenv |
|
from huggingface_hub import HfFileSystem |
|
from langchain.llms.base import LLM |
|
from llama_index import ( |
|
Document, |
|
GPTVectorStoreIndex, |
|
LLMPredictor, |
|
PromptHelper, |
|
ServiceContext, |
|
SimpleDirectoryReader, |
|
StorageContext, |
|
load_index_from_storage, |
|
) |
|
from transformers import AutoModelForCausalLM, AutoTokenizer, pipeline |
|
|
|
|
|
|
|
load_dotenv() |
|
|
|
fs = HfFileSystem() |
|
|
|
|
|
|
|
CONTEXT_WINDOW = 2048 |
|
|
|
NUM_OUTPUT = 525 |
|
|
|
CHUNK_OVERLAP_RATION = 0.2 |
|
|
|
llm_model_name = "bigscience/bloom-560m" |
|
tokenizer = AutoTokenizer.from_pretrained(llm_model_name) |
|
model = AutoModelForCausalLM.from_pretrained(llm_model_name, config="T5Config") |
|
|
|
model_pipeline = pipeline( |
|
model=model, |
|
tokenizer=tokenizer, |
|
task="text-generation", |
|
|
|
|
|
do_sample=True, |
|
top_p=0.95, |
|
top_k=50, |
|
temperature=0.7, |
|
) |
|
|
|
|
|
class CustomLLM(LLM): |
|
pipeline = model_pipeline |
|
|
|
def _call(self, prompt: str, stop: Optional[List[str]] = None) -> str: |
|
prompt_length = len(prompt) |
|
response = self.pipeline(prompt, max_new_tokens=525)[0]["generated_text"] |
|
|
|
|
|
return response[prompt_length:] |
|
|
|
@property |
|
def _identifying_params(self) -> Mapping[str, Any]: |
|
return {"name_of_model": llm_model_name} |
|
|
|
@property |
|
def _llm_type(self) -> str: |
|
return "custom" |
|
|
|
class LlamaCustom: |
|
|
|
def __init__(self, model_name: str) -> None: |
|
self.vector_index = self.initialize_index(model_name=model_name) |
|
|
|
def initialize_index(self, model_name: str): |
|
index_name = model_name.split("/")[-1] |
|
|
|
file_path = f"./vectorStores/{index_name}" |
|
if os.path.exists(path=file_path): |
|
|
|
storage_context = StorageContext.from_defaults(persist_dir=file_path) |
|
|
|
|
|
index = load_index_from_storage(storage_context) |
|
|
|
|
|
|
|
|
|
return index |
|
else: |
|
|
|
prompt_helper = PromptHelper( |
|
context_window=CONTEXT_WINDOW, |
|
num_output=NUM_OUTPUT, |
|
chunk_overlap_ratio=CHUNK_OVERLAP_RATION, |
|
) |
|
llm_predictor = LLMPredictor(llm=CustomLLM()) |
|
service_context = ServiceContext.from_defaults( |
|
llm_predictor=llm_predictor, prompt_helper=prompt_helper |
|
) |
|
|
|
|
|
documents = SimpleDirectoryReader(input_dir="./assets/pdf").load_data() |
|
|
|
index = GPTVectorStoreIndex.from_documents( |
|
documents, service_context=service_context |
|
) |
|
|
|
|
|
index.storage_context.persist(file_path) |
|
|
|
|
|
|
|
|
|
return index |
|
|
|
def get_response(self, query_str): |
|
print("query_str: ", query_str) |
|
query_engine = self.vector_index.as_query_engine() |
|
response = query_engine.query(query_str) |
|
return str(response) |