Spaces:
Running
Running
import glob | |
import inspect | |
import os | |
import pathlib | |
import pickle | |
import shutil | |
import subprocess | |
import sys | |
import tempfile | |
import traceback | |
import uuid | |
import zipfile | |
from collections import defaultdict | |
from datetime import datetime | |
from functools import reduce | |
from operator import concat | |
from joblib import Parallel, delayed | |
from utils import wrapped_partial, EThread, import_matplotlib, sanitize_filename, makedirs, get_url, flatten_list, \ | |
get_device | |
import_matplotlib() | |
import numpy as np | |
import pandas as pd | |
import requests | |
from langchain.chains.qa_with_sources import load_qa_with_sources_chain | |
# , GCSDirectoryLoader, GCSFileLoader | |
# , OutlookMessageLoader # GPL3 | |
# ImageCaptionLoader, # use our own wrapper | |
# ReadTheDocsLoader, # no special file, some path, so have to give as special option | |
from langchain.document_loaders import PyPDFLoader, TextLoader, CSVLoader, PythonLoader, TomlLoader, \ | |
UnstructuredURLLoader, UnstructuredHTMLLoader, UnstructuredWordDocumentLoader, UnstructuredMarkdownLoader, \ | |
EverNoteLoader, UnstructuredEmailLoader, UnstructuredODTLoader, UnstructuredPowerPointLoader, \ | |
UnstructuredEPubLoader, UnstructuredImageLoader, UnstructuredRTFLoader, ArxivLoader | |
from langchain.text_splitter import RecursiveCharacterTextSplitter | |
from langchain.vectorstores import FAISS | |
from langchain.chains.question_answering import load_qa_chain | |
from langchain.docstore.document import Document | |
from langchain import PromptTemplate | |
from langchain.vectorstores import Chroma | |
def get_db(sources, use_openai_embedding=False, db_type='faiss', persist_directory="db_dir", langchain_mode='notset', | |
hf_embedding_model="sentence-transformers/all-MiniLM-L6-v2"): | |
if not sources: | |
return None | |
# get embedding model | |
embedding = get_embedding(use_openai_embedding, hf_embedding_model=hf_embedding_model) | |
# Create vector database | |
if db_type == 'faiss': | |
db = FAISS.from_documents(sources, embedding) | |
elif db_type == 'chroma': | |
collection_name = langchain_mode.replace(' ', '_') | |
os.makedirs(persist_directory, exist_ok=True) | |
db = Chroma.from_documents(documents=sources, | |
embedding=embedding, | |
persist_directory=persist_directory, | |
collection_name=collection_name, | |
anonymized_telemetry=False) | |
db.persist() | |
# FIXME: below just proves can load persistent dir, regenerates its embedding files, so a bit wasteful | |
if False: | |
db = Chroma(embedding_function=embedding, | |
persist_directory=persist_directory, | |
collection_name=collection_name) | |
else: | |
raise RuntimeError("No such db_type=%s" % db_type) | |
return db | |
def add_to_db(db, sources, db_type='faiss', avoid_dup=True): | |
if not sources: | |
return db | |
if db_type == 'faiss': | |
db.add_documents(sources) | |
elif db_type == 'chroma': | |
if avoid_dup: | |
collection = db.get() | |
metadata_sources = set([x['source'] for x in collection['metadatas']]) | |
sources = [x for x in sources if x.metadata['source'] not in metadata_sources] | |
if len(sources) == 0: | |
return db | |
db.add_documents(documents=sources) | |
db.persist() | |
else: | |
raise RuntimeError("No such db_type=%s" % db_type) | |
return db | |
def get_embedding(use_openai_embedding, hf_embedding_model="sentence-transformers/all-MiniLM-L6-v2"): | |
# Get embedding model | |
if use_openai_embedding: | |
assert os.getenv("OPENAI_API_KEY") is not None, "Set ENV OPENAI_API_KEY" | |
from langchain.embeddings import OpenAIEmbeddings | |
embedding = OpenAIEmbeddings() | |
else: | |
# to ensure can fork without deadlock | |
from langchain.embeddings import HuggingFaceEmbeddings | |
device, torch_dtype, context_class = get_device_dtype() | |
model_kwargs = dict(device=device) | |
embedding = HuggingFaceEmbeddings(model_name=hf_embedding_model, model_kwargs=model_kwargs) | |
return embedding | |
def get_answer_from_sources(chain, sources, question): | |
return chain( | |
{ | |
"input_documents": sources, | |
"question": question, | |
}, | |
return_only_outputs=True, | |
)["output_text"] | |
def get_llm(use_openai_model=False, model_name=None, model=None, | |
tokenizer=None, stream_output=False, | |
max_new_tokens=256, | |
temperature=0.1, | |
repetition_penalty=1.0, | |
top_k=40, | |
top_p=0.7, | |
prompt_type=None, | |
): | |
if use_openai_model: | |
from langchain.llms import OpenAI | |
llm = OpenAI(temperature=0) | |
model_name = 'openai' | |
streamer = None | |
elif model_name in ['gptj', 'llama']: | |
from gpt4all_llm import get_llm_gpt4all | |
llm = get_llm_gpt4all(model_name, model=model, max_new_tokens=max_new_tokens, | |
temperature=temperature, | |
repetition_penalty=repetition_penalty, | |
top_k=top_k, | |
top_p=top_p, | |
) | |
streamer = None | |
prompt_type = 'plain' | |
else: | |
from transformers import AutoTokenizer, AutoModelForCausalLM | |
if model is None: | |
# only used if didn't pass model in | |
assert model_name is None | |
assert tokenizer is None | |
model_name = 'h2oai/h2ogpt-oasst1-512-12b' | |
# model_name = 'h2oai/h2ogpt-oig-oasst1-512-6.9b' | |
# model_name = 'h2oai/h2ogpt-oasst1-512-20b' | |
tokenizer = AutoTokenizer.from_pretrained(model_name) | |
device, torch_dtype, context_class = get_device_dtype() | |
with context_class(device): | |
load_8bit = True | |
# FIXME: for now not to spread across hetero GPUs | |
# device_map={"": 0} if load_8bit and device == 'cuda' else "auto" | |
device_map = {"": 0} if device == 'cuda' else "auto" | |
model = AutoModelForCausalLM.from_pretrained(model_name, | |
device_map=device_map, | |
torch_dtype=torch_dtype, | |
load_in_8bit=load_8bit) | |
gen_kwargs = dict(max_new_tokens=max_new_tokens, return_full_text=True, early_stopping=False) | |
if stream_output: | |
skip_prompt = False | |
from generate import H2OTextIteratorStreamer | |
decoder_kwargs = {} | |
streamer = H2OTextIteratorStreamer(tokenizer, skip_prompt=skip_prompt, block=False, **decoder_kwargs) | |
gen_kwargs.update(dict(streamer=streamer)) | |
else: | |
streamer = None | |
if 'h2ogpt' in model_name or prompt_type == 'human_bot': | |
from h2oai_pipeline import H2OTextGenerationPipeline | |
pipe = H2OTextGenerationPipeline(model=model, tokenizer=tokenizer, **gen_kwargs) | |
# pipe.task = "text-generation" | |
# below makes it listen only to our prompt removal, not built in prompt removal that is less general and not specific for our model | |
pipe.task = "text2text-generation" | |
prompt_type = 'human_bot' | |
else: | |
# only for non-instruct tuned cases when ok with just normal next token prediction | |
from transformers import pipeline | |
pipe = pipeline("text-generation", model=model, tokenizer=tokenizer, **gen_kwargs) | |
from langchain.llms import HuggingFacePipeline | |
llm = HuggingFacePipeline(pipeline=pipe) | |
return llm, model_name, streamer, prompt_type | |
def get_device_dtype(): | |
# torch.device("cuda") leads to cuda:x cuda:y mismatches for multi-GPU consistently | |
import torch | |
n_gpus = torch.cuda.device_count() if torch.cuda.is_available else 0 | |
device = 'cpu' if n_gpus == 0 else 'cuda' | |
# from utils import NullContext | |
# context_class = NullContext if n_gpus > 1 or n_gpus == 0 else context_class | |
context_class = torch.device | |
torch_dtype = torch.float16 if device == 'cuda' else torch.float32 | |
return device, torch_dtype, context_class | |
def get_wiki_data(title, first_paragraph_only, text_limit=None, take_head=True): | |
""" | |
Get wikipedia data from online | |
:param title: | |
:param first_paragraph_only: | |
:param text_limit: | |
:param take_head: | |
:return: | |
""" | |
filename = 'wiki_%s_%s_%s_%s.data' % (first_paragraph_only, title, text_limit, take_head) | |
url = f"https://en.wikipedia.org/w/api.php?format=json&action=query&prop=extracts&explaintext=1&titles={title}" | |
if first_paragraph_only: | |
url += "&exintro=1" | |
import json | |
if not os.path.isfile(filename): | |
data = requests.get(url).json() | |
json.dump(data, open(filename, 'wt')) | |
else: | |
data = json.load(open(filename, "rt")) | |
page_content = list(data["query"]["pages"].values())[0]["extract"] | |
if take_head is not None and text_limit is not None: | |
page_content = page_content[:text_limit] if take_head else page_content[:-text_limit] | |
title_url = str(title).replace(' ', '_') | |
return Document( | |
page_content=page_content, | |
metadata={"source": f"https://en.wikipedia.org/wiki/{title_url}"}, | |
) | |
def get_wiki_sources(first_para=True, text_limit=None): | |
""" | |
Get specific named sources from wikipedia | |
:param first_para: | |
:param text_limit: | |
:return: | |
""" | |
default_wiki_sources = ['Unix', 'Microsoft_Windows', 'Linux'] | |
wiki_sources = list(os.getenv('WIKI_SOURCES', default_wiki_sources)) | |
return [get_wiki_data(x, first_para, text_limit=text_limit) for x in wiki_sources] | |
def get_github_docs(repo_owner, repo_name): | |
""" | |
Access github from specific repo | |
:param repo_owner: | |
:param repo_name: | |
:return: | |
""" | |
with tempfile.TemporaryDirectory() as d: | |
subprocess.check_call( | |
f"git clone --depth 1 https://github.com/{repo_owner}/{repo_name}.git .", | |
cwd=d, | |
shell=True, | |
) | |
git_sha = ( | |
subprocess.check_output("git rev-parse HEAD", shell=True, cwd=d) | |
.decode("utf-8") | |
.strip() | |
) | |
repo_path = pathlib.Path(d) | |
markdown_files = list(repo_path.glob("*/*.md")) + list( | |
repo_path.glob("*/*.mdx") | |
) | |
for markdown_file in markdown_files: | |
with open(markdown_file, "r") as f: | |
relative_path = markdown_file.relative_to(repo_path) | |
github_url = f"https://github.com/{repo_owner}/{repo_name}/blob/{git_sha}/{relative_path}" | |
yield Document(page_content=f.read(), metadata={"source": github_url}) | |
def get_dai_pickle(dest="."): | |
from huggingface_hub import hf_hub_download | |
# True for case when locally already logged in with correct token, so don't have to set key | |
token = os.getenv('HUGGINGFACE_API_TOKEN', True) | |
path_to_zip_file = hf_hub_download('h2oai/dai_docs', 'dai_docs.pickle', token=token, repo_type='dataset') | |
shutil.copy(path_to_zip_file, dest) | |
def get_dai_docs(from_hf=False, get_pickle=True): | |
""" | |
Consume DAI documentation, or consume from public pickle | |
:param from_hf: get DAI docs from HF, then generate pickle for later use by LangChain | |
:param get_pickle: Avoid raw DAI docs, just get pickle directly from HF | |
:return: | |
""" | |
import pickle | |
if get_pickle: | |
get_dai_pickle() | |
dai_store = 'dai_docs.pickle' | |
dst = "working_dir_docs" | |
if not os.path.isfile(dai_store): | |
from create_data import setup_dai_docs | |
dst = setup_dai_docs(dst=dst, from_hf=from_hf) | |
import glob | |
files = list(glob.glob(os.path.join(dst, '*rst'), recursive=True)) | |
basedir = os.path.abspath(os.getcwd()) | |
from create_data import rst_to_outputs | |
new_outputs = rst_to_outputs(files) | |
os.chdir(basedir) | |
pickle.dump(new_outputs, open(dai_store, 'wb')) | |
else: | |
new_outputs = pickle.load(open(dai_store, 'rb')) | |
sources = [] | |
for line, file in new_outputs: | |
# gradio requires any linked file to be with app.py | |
sym_src = os.path.abspath(os.path.join(dst, file)) | |
sym_dst = os.path.abspath(os.path.join(os.getcwd(), file)) | |
if os.path.lexists(sym_dst): | |
os.remove(sym_dst) | |
os.symlink(sym_src, sym_dst) | |
itm = Document(page_content=line, metadata={"source": file}) | |
# NOTE: yield has issues when going into db, loses metadata | |
# yield itm | |
sources.append(itm) | |
return sources | |
import distutils.spawn | |
have_tesseract = distutils.spawn.find_executable("tesseract") | |
have_libreoffice = distutils.spawn.find_executable("libreoffice") | |
import pkg_resources | |
try: | |
assert pkg_resources.get_distribution('arxiv') is not None | |
assert pkg_resources.get_distribution('pymupdf') is not None | |
have_arxiv = True | |
except (pkg_resources.DistributionNotFound, AssertionError): | |
have_arxiv = False | |
image_types = ["png", "jpg", "jpeg"] | |
non_image_types = ["pdf", "txt", "csv", "toml", "py", "rst", "rtf", | |
"md", "html", | |
"enex", "eml", "epub", "odt", "pptx", "ppt", | |
"zip", "urls", | |
] | |
# "msg", GPL3 | |
if have_libreoffice: | |
non_image_types.extend(["docx", "doc"]) | |
file_types = non_image_types + image_types | |
def add_meta(docs1, file): | |
file_extension = pathlib.Path(file).suffix | |
if not isinstance(docs1, list): | |
docs1 = [docs1] | |
[x.metadata.update(dict(input_type=file_extension, date=str(datetime.now))) for x in docs1] | |
def file_to_doc(file, base_path=None, verbose=False, fail_any_exception=False, chunk=True, chunk_size=512, | |
is_url=False, is_txt=False, | |
enable_captions=True, | |
captions_model=None, | |
enable_ocr=False, caption_loader=None, | |
headsize=50): | |
if file is None: | |
if fail_any_exception: | |
raise RuntimeError("Unexpected None file") | |
else: | |
return [] | |
doc1 = [] # in case no support, or disabled support | |
if base_path is None and not is_txt and not is_url: | |
# then assume want to persist but don't care which path used | |
# can't be in base_path | |
dir_name = os.path.dirname(file) | |
base_name = os.path.basename(file) | |
# if from gradio, will have its own temp uuid too, but that's ok | |
base_name = sanitize_filename(base_name) + "_" + str(uuid.uuid4())[:10] | |
base_path = os.path.join(dir_name, base_name) | |
if is_url: | |
if file.lower().startswith('arxiv:'): | |
query = file.lower().split('arxiv:') | |
if len(query) == 2 and have_arxiv: | |
query = query[1] | |
docs1 = ArxivLoader(query=query, load_max_docs=20, load_all_available_meta=True).load() | |
# ensure string, sometimes None | |
[[x.metadata.update({k: str(v)}) for k, v in x.metadata.items()] for x in docs1] | |
query_url = f"https://arxiv.org/abs/{query}" | |
[x.metadata.update( | |
dict(source=x.metadata.get('entry_id', query_url), query=query_url, | |
input_type='arxiv', head=x.metadata.get('Title', ''), date=str(datetime.now))) for x in | |
docs1] | |
else: | |
docs1 = [] | |
else: | |
docs1 = UnstructuredURLLoader(urls=[file]).load() | |
[x.metadata.update(dict(input_type='url', date=str(datetime.now))) for x in docs1] | |
doc1 = chunk_sources(docs1, chunk_size=chunk_size) | |
elif is_txt: | |
base_path = "user_paste" | |
source_file = os.path.join(base_path, "_%s" % str(uuid.uuid4())[:10]) | |
makedirs(os.path.dirname(source_file), exist_ok=True) | |
with open(source_file, "wt") as f: | |
f.write(file) | |
metadata = dict(source=source_file, date=str(datetime.now()), input_type='pasted txt') | |
doc1 = Document(page_content=file, metadata=metadata) | |
elif file.endswith('.html') or file.endswith('.mhtml'): | |
docs1 = UnstructuredHTMLLoader(file_path=file).load() | |
add_meta(docs1, file) | |
doc1 = chunk_sources(docs1, chunk_size=chunk_size) | |
elif (file.endswith('.docx') or file.endswith('.doc')) and have_libreoffice: | |
docs1 = UnstructuredWordDocumentLoader(file_path=file).load() | |
add_meta(docs1, file) | |
doc1 = chunk_sources(docs1, chunk_size=chunk_size) | |
elif file.endswith('.odt'): | |
docs1 = UnstructuredODTLoader(file_path=file).load() | |
add_meta(docs1, file) | |
doc1 = chunk_sources(docs1, chunk_size=chunk_size) | |
elif file.endswith('pptx') or file.endswith('ppt'): | |
docs1 = UnstructuredPowerPointLoader(file_path=file).load() | |
add_meta(docs1, file) | |
doc1 = chunk_sources(docs1, chunk_size=chunk_size) | |
elif file.endswith('.txt'): | |
# use UnstructuredFileLoader ? | |
doc1 = TextLoader(file, encoding="utf8", autodetect_encoding=True).load() | |
add_meta(doc1, file) | |
elif file.endswith('.rtf'): | |
docs1 = UnstructuredRTFLoader(file).load() | |
add_meta(docs1, file) | |
doc1 = chunk_sources(docs1, chunk_size=chunk_size) | |
elif file.endswith('.md'): | |
docs1 = UnstructuredMarkdownLoader(file).load() | |
add_meta(docs1, file) | |
doc1 = chunk_sources(docs1, chunk_size=chunk_size) | |
elif file.endswith('.enex'): | |
doc1 = EverNoteLoader(file).load() | |
add_meta(doc1, file) | |
elif file.endswith('.epub'): | |
docs1 = UnstructuredEPubLoader(file).load() | |
add_meta(docs1, file) | |
doc1 = chunk_sources(docs1, chunk_size=chunk_size) | |
elif file.endswith('.jpeg') or file.endswith('.jpg') or file.endswith('.png'): | |
docs1 = [] | |
if have_tesseract and enable_ocr: | |
# OCR, somewhat works, but not great | |
docs1.extend(UnstructuredImageLoader(file).load()) | |
add_meta(docs1, file) | |
if enable_captions: | |
# BLIP | |
if caption_loader is not None and not isinstance(caption_loader, (str, bool)): | |
# assumes didn't fork into this process with joblib, else can deadlock | |
caption_loader.set_image_paths([file]) | |
docs1c = caption_loader.load() | |
add_meta(docs1c, file) | |
[x.metadata.update(dict(head=x.page_content[:headsize].strip())) for x in docs1c] | |
docs1.extend(docs1c) | |
else: | |
from image_captions import H2OImageCaptionLoader | |
caption_loader = H2OImageCaptionLoader(caption_gpu=caption_loader == 'gpu', | |
blip_model=captions_model, | |
blip_processor=captions_model) | |
caption_loader.set_image_paths([file]) | |
docs1c = caption_loader.load() | |
add_meta(docs1c, file) | |
[x.metadata.update(dict(head=x.page_content[:headsize].strip())) for x in docs1c] | |
docs1.extend(docs1c) | |
for doci in docs1: | |
doci.metadata['source'] = doci.metadata['image_path'] | |
if docs1: | |
doc1 = chunk_sources(docs1, chunk_size=chunk_size) | |
elif file.endswith('.msg'): | |
raise RuntimeError("Not supported, GPL3 license") | |
# docs1 = OutlookMessageLoader(file).load() | |
# docs1[0].metadata['source'] = file | |
elif file.endswith('.eml'): | |
try: | |
docs1 = UnstructuredEmailLoader(file).load() | |
add_meta(docs1, file) | |
doc1 = chunk_sources(docs1, chunk_size=chunk_size) | |
except ValueError as e: | |
if 'text/html content not found in email' in str(e): | |
# e.g. plain/text dict key exists, but not | |
# doc1 = TextLoader(file, encoding="utf8").load() | |
docs1 = UnstructuredEmailLoader(file, content_source="text/plain").load() | |
add_meta(docs1, file) | |
doc1 = chunk_sources(docs1, chunk_size=chunk_size) | |
else: | |
raise | |
# elif file.endswith('.gcsdir'): | |
# doc1 = GCSDirectoryLoader(project_name, bucket, prefix).load() | |
# elif file.endswith('.gcsfile'): | |
# doc1 = GCSFileLoader(project_name, bucket, blob).load() | |
elif file.endswith('.rst'): | |
with open(file, "r") as f: | |
doc1 = Document(page_content=f.read(), metadata={"source": file}) | |
add_meta(doc1, file) | |
elif file.endswith('.pdf'): | |
# Some PDFs return nothing or junk from PDFMinerLoader | |
# e.g. Beyond fine-tuning_ Classifying high resolution mammograms using function-preserving transformations _ Elsevier Enhanced Reader.pdf | |
doc1 = PyPDFLoader(file).load_and_split() | |
add_meta(doc1, file) | |
elif file.endswith('.csv'): | |
doc1 = CSVLoader(file).load() | |
add_meta(doc1, file) | |
elif file.endswith('.py'): | |
doc1 = PythonLoader(file).load() | |
add_meta(doc1, file) | |
elif file.endswith('.toml'): | |
doc1 = TomlLoader(file).load() | |
add_meta(doc1, file) | |
elif file.endswith('.urls'): | |
with open(file, "r") as f: | |
docs1 = UnstructuredURLLoader(urls=f.readlines()).load() | |
add_meta(docs1, file) | |
doc1 = chunk_sources(docs1, chunk_size=chunk_size) | |
elif file.endswith('.zip'): | |
with zipfile.ZipFile(file, 'r') as zip_ref: | |
# don't put into temporary path, since want to keep references to docs inside zip | |
# so just extract in path where | |
zip_ref.extractall(base_path) | |
# recurse | |
doc1 = path_to_docs(base_path, verbose=verbose, fail_any_exception=fail_any_exception) | |
else: | |
raise RuntimeError("No file handler for %s" % os.path.basename(file)) | |
# allow doc1 to be list or not. If not list, did not chunk yet, so chunk now | |
if not isinstance(doc1, list): | |
if chunk: | |
docs = chunk_sources([doc1], chunk_size=chunk_size) | |
else: | |
docs = [doc1] | |
else: | |
docs = doc1 | |
assert isinstance(docs, list) | |
return docs | |
def path_to_doc1(file, verbose=False, fail_any_exception=False, return_file=True, chunk=True, chunk_size=512, | |
is_url=False, is_txt=False, | |
enable_captions=True, | |
captions_model=None, | |
enable_ocr=False, caption_loader=None): | |
if verbose: | |
if is_url: | |
print("Ingesting URL: %s" % file, flush=True) | |
elif is_txt: | |
print("Ingesting Text: %s" % file, flush=True) | |
else: | |
print("Ingesting file: %s" % file, flush=True) | |
res = None | |
try: | |
# don't pass base_path=path, would infinitely recurse | |
res = file_to_doc(file, base_path=None, verbose=verbose, fail_any_exception=fail_any_exception, | |
chunk=chunk, chunk_size=chunk_size, | |
is_url=is_url, is_txt=is_txt, | |
enable_captions=enable_captions, | |
captions_model=captions_model, | |
enable_ocr=enable_ocr, | |
caption_loader=caption_loader) | |
except BaseException as e: | |
print("Failed to ingest %s due to %s" % (file, traceback.format_exc())) | |
if fail_any_exception: | |
raise | |
else: | |
exception_doc = Document( | |
page_content='', | |
metadata={"source": file, "exception": str(e), "traceback": traceback.format_exc()}) | |
res = [exception_doc] | |
if return_file: | |
base_tmp = "temp_path_to_doc1" | |
if not os.path.isdir(base_tmp): | |
os.makedirs(base_tmp, exist_ok=True) | |
filename = os.path.join(base_tmp, str(uuid.uuid4()) + ".tmp.pickle") | |
with open(filename, 'wb') as f: | |
pickle.dump(res, f) | |
return filename | |
return res | |
def path_to_docs(path_or_paths, verbose=False, fail_any_exception=False, n_jobs=-1, | |
chunk=True, chunk_size=512, | |
url=None, text=None, | |
enable_captions=True, | |
captions_model=None, | |
caption_loader=None, | |
enable_ocr=False, | |
): | |
globs_image_types = [] | |
globs_non_image_types = [] | |
if path_or_paths is None: | |
return [] | |
elif url: | |
globs_non_image_types = [url] | |
elif text: | |
globs_non_image_types = [text] | |
elif isinstance(path_or_paths, str): | |
# single path, only consume allowed files | |
path = path_or_paths | |
# Below globs should match patterns in file_to_doc() | |
[globs_image_types.extend(glob.glob(os.path.join(path, "./**/*.%s" % ftype), recursive=True)) | |
for ftype in image_types] | |
[globs_non_image_types.extend(glob.glob(os.path.join(path, "./**/*.%s" % ftype), recursive=True)) | |
for ftype in non_image_types] | |
else: | |
# list/tuple of files (consume what can, and exception those that selected but cannot consume so user knows) | |
assert isinstance(path_or_paths, (list, tuple)), "Wrong type for path_or_paths: %s" % type(path_or_paths) | |
# reform out of allowed types | |
globs_image_types.extend(flatten_list([[x for x in path_or_paths if x.endswith(y)] for y in image_types])) | |
# could do below: | |
# globs_non_image_types = flatten_list([[x for x in path_or_paths if x.endswith(y)] for y in non_image_types]) | |
# But instead, allow fail so can collect unsupported too | |
set_globs_image_types = set(globs_image_types) | |
globs_non_image_types.extend([x for x in path_or_paths if x not in set_globs_image_types]) | |
# could use generator, but messes up metadata handling in recursive case | |
if caption_loader and not isinstance(caption_loader, (bool, str)) and \ | |
caption_loader.device != 'cpu' or \ | |
get_device() == 'cuda': | |
# to avoid deadlocks, presume was preloaded and so can't fork due to cuda context | |
n_jobs_image = 1 | |
else: | |
n_jobs_image = n_jobs | |
return_file = True # local choice | |
is_url = url is not None | |
is_txt = text is not None | |
kwargs = dict(verbose=verbose, fail_any_exception=fail_any_exception, | |
return_file=return_file, | |
chunk=chunk, chunk_size=chunk_size, | |
is_url=is_url, | |
is_txt=is_txt, | |
enable_captions=enable_captions, | |
captions_model=captions_model, | |
caption_loader=caption_loader, | |
enable_ocr=enable_ocr, | |
) | |
if n_jobs != 1 and len(globs_non_image_types) > 1: | |
# avoid nesting, e.g. upload 1 zip and then inside many files | |
# harder to handle if upload many zips with many files, inner parallel one will be disabled by joblib | |
documents = Parallel(n_jobs=n_jobs, verbose=10 if verbose else 0, backend='multiprocessing')( | |
delayed(path_to_doc1)(file, **kwargs) for file in globs_non_image_types | |
) | |
else: | |
documents = [path_to_doc1(file, **kwargs) for file in globs_non_image_types] | |
# do images separately since can't fork after cuda in parent, so can't be parallel | |
if n_jobs_image != 1 and len(globs_image_types) > 1: | |
# avoid nesting, e.g. upload 1 zip and then inside many files | |
# harder to handle if upload many zips with many files, inner parallel one will be disabled by joblib | |
image_documents = Parallel(n_jobs=n_jobs, verbose=10 if verbose else 0, backend='multiprocessing')( | |
delayed(path_to_doc1)(file, **kwargs) for file in globs_image_types | |
) | |
else: | |
image_documents = [path_to_doc1(file, **kwargs) for file in globs_image_types] | |
# add image docs in | |
documents += image_documents | |
if return_file: | |
# then documents really are files | |
files = documents.copy() | |
documents = [] | |
for fil in files: | |
with open(fil, 'rb') as f: | |
documents.extend(pickle.load(f)) | |
# remove temp pickle | |
os.remove(fil) | |
else: | |
documents = reduce(concat, documents) | |
return documents | |
def prep_langchain(persist_directory, load_db_if_exists, db_type, use_openai_embedding, langchain_mode, user_path, | |
hf_embedding_model, n_jobs=-1, kwargs_make_db={}): | |
""" | |
do prep first time, involving downloads | |
# FIXME: Add github caching then add here | |
:return: | |
""" | |
assert langchain_mode not in ['MyData'], "Should not prep scratch data" | |
if os.path.isdir(persist_directory): | |
print("Prep: persist_directory=%s exists, using" % persist_directory, flush=True) | |
db = get_existing_db(persist_directory, load_db_if_exists, db_type, use_openai_embedding, langchain_mode, | |
hf_embedding_model) | |
else: | |
print("Prep: persist_directory=%s does not exist, regenerating" % persist_directory, flush=True) | |
db = None | |
if langchain_mode in ['All', 'DriverlessAI docs']: | |
# FIXME: Could also just use dai_docs.pickle directly and upload that | |
get_dai_docs(from_hf=True) | |
if langchain_mode in ['All', 'wiki']: | |
get_wiki_sources(first_para=kwargs_make_db['first_para'], text_limit=kwargs_make_db['text_limit']) | |
langchain_kwargs = kwargs_make_db.copy() | |
langchain_kwargs.update(locals()) | |
db = make_db(**langchain_kwargs) | |
return db | |
def get_existing_db(persist_directory, load_db_if_exists, db_type, use_openai_embedding, langchain_mode, | |
hf_embedding_model): | |
if load_db_if_exists and db_type == 'chroma' and os.path.isdir(persist_directory) and os.path.isdir( | |
os.path.join(persist_directory, 'index')): | |
print("DO Loading db: %s" % langchain_mode, flush=True) | |
embedding = get_embedding(use_openai_embedding, hf_embedding_model=hf_embedding_model) | |
db = Chroma(persist_directory=persist_directory, embedding_function=embedding, | |
collection_name=langchain_mode.replace(' ', '_')) | |
print("DONE Loading db: %s" % langchain_mode, flush=True) | |
return db | |
return None | |
def make_db(**langchain_kwargs): | |
func_names = list(inspect.signature(_make_db).parameters) | |
missing_kwargs = [x for x in func_names if x not in langchain_kwargs] | |
defaults_db = {k: v.default for k, v in dict(inspect.signature(run_qa_db).parameters).items()} | |
for k in missing_kwargs: | |
if k in defaults_db: | |
langchain_kwargs[k] = defaults_db[k] | |
# final check for missing | |
missing_kwargs = [x for x in func_names if x not in langchain_kwargs] | |
assert not missing_kwargs, "Missing kwargs: %s" % missing_kwargs | |
# only keep actual used | |
langchain_kwargs = {k: v for k, v in langchain_kwargs.items() if k in func_names} | |
return _make_db(**langchain_kwargs) | |
def _make_db(use_openai_embedding=False, | |
hf_embedding_model="sentence-transformers/all-MiniLM-L6-v2", | |
first_para=False, text_limit=None, chunk=False, chunk_size=1024, | |
langchain_mode=None, | |
user_path=None, | |
db_type='faiss', | |
load_db_if_exists=False, | |
db=None, | |
n_jobs=-1): | |
persist_directory = 'db_dir_%s' % langchain_mode # single place, no special names for each case | |
if not db and load_db_if_exists and db_type == 'chroma' and os.path.isdir(persist_directory) and os.path.isdir( | |
os.path.join(persist_directory, 'index')): | |
assert langchain_mode not in ['MyData'], "Should not load MyData db this way" | |
print("Loading db", flush=True) | |
embedding = get_embedding(use_openai_embedding, hf_embedding_model=hf_embedding_model) | |
db = Chroma(persist_directory=persist_directory, embedding_function=embedding, | |
collection_name=langchain_mode.replace(' ', '_')) | |
elif not db: | |
assert langchain_mode not in ['MyData'], "Should not make MyData db this way" | |
sources = [] | |
print("Generating sources", flush=True) | |
if langchain_mode in ['wiki_full', 'All', "'All'"]: | |
from read_wiki_full import get_all_documents | |
small_test = None | |
print("Generating new wiki", flush=True) | |
sources1 = get_all_documents(small_test=small_test, n_jobs=os.cpu_count() // 2) | |
print("Got new wiki", flush=True) | |
if chunk: | |
sources1 = chunk_sources(sources1, chunk_size=chunk_size) | |
print("Chunked new wiki", flush=True) | |
sources.extend(sources1) | |
if langchain_mode in ['wiki', 'All', "'All'"]: | |
sources1 = get_wiki_sources(first_para=first_para, text_limit=text_limit) | |
if chunk: | |
sources1 = chunk_sources(sources1, chunk_size=chunk_size) | |
sources.extend(sources1) | |
if langchain_mode in ['github h2oGPT', 'All', "'All'"]: | |
# sources = get_github_docs("dagster-io", "dagster") | |
sources1 = get_github_docs("h2oai", "h2ogpt") | |
# FIXME: always chunk for now | |
sources1 = chunk_sources(sources1, chunk_size=chunk_size) | |
sources.extend(sources1) | |
if langchain_mode in ['DriverlessAI docs', 'All', "'All'"]: | |
sources1 = get_dai_docs(from_hf=True) | |
if chunk and False: # FIXME: DAI docs are already chunked well, should only chunk more if over limit | |
sources1 = chunk_sources(sources1, chunk_size=chunk_size) | |
sources.extend(sources1) | |
if langchain_mode in ['All', 'UserData']: | |
if user_path: | |
# chunk internally for speed over multiple docs | |
sources1 = path_to_docs(user_path, n_jobs=n_jobs, chunk=chunk, chunk_size=chunk_size) | |
sources.extend(sources1) | |
else: | |
print("Chose UserData but user_path is empty/None", flush=True) | |
if False and langchain_mode in ['urls', 'All', "'All'"]: | |
# from langchain.document_loaders import UnstructuredURLLoader | |
# loader = UnstructuredURLLoader(urls=urls) | |
urls = ["https://www.birdsongsf.com/who-we-are/"] | |
from langchain.document_loaders import PlaywrightURLLoader | |
loader = PlaywrightURLLoader(urls=urls, remove_selectors=["header", "footer"]) | |
sources1 = loader.load() | |
sources.extend(sources1) | |
if not sources: | |
print("langchain_mode %s has no sources, not making db" % langchain_mode, flush=True) | |
return None | |
print("Generating db", flush=True) | |
db = get_db(sources, use_openai_embedding=use_openai_embedding, db_type=db_type, | |
persist_directory=persist_directory, langchain_mode=langchain_mode, | |
hf_embedding_model=hf_embedding_model) | |
print("Generated db", flush=True) | |
return db | |
source_prefix = "Sources [Score | Link]:" | |
source_postfix = "End Sources<p>" | |
def run_qa_db(**kwargs): | |
func_names = list(inspect.signature(_run_qa_db).parameters) | |
# hard-coded defaults | |
kwargs['answer_with_sources'] = True | |
kwargs['sanitize_bot_response'] = True | |
kwargs['show_rank'] = False | |
missing_kwargs = [x for x in func_names if x not in kwargs] | |
assert not missing_kwargs, "Missing kwargs: %s" % missing_kwargs | |
# only keep actual used | |
kwargs = {k: v for k, v in kwargs.items() if k in func_names} | |
return _run_qa_db(**kwargs) | |
def _run_qa_db(query=None, | |
use_openai_model=False, use_openai_embedding=False, | |
first_para=False, text_limit=None, k=4, chunk=False, chunk_size=1024, | |
user_path=None, | |
db_type='faiss', | |
model_name=None, model=None, tokenizer=None, | |
hf_embedding_model="sentence-transformers/all-MiniLM-L6-v2", | |
stream_output=False, | |
prompter=None, | |
prompt_type=None, | |
answer_with_sources=True, | |
cut_distanct=1.1, | |
sanitize_bot_response=True, | |
show_rank=False, | |
load_db_if_exists=False, | |
db=None, | |
max_new_tokens=256, | |
temperature=0.1, | |
repetition_penalty=1.0, | |
top_k=40, | |
top_p=0.7, | |
langchain_mode=None, | |
n_jobs=-1): | |
""" | |
:param query: | |
:param use_openai_model: | |
:param use_openai_embedding: | |
:param first_para: | |
:param text_limit: | |
:param k: | |
:param chunk: | |
:param chunk_size: | |
:param user_path: user path to glob recursively from | |
:param db_type: 'faiss' for in-memory db or 'chroma' for persistent db | |
:param model_name: model name, used to switch behaviors | |
:param model: pre-initialized model, else will make new one | |
:param tokenizer: pre-initialized tokenizer, else will make new one. Required not None if model is not None | |
:param answer_with_sources | |
:return: | |
""" | |
# FIXME: For All just go over all dbs instead of a separate db for All | |
db = make_db(**locals()) | |
prompt_type = prompter.prompt_type if prompter is not None else prompt_type | |
llm, model_name, streamer, prompt_type_out = get_llm(use_openai_model=use_openai_model, model_name=model_name, | |
model=model, tokenizer=tokenizer, | |
stream_output=stream_output, | |
max_new_tokens=max_new_tokens, | |
temperature=temperature, | |
repetition_penalty=repetition_penalty, | |
top_k=top_k, | |
top_p=top_p, | |
prompt_type=prompt_type, | |
) | |
if model_name in ['llama', 'gptj']: | |
# FIXME: for now, streams to stdout/stderr currently | |
stream_output = False | |
if not use_openai_model and prompt_type not in ['plain'] or model_name in ['llama', 'gptj']: | |
# instruct-like, rather than few-shot prompt_type='plain' as default | |
# but then sources confuse the model with how inserted among rest of text, so avoid | |
prefix = "" | |
if langchain_mode in ['Disabled', 'ChatLLM', 'LLM']: | |
use_context = False | |
template = """%s{context}{question}""" % prefix | |
else: | |
use_context = True | |
template = """%s | |
== | |
{context} | |
== | |
{question}""" % prefix | |
prompt = PromptTemplate( | |
# input_variables=["summaries", "question"], | |
input_variables=["context", "question"], | |
template=template, | |
) | |
chain = load_qa_chain(llm, prompt=prompt) | |
else: | |
chain = load_qa_with_sources_chain(llm) | |
use_context = True | |
if query is None: | |
query = "What are the main differences between Linux and Windows?" | |
# https://github.com/hwchase17/langchain/issues/1946 | |
# FIXME: Seems to way to get size of chroma db to limit k to avoid | |
# Chroma collection MyData contains fewer than 4 elements. | |
# type logger error | |
k_db = 1000 if db_type == 'chroma' else k # k=100 works ok too for | |
if db and use_context: | |
docs_with_score = db.similarity_search_with_score(query, k=k_db)[:k] | |
# cut off so no high distance docs/sources considered | |
docs = [x[0] for x in docs_with_score if x[1] < cut_distanct] | |
scores = [x[1] for x in docs_with_score if x[1] < cut_distanct] | |
if len(scores) > 0: | |
print("Distance: min: %s max: %s mean: %s median: %s" % | |
(scores[0], scores[-1], np.mean(scores), np.median(scores)), flush=True) | |
else: | |
docs = [] | |
scores = [] | |
if not docs and use_context: | |
return None | |
common_words_file = "data/NGSL_1.2_stats.csv.zip" | |
if os.path.isfile(common_words_file): | |
df = pd.read_csv("data/NGSL_1.2_stats.csv.zip") | |
import string | |
reduced_query = query.translate(str.maketrans(string.punctuation, ' ' * len(string.punctuation))).strip() | |
reduced_query_words = reduced_query.split(' ') | |
set_common = set(df['Lemma'].values.tolist()) | |
num_common = len([x.lower() in set_common for x in reduced_query_words]) | |
frac_common = num_common / len(reduced_query) | |
# FIXME: report to user bad query that uses too many common words | |
print("frac_common: %s" % frac_common, flush=True) | |
if langchain_mode in ['Disabled', 'ChatLLM', 'LLM']: | |
chain_kwargs = dict(input_documents=[], question=query) | |
else: | |
chain_kwargs = dict(input_documents=docs, question=query) | |
if stream_output: | |
answer = None | |
assert streamer is not None | |
target = wrapped_partial(chain, chain_kwargs) | |
import queue | |
bucket = queue.Queue() | |
thread = EThread(target=target, streamer=streamer, bucket=bucket) | |
thread.start() | |
outputs = "" | |
prompt = None # FIXME | |
try: | |
for new_text in streamer: | |
# print("new_text: %s" % new_text, flush=True) | |
if bucket.qsize() > 0 or thread.exc: | |
thread.join() | |
outputs += new_text | |
if prompter: # and False: # FIXME: pipeline can already use prompter | |
output1 = prompter.get_response(outputs, prompt=prompt, | |
sanitize_bot_response=sanitize_bot_response) | |
yield output1 | |
else: | |
yield outputs | |
except BaseException: | |
# if any exception, raise that exception if was from thread, first | |
if thread.exc: | |
raise thread.exc | |
raise | |
finally: | |
# in case no exception and didn't join with thread yet, then join | |
if not thread.exc: | |
answer = thread.join() | |
# in case raise StopIteration or broke queue loop in streamer, but still have exception | |
if thread.exc: | |
raise thread.exc | |
# FIXME: answer is not string outputs from streamer. How to get actual final output? | |
# answer = outputs | |
else: | |
answer = chain(chain_kwargs) | |
if not use_context: | |
ret = answer['output_text'] | |
yield ret | |
elif answer is not None: | |
print("query: %s" % query, flush=True) | |
print("answer: %s" % answer['output_text'], flush=True) | |
# link | |
answer_sources = [(max(0.0, 1.5 - score) / 1.5, get_url(doc)) for score, doc in | |
zip(scores, answer['input_documents'])] | |
answer_sources_dict = defaultdict(list) | |
[answer_sources_dict[url].append(score) for score, url in answer_sources] | |
answers_dict = {} | |
for url, scores_url in answer_sources_dict.items(): | |
answers_dict[url] = np.max(scores_url) | |
answer_sources = [(score, url) for url, score in answers_dict.items()] | |
answer_sources.sort(key=lambda x: x[0], reverse=True) | |
if show_rank: | |
# answer_sources = ['%d | %s' % (1 + rank, url) for rank, (score, url) in enumerate(answer_sources)] | |
# sorted_sources_urls = "Sources [Rank | Link]:<br>" + "<br>".join(answer_sources) | |
answer_sources = ['%s' % url for rank, (score, url) in enumerate(answer_sources)] | |
sorted_sources_urls = "Ranked Sources:<br>" + "<br>".join(answer_sources) | |
else: | |
answer_sources = ['<li>%.2g | %s</li>' % (score, url) for score, url in answer_sources] | |
sorted_sources_urls = f"{source_prefix}<p><ul>" + "<p>".join(answer_sources) | |
sorted_sources_urls += f"</ul></p>{source_postfix}" | |
if not answer['output_text'].endswith('\n'): | |
answer['output_text'] += '\n' | |
if answer_with_sources: | |
ret = answer['output_text'] + '\n' + sorted_sources_urls | |
else: | |
ret = answer['output_text'] | |
yield ret | |
return | |
def chunk_sources(sources, chunk_size=1024): | |
source_chunks = [] | |
# Below for known separator | |
# splitter = CharacterTextSplitter(separator=" ", chunk_size=chunk_size, chunk_overlap=0) | |
splitter = RecursiveCharacterTextSplitter(chunk_size=chunk_size, chunk_overlap=0) | |
for source in sources: | |
# print(source.metadata['source'], flush=True) | |
for chunky in splitter.split_text(source.page_content): | |
source_chunks.append(Document(page_content=chunky, metadata=source.metadata)) | |
return source_chunks | |
def get_db_from_hf(dest=".", db_dir='db_dir_DriverlessAI_docs.zip'): | |
from huggingface_hub import hf_hub_download | |
# True for case when locally already logged in with correct token, so don't have to set key | |
token = os.getenv('HUGGINGFACE_API_TOKEN', True) | |
path_to_zip_file = hf_hub_download('h2oai/db_dirs', db_dir, token=token, repo_type='dataset') | |
import zipfile | |
with zipfile.ZipFile(path_to_zip_file, 'r') as zip_ref: | |
zip_ref.extractall(dest) | |
return path_to_zip_file | |
# Note dir has space in some cases, while zip does not | |
some_db_zips = [['db_dir_DriverlessAI_docs.zip', 'db_dir_DriverlessAI docs', 'CC-BY-NC license'], | |
['db_dir_UserData.zip', 'db_dir_UserData', 'CC-BY license for ArXiv'], | |
['db_dir_github_h2oGPT.zip', 'db_dir_github h2oGPT', 'ApacheV2 license'], | |
['db_dir_wiki.zip', 'db_dir_wiki', 'CC-BY-SA Wikipedia license'], | |
# ['db_dir_wiki_full.zip', 'db_dir_wiki_full.zip', '23GB, 05/04/2023 CC-BY-SA Wiki license'], | |
] | |
all_db_zips = some_db_zips + \ | |
[['db_dir_wiki_full.zip', 'db_dir_wiki_full.zip', '23GB, 05/04/2023 CC-BY-SA Wiki license'], | |
] | |
def get_some_dbs_from_hf(dest='.', db_zips=None): | |
if db_zips is None: | |
db_zips = some_db_zips | |
for db_dir, dir_expected, license1 in db_zips: | |
path_to_zip_file = get_db_from_hf(dest=dest, db_dir=db_dir) | |
assert os.path.isfile(path_to_zip_file), "Missing zip in %s" % path_to_zip_file | |
if dir_expected: | |
assert os.path.isdir(os.path.join(dest, dir_expected)), "Missing path for %s" % dir_expected | |
assert os.path.isdir(os.path.join(dest, dir_expected, 'index')), "Missing index in %s" % dir_expected | |
if __name__ == '__main__': | |
pass | |