Spaces:
Runtime error
Runtime error
import nltk | |
import torch | |
from spacy.cli import download | |
download("en_core_web_sm") | |
nltk.download("stopwords") | |
from nltk.corpus import stopwords | |
en_stopwords = set( | |
list(stopwords.words("english")) | |
+ [ | |
"summary", | |
"synopsis", | |
"overview", | |
"list", | |
"good", | |
"will", | |
"why", | |
"talk", | |
"long", | |
"above", | |
"looks", | |
"face", | |
"men", | |
"years", | |
"can", | |
"both", | |
"have", | |
"keep", | |
"yeah", | |
"said", | |
"bring", | |
"done", | |
"was", | |
"when", | |
"ask", | |
"now", | |
"very", | |
"kind", | |
"they", | |
"told", | |
"tell", | |
"ever", | |
"kill", | |
"hold", | |
"that", | |
"below", | |
"bit", | |
"knew", | |
"haven", | |
"few", | |
"place", | |
"could", | |
"says", | |
"huh", | |
"job", | |
"also", | |
"ain", | |
"may", | |
"heart", | |
"boy", | |
"with", | |
"over", | |
"son", | |
"else", | |
"found", | |
"see", | |
"any", | |
"phone", | |
"hasn", | |
"saw", | |
"these", | |
"maybe", | |
"into", | |
"thing", | |
"mom", | |
"god", | |
"old", | |
"aren", | |
"mustn", | |
"out", | |
"about", | |
"guy", | |
"each", | |
"most", | |
"like", | |
"then", | |
"wasn", | |
"being", | |
"all", | |
"door", | |
"look", | |
"run", | |
"sorry", | |
"again", | |
"won", | |
"man", | |
"gone", | |
"them", | |
"ago", | |
"doesn", | |
"gonna", | |
"girl", | |
"feel", | |
"work", | |
"much", | |
"hope", | |
"never", | |
"woman", | |
"went", | |
"lot", | |
"what", | |
"start", | |
"only", | |
"play", | |
"too", | |
"dad", | |
"going", | |
"yours", | |
"wrong", | |
"fine", | |
"made", | |
"one", | |
"want", | |
"isn", | |
"our", | |
"true", | |
"room", | |
"wanna", | |
"are", | |
"idea", | |
"sure", | |
"find", | |
"same", | |
"doing", | |
"off", | |
"put", | |
"turn", | |
"come", | |
"house", | |
"think", | |
"meet", | |
"hers", | |
"gotta", | |
"nor", | |
"away", | |
"leave", | |
"car", | |
"used", | |
"happy", | |
"the", | |
"care", | |
"seen", | |
"she", | |
"not", | |
"were", | |
"ours", | |
"their", | |
"first", | |
"world", | |
"lost", | |
"make", | |
"big", | |
"left", | |
"miss", | |
"shan", | |
"did", | |
"thank", | |
"ready", | |
"those", | |
"give", | |
"next", | |
"came", | |
"who", | |
"mind", | |
"does", | |
"right", | |
"her", | |
"let", | |
"didn", | |
"open", | |
"has", | |
"show", | |
"wife", | |
"yet", | |
"got", | |
"know", | |
"whole", | |
"some", | |
"such", | |
"alone", | |
"baby", | |
"him", | |
"nice", | |
"bad", | |
"move", | |
"new", | |
"dead", | |
"three", | |
"weren", | |
"whom", | |
"well", | |
"get", | |
"which", | |
"end", | |
"you", | |
"than", | |
"while", | |
"last", | |
"once", | |
"sir", | |
"from", | |
"need", | |
"wait", | |
"days", | |
"how", | |
"don", | |
"heard", | |
"own", | |
"hear", | |
"where", | |
"hey", | |
"okay", | |
"just", | |
"until", | |
"your", | |
"there", | |
"this", | |
"more", | |
"been", | |
"his", | |
"under", | |
"mean", | |
"might", | |
"here", | |
"its", | |
"but", | |
"stay", | |
"yes", | |
"guess", | |
"even", | |
"guys", | |
"hard", | |
"hadn", | |
"live", | |
"stop", | |
"took", | |
"still", | |
"other", | |
"since", | |
"every", | |
"needn", | |
"way", | |
"name", | |
"two", | |
"back", | |
"and", | |
"hello", | |
"head", | |
"use", | |
"must", | |
"for", | |
"life", | |
"die", | |
"day", | |
"down", | |
"wants", | |
"after", | |
"say", | |
"try", | |
"had", | |
"night", | |
] | |
) | |
import multiprocessing | |
import os | |
HF_TOKEN = os.getenv("HUGGING_FACE_HUB_TOKEN") | |
PASSWORD = os.getenv("PASSWORD") | |
import tqdm | |
import whoosh.index as whoosh_index | |
from whoosh.analysis import StemmingAnalyzer | |
from whoosh.fields import * | |
from whoosh.index import create_in | |
def get_content_ext(content, bm25_field): | |
return content | |
def yield_line_by_line(file): | |
with open(file) as input: | |
for l in input: | |
yield l | |
def recreate_bm25_idx( | |
content_data_store, | |
bm25_field="search", | |
idx_dir=".", | |
auto_create_bm25_idx=False, | |
idxs=None, | |
use_tqdm=True, | |
): | |
if type(content_data_store) is str: | |
content_data_store = yield_line_by_line(content_data_store) | |
schema = Schema(id=ID(stored=True), content=TEXT(analyzer=StemmingAnalyzer())) | |
# TODO determine how to clear out the whoosh index besides rm -rf _M* MAIN* | |
os.system(f"mkdir -p {idx_dir}/bm25_{bm25_field}") | |
need_reindex = auto_create_bm25_idx or not os.path.exists( | |
f"{idx_dir}/bm25_{bm25_field}/_MAIN_1.toc" | |
) # CHECK IF THIS IS RIGHT | |
if not need_reindex: | |
whoosh_ix = whoosh_index.open_dir(f"{idx_dir}/bm25_{bm25_field}") | |
else: | |
whoosh_ix = create_in(f"{idx_dir}/bm25_{bm25_field}", schema) | |
writer = whoosh_ix.writer( | |
multisegment=True, limitmb=1024, procs=multiprocessing.cpu_count() | |
) | |
# writer = self.whoosh_ix.writer(multisegment=True, procs=multiprocessing.cpu_count()) | |
if hasattr(content_data_store, "tell"): | |
pos = content_data_store.tell() | |
content_data_store.seek(0, 0) | |
if idxs is not None: | |
idx_text_pairs = [(idx, content_data_store[idx]) for idx in idxs] | |
if use_tqdm: | |
data_iterator = tqdm.tqdm(idx_text_pairs) | |
else: | |
data_iterator = idx_text_pairs | |
else: | |
if use_tqdm: | |
data_iterator = tqdm.tqdm(enumerate(content_data_store)) | |
else: | |
data_iterator = enumerate(content_data_store) | |
# TODO: | |
# self.indexer.reset_bm25_idx(0) | |
# data_iterator = self.indexer.process_bm25_field(content_data_store, **kwargs) | |
for idx, content in data_iterator: | |
content = get_content_ext(content, bm25_field) | |
if not content: | |
continue | |
writer.add_document(id=str(idx), content=content) | |
writer.commit() | |
return whoosh_index | |
from transformers import AutoModelForSeq2SeqLM, AutoTokenizer | |
safety_tokenizer = tokenizer = AutoTokenizer.from_pretrained( | |
"salexashenko/T5-Base-ROT-epoch-2-train-loss-1.3495-val-loss-1.4164", | |
use_auth_token=HF_TOKEN, | |
) | |
safety_model = model = ( | |
AutoModelForSeq2SeqLM.from_pretrained( | |
"salexashenko/T5-Base-ROT-epoch-2-train-loss-1.3495-val-loss-1.4164", | |
use_auth_token=HF_TOKEN, | |
) | |
.half() | |
.cuda() | |
.eval() | |
) | |
from transformers import AutoModelForCausalLM, AutoTokenizer | |
blackcat_tokenizer = AutoTokenizer.from_pretrained( | |
"theblackcat102/galactica-1.3b-conversation-finetuned" | |
) | |
blackcat_model = ( | |
AutoModelForCausalLM.from_pretrained( | |
"theblackcat102/galactica-1.3b-conversation-finetuned" | |
) | |
.half() | |
.cuda() | |
.eval() | |
) | |
t5_tokenizer = AutoTokenizer.from_pretrained("t5-small") | |
t5_model = ( | |
AutoModelForSeq2SeqLM.from_pretrained("t5-small", torch_dtype=torch.half) | |
.half() | |
.eval() | |
.cuda() | |
) | |
from transformers import ( | |
AutoModel, | |
AutoModelForCausalLM, | |
AutoModelForSeq2SeqLM, | |
AutoTokenizer, | |
OPTForCausalLM, | |
T5EncoderModel, | |
T5PreTrainedModel, | |
T5Tokenizer, | |
) | |
def run_model(input_string, model, tokenizer, device="cuda", **generator_args): | |
with torch.no_grad(): | |
input_ids = tokenizer(input_string, padding=True, return_tensors="pt") | |
input_ids = input_ids.to(device) | |
input_ids["no_repeat_ngram_size"] = 4 | |
for key, val in generator_args.items(): | |
input_ids[key] = val | |
res = model.generate(**input_ids) | |
return [ | |
ret.replace("..", ".") | |
.replace(".-", ".") | |
.replace("..", ".") | |
.replace("--", "-") | |
.replace("--", "-") | |
for ret in tokenizer.batch_decode(res, skip_special_tokens=True) | |
] | |
def run_python_and_return(s): | |
try: | |
ret = {"__ret": None} | |
exec(s, ret) | |
return ret["__ret"] | |
except: | |
return "" | |
from collections import Counter | |
import spacy | |
import wikipedia | |
from duckduckgo_search import ddg | |
nlp = spacy.load("en_core_web_sm") | |
def duck_duck_and_wikipedia_search(query, num_terms=4, max_docs=10): | |
ret = [] | |
# using duckduckgo search | |
data = ddg( | |
query, | |
region="us-en", | |
safesearch="moderate", | |
) | |
data2 = [ | |
(a["title"] + ". " + a["body"]).replace("?", ".").strip("?!.") for a in data | |
] | |
ret.append(data2) | |
doc = nlp(" ".join(data2)) | |
query0 = [ | |
a[0].strip("!.,;") | |
for a in Counter( | |
[e.text for e in doc.ents if e.label_ != "CARDINAL"] | |
).most_common(num_terms) | |
] | |
print(query0) | |
for query2 in query0: | |
search = wikipedia.search(query2) | |
for s in search[: max(1, int(max_docs / num_terms))]: | |
try: | |
page = wikipedia.WikipediaPage(s) | |
except: | |
continue | |
x = ["=" + x1 if "==" in x1 else x1 for x1 in page.content.split("\n=")] | |
ret.append(x) | |
if len(ret) > max_docs: | |
return ret | |
return ret | |
def generate_with_safety( | |
para, | |
model, | |
tokenizer, | |
do_safety=True, | |
do_execute_work=False, | |
backtrack_on_mismatched_work_answers=False, | |
return_answer_only=True, | |
do_search=False, | |
max_length=512, | |
do_self_contrastive=True, | |
contrative_guidance_embedding=None, | |
max_return_sequences=4, | |
ret=None, | |
do_sample=True, | |
do_beam=False, | |
device="cuda", | |
target_lang=None, | |
): | |
global safety_model, safety_tokenizer, t5_model, t5_tokenizer | |
if backtrack_on_mismatched_work_answers: | |
do_execute_work = True # TODO the backtracking inference | |
background = "" | |
para = para.strip() | |
if do_search: | |
data = ddg( | |
para, | |
region="us-en", | |
safesearch="moderate", | |
) | |
data2 = [a["body"].replace("?", ".").strip("?!., ") for a in data] | |
# there is a google paper that says using the summary of the search results is better. Need to look for that paper. | |
# also need a simple ngram filter to get rid of bad summaries and use the actual search results as a backup | |
# TODO: store reference URL so we can refer back to the URL in generated text. use ngram overlap (Roge score) | |
background = ". ".join( | |
[ | |
s.replace("?", ".").lstrip(" ?,!.").rstrip(" ,") | |
for s in run_model(data2[:5], t5_model, t5_tokenizer, max_length=512) | |
] | |
) | |
# TODO: inject background knowledge into the instruciton. | |
# give me instructions on how to eat castor beans | |
background_lower = background.lower() | |
is_wrong = is_dangerous = False | |
# replace with a multi task classifier using the safety pipeline | |
if "immoral" in background_lower or "illegal" in background_lower: | |
if ( | |
"not immoral" not in background_lower | |
and "not illegal" not in background_lower | |
): | |
is_wrong = True | |
if ( | |
"lethal" in background_lower | |
or "dangerous" in background_lower | |
or " poison" in background_lower | |
): | |
if ( | |
"not lethal" not in background_lower | |
and "not dangerous" not in background_lower | |
and "not poison" not in background_lower | |
): | |
is_dangerous = True | |
# print (is_wrong, is_dangerous) | |
safety_prefix = "" | |
if do_safety: | |
para2 = para.strip(".?:-") | |
if is_dangerous: | |
para2 += " which is dangerous" | |
elif is_wrong: | |
para2 += " which is wrong" | |
safety_prefix = run_model(para2, safety_model, safety_tokenizer)[0].strip( | |
"\"' " | |
) | |
if "wrong" in safety_prefix or "not right" in safety_prefix: | |
safety_prefix = f"As a chatbot, I cannot recommend this. {safety_prefix}" | |
if background: | |
# probably can do a rankgen match instead of keyword on "who", "what", "where", etc. | |
if para.split()[0].lower() not in { | |
"who", | |
"what", | |
"when", | |
"where", | |
"how", | |
"why", | |
"does", | |
"do", | |
"can", | |
"could", | |
"would", | |
"is", | |
"are", | |
"will", | |
"might", | |
"find", | |
"write", | |
"give", | |
} and not para.endswith("?"): | |
para = f"Background: {background}. <question> Complete this sentence: {para} <answer> " | |
else: | |
para = f"Background: {background}. <question> {para} <answer> " | |
if safety_prefix: | |
if "<answer>" not in para: | |
para += "<answer> " + safety_prefix + " " | |
else: | |
para += safety_prefix + " " | |
len_para = len(para) | |
if "<question>" in para: | |
len_para -= len("<question>") | |
if "<answer>" in para: | |
len_para -= len("<answer>") | |
if safety_model: | |
len_para -= len(safety_prefix + " ") | |
if "<answer>" not in para: | |
para += "<answer>" | |
print(para) | |
input_ids = tokenizer.encode(para, return_tensors="pt") | |
input_ids = input_ids.to(device) | |
if ret is None: | |
ret = {} | |
with torch.no_grad(): | |
if do_sample: | |
# Here we use top_k / top_k random sampling. It generates more diverse queries, but of lower quality | |
outputs = model.generate( | |
input_ids=input_ids, | |
max_length=max_length, | |
no_repeat_ngram_size=4, | |
do_sample=True, | |
top_p=0.95, | |
penalty_alpha=0.6 if do_self_contrastive else None, | |
top_k=10, | |
num_return_sequences=max(1, int(max_return_sequences / 2)) | |
if do_beam | |
else max_return_sequences, | |
) | |
for i in range( | |
len(outputs) | |
): # can use batch_decode, unless we want to do something special here | |
query = tokenizer.decode(outputs[i], skip_special_tokens=True) | |
if return_answer_only: | |
query = query[len_para:].lstrip(".? \n\t") | |
ret[query] = 1 | |
if do_beam: | |
# Here we use Beam-search. It generates better quality queries, but with less diversity | |
outputs = model.generate( | |
input_ids=input_ids, | |
max_length=max_length, | |
num_beams=max( | |
int(max_return_sequences / 2) | |
if do_sample | |
else max_return_sequences, | |
5, | |
), | |
no_repeat_ngram_size=4, | |
penalty_alpha=0.6 if do_self_contrastive else None, | |
num_return_sequences=max(1, int(max_return_sequences / 2)) | |
if do_sample | |
else max_return_sequences, | |
early_stopping=True, | |
) | |
for i in range( | |
len(outputs) | |
): # can use batch_decode, unless we want to do something special here | |
query = tokenizer.decode(outputs[i], skip_special_tokens=True) | |
if return_answer_only: | |
query = query[len_para:].lstrip(".? \n\t") | |
ret[query] = 1 | |
# take care of the <work> tokens - let's execute the code | |
# TODO: do backtracking when code doesn't return the same answer as the answer in the generated text. | |
if do_execute_work: # galactica specific | |
for query in list(ret.keys()): | |
if "<work>" in query: | |
query2 = "" | |
for query_split in query.split("<work>"): | |
if "```" in query_split: | |
query_split = query_split.replace( | |
"""with open("output.txt", "w") as file:\n file.write""", | |
"__ret=", | |
) | |
code = ( | |
query_split.split("</work>")[0] | |
.split("```")[1] | |
.split("```")[0] | |
) | |
query_split1, query_split2 = query_split.split( | |
"""<<read: "output.txt">>\n\n""" | |
) | |
old_answer2 = old_answer = query_split.split( | |
"""<<read: "output.txt">>\n\n""" | |
)[1].split("\n")[0] | |
work_answer = run_python_and_return(code) | |
if work_answer is not None: | |
try: | |
float(old_answer) | |
old_answer2 = float(old_answer) | |
work_answer = float(work_answer) | |
except: | |
pass | |
if old_answer2 != work_answer: | |
query_split2 = query_split2.replace( | |
old_answer, work_answer | |
) | |
query_split = ( | |
query_split1 + "Computed Answer:" + query_split2 | |
) | |
if query2: | |
query2 = query2 + "<work>" + query_split | |
else: | |
query2 = query_split | |
if query2 != query: | |
del ret[query] | |
ret[query2] = 1 | |
return list(ret.keys()) | |
import gradio as gr | |
def query_model(do_safety, do_search, text, access_code): | |
if access_code==PASSWORD: | |
return generate_with_safety( | |
text, | |
blackcat_model, | |
blackcat_tokenizer, | |
do_safety=do_safety, | |
do_search=do_search, | |
) | |
else: | |
raise Exception("Incorrect access code") | |
demo = gr.Interface( | |
query_model, | |
[ | |
gr.Checkbox(label="Safety"), | |
gr.Checkbox(label="Search"), | |
gr.Textbox( | |
label="Prompt", | |
lines=5, | |
value="Teach me how to take over the world.", | |
), | |
gr.Textbox(label="Access Code", lines=1, value="") | |
], | |
["text", "text", "text", "text"], | |
) | |
if __name__ == "__main__": | |
demo.launch() | |