gclone-4125-x2 / app.py
jacobfrye's picture
Update prompt
3e3a4c6
import sqlite3
import gradio as gr
from hashlib import md5 as hash_algo
from re import match
from io import BytesIO
from pypdf import PdfReader
from llm_rs import AutoModel,SessionConfig,GenerationConfig,Precision
repo_name = "rustformers/mpt-7b-ggml"
file_name = "mpt-7b-instruct-q5_1-ggjt.bin"
script_env = 'prod'
session_config = SessionConfig(threads=2,batch_size=2)
model = AutoModel.from_pretrained(repo_name, model_file=file_name, session_config=session_config,verbose=True)
def process_stream(rules, log, temperature, top_p, top_k, max_new_tokens, seed):
con = sqlite3.connect("history.db")
cur = con.cursor()
instruction = ''
hashes = []
if type(rules) is not list:
rules = [rules]
for rule in rules:
data, hash = get_file_contents(rule)
instruction += data + '\n'
hashes.append(hash)
hashes.sort()
hashes = hash_algo(''.join(hashes).encode()).hexdigest()
largest = 0
lines = instruction.split('\r\n')
if len(lines) == 1:
lines = instruction.split('\n')
for line in lines:
m = match('^(\d+)\.', line)
if m != None:
num = int(line[m.start():m.end()-1])
if num > largest:
largest = num
instruction += str(largest + 1) + '. '
query, hash = get_file_contents(log)
hashes = hash_algo((hashes + hash).encode()).hexdigest()
instruction = instruction.replace('\r\r\n', '\n')
full_req = "A conversation between a user and an LLM-based AI assistant. The assistant gives helpful and honest answers.\r\n\r\nQ: Read the rules stated below and check the queries for any violation. State the rules which are violated by a query (if any). Also suggest a possible remediation, if possible. Do not make any assumptions outside of the rules stated below.\r\n\r\n" + instruction + 'The queries are as follows:\r\n' + query + '\r\n \r\nA: '
full_req = full_req.replace('\r\n', '\n')
prompt=f"""Below is an instruction that describes a task. Write a response that appropriately completes the request.
### Instruction:
{full_req}
### Response:
Answer:"""
response = ""
row = cur.execute('SELECT response FROM queries WHERE hexdigest = ?', [hashes]).fetchone()
if row != None:
response += "Cached Result:\n" + row[0]
yield response
else:
if script_env != 'test':
generation_config = GenerationConfig(seed=seed,temperature=temperature,top_p=top_p,top_k=top_k,max_new_tokens=max_new_tokens)
streamer = model.stream(prompt=prompt,generation_config=generation_config)
for new_text in streamer:
response += new_text
yield response
else:
num = 0
while num < 100:
response += " " + str(num)
num += 1
yield response
cur.execute('INSERT INTO queries VALUES(?, ?)', (hashes, response))
con.commit()
cur.close()
con.close()
def get_file_contents(file):
data = None
byte_hash = ''
with open(file.name, 'rb') as f:
data = f.read()
byte_hash = hash_algo(data).hexdigest()
if file.name.endswith('.pdf'):
rdr = PdfReader(BytesIO(data))
data = ''
for page in rdr.pages:
data += page.extract_text()
else:
data = data.decode()
if file.name.endswith(".csv"):
data = data.replace(',', ' ')
return (data, byte_hash)
def upload_log_file(files):
file_paths = [file.name for file in files]
return file_paths
def upload_file(files):
file_paths = [file.name for file in files]
return file_paths
with gr.Blocks(
theme=gr.themes.Soft(),
css=".disclaimer {font-variant-caps: all-small-caps;}",
) as demo:
gr.Markdown(
"""<h1><center>Grid 5.0 Information Security Track</center></h1>
"""
)
rules = gr.File(file_count="multiple")
upload_button = gr.UploadButton("Click to upload a new Compliance Document", file_types=[".txt", ".pdf"], file_count="multiple")
upload_button.upload(upload_file, upload_button, rules)
with gr.Row():
with gr.Column():
log = gr.File()
upload_log_button = gr.UploadButton("Click to upload a log file", file_types=[".txt", ".csv", ".pdf"], file_count="multiple")
upload_log_button.upload(upload_log_file, upload_log_button, log)
with gr.Accordion("Advanced Options:", open=False):
with gr.Row():
with gr.Column():
with gr.Row():
temperature = gr.Slider(
label="Temperature",
value=0.8,
minimum=0.1,
maximum=1.0,
step=0.1,
interactive=True,
info="Higher values produce more diverse outputs",
)
with gr.Column():
with gr.Row():
top_p = gr.Slider(
label="Top-p (nucleus sampling)",
value=0.95,
minimum=0.0,
maximum=1.0,
step=0.01,
interactive=True,
info=(
"Sample from the smallest possible set of tokens whose cumulative probability "
"exceeds top_p. Set to 1 to disable and sample from all tokens."
),
)
with gr.Column():
with gr.Row():
top_k = gr.Slider(
label="Top-k",
value=40,
minimum=5,
maximum=80,
step=1,
interactive=True,
info="Sample from a shortlist of top-k tokens — 0 to disable and sample from all tokens.",
)
with gr.Column():
with gr.Row():
max_new_tokens = gr.Slider(
label="Maximum new tokens",
value=256,
minimum=0,
maximum=1024,
step=5,
interactive=True,
info="The maximum number of new tokens to generate",
)
with gr.Column():
with gr.Row():
seed = gr.Number(
label="Seed",
value=42,
interactive=True,
info="The seed to use for the generation",
precision=0
)
with gr.Row():
submit = gr.Button("Submit")
with gr.Row():
with gr.Box():
gr.Markdown("**Output**")
output_7b = gr.Markdown()
submit.click(
process_stream,
inputs=[rules, log, temperature, top_p, top_k, max_new_tokens,seed],
outputs=output_7b,
)
demo.queue(max_size=4, concurrency_count=1).launch(debug=True)