Spaces:
Sleeping
Sleeping
import sqlite3 | |
import gradio as gr | |
from hashlib import md5 as hash_algo | |
from re import match | |
from io import BytesIO | |
from pypdf import PdfReader | |
from llm_rs import AutoModel,SessionConfig,GenerationConfig,Precision | |
repo_name = "rustformers/mpt-7b-ggml" | |
file_name = "mpt-7b-instruct-q5_1-ggjt.bin" | |
script_env = 'prod' | |
session_config = SessionConfig(threads=2,batch_size=2) | |
model = AutoModel.from_pretrained(repo_name, model_file=file_name, session_config=session_config,verbose=True) | |
def process_stream(rules, log, temperature, top_p, top_k, max_new_tokens, seed): | |
con = sqlite3.connect("history.db") | |
cur = con.cursor() | |
instruction = '' | |
hashes = [] | |
if type(rules) is not list: | |
rules = [rules] | |
for rule in rules: | |
data, hash = get_file_contents(rule) | |
instruction += data + '\n' | |
hashes.append(hash) | |
hashes.sort() | |
hashes = hash_algo(''.join(hashes).encode()).hexdigest() | |
largest = 0 | |
lines = instruction.split('\r\n') | |
if len(lines) == 1: | |
lines = instruction.split('\n') | |
for line in lines: | |
m = match('^(\d+)\.', line) | |
if m != None: | |
num = int(line[m.start():m.end()-1]) | |
if num > largest: | |
largest = num | |
instruction += str(largest + 1) + '. ' | |
query, hash = get_file_contents(log) | |
hashes = hash_algo((hashes + hash).encode()).hexdigest() | |
instruction = instruction.replace('\r\r\n', '\n') | |
full_req = "A conversation between a user and an LLM-based AI assistant. The assistant gives helpful and honest answers.\r\n\r\nQ: Read the rules stated below and check the queries for any violation. State the rules which are violated by a query (if any). Also suggest a possible remediation, if possible. Do not make any assumptions outside of the rules stated below.\r\n\r\n" + instruction + 'The queries are as follows:\r\n' + query + '\r\n \r\nA: ' | |
full_req = full_req.replace('\r\n', '\n') | |
prompt=f"""Below is an instruction that describes a task. Write a response that appropriately completes the request. | |
### Instruction: | |
{full_req} | |
### Response: | |
Answer:""" | |
response = "" | |
row = cur.execute('SELECT response FROM queries WHERE hexdigest = ?', [hashes]).fetchone() | |
if row != None: | |
response += "Cached Result:\n" + row[0] | |
yield response | |
else: | |
if script_env != 'test': | |
generation_config = GenerationConfig(seed=seed,temperature=temperature,top_p=top_p,top_k=top_k,max_new_tokens=max_new_tokens) | |
streamer = model.stream(prompt=prompt,generation_config=generation_config) | |
for new_text in streamer: | |
response += new_text | |
yield response | |
else: | |
num = 0 | |
while num < 100: | |
response += " " + str(num) | |
num += 1 | |
yield response | |
cur.execute('INSERT INTO queries VALUES(?, ?)', (hashes, response)) | |
con.commit() | |
cur.close() | |
con.close() | |
def get_file_contents(file): | |
data = None | |
byte_hash = '' | |
with open(file.name, 'rb') as f: | |
data = f.read() | |
byte_hash = hash_algo(data).hexdigest() | |
if file.name.endswith('.pdf'): | |
rdr = PdfReader(BytesIO(data)) | |
data = '' | |
for page in rdr.pages: | |
data += page.extract_text() | |
else: | |
data = data.decode() | |
if file.name.endswith(".csv"): | |
data = data.replace(',', ' ') | |
return (data, byte_hash) | |
def upload_log_file(files): | |
file_paths = [file.name for file in files] | |
return file_paths | |
def upload_file(files): | |
file_paths = [file.name for file in files] | |
return file_paths | |
with gr.Blocks( | |
theme=gr.themes.Soft(), | |
css=".disclaimer {font-variant-caps: all-small-caps;}", | |
) as demo: | |
gr.Markdown( | |
"""<h1><center>Grid 5.0 Information Security Track</center></h1> | |
""" | |
) | |
rules = gr.File(file_count="multiple") | |
upload_button = gr.UploadButton("Click to upload a new Compliance Document", file_types=[".txt", ".pdf"], file_count="multiple") | |
upload_button.upload(upload_file, upload_button, rules) | |
with gr.Row(): | |
with gr.Column(): | |
log = gr.File() | |
upload_log_button = gr.UploadButton("Click to upload a log file", file_types=[".txt", ".csv", ".pdf"], file_count="multiple") | |
upload_log_button.upload(upload_log_file, upload_log_button, log) | |
with gr.Accordion("Advanced Options:", open=False): | |
with gr.Row(): | |
with gr.Column(): | |
with gr.Row(): | |
temperature = gr.Slider( | |
label="Temperature", | |
value=0.8, | |
minimum=0.1, | |
maximum=1.0, | |
step=0.1, | |
interactive=True, | |
info="Higher values produce more diverse outputs", | |
) | |
with gr.Column(): | |
with gr.Row(): | |
top_p = gr.Slider( | |
label="Top-p (nucleus sampling)", | |
value=0.95, | |
minimum=0.0, | |
maximum=1.0, | |
step=0.01, | |
interactive=True, | |
info=( | |
"Sample from the smallest possible set of tokens whose cumulative probability " | |
"exceeds top_p. Set to 1 to disable and sample from all tokens." | |
), | |
) | |
with gr.Column(): | |
with gr.Row(): | |
top_k = gr.Slider( | |
label="Top-k", | |
value=40, | |
minimum=5, | |
maximum=80, | |
step=1, | |
interactive=True, | |
info="Sample from a shortlist of top-k tokens — 0 to disable and sample from all tokens.", | |
) | |
with gr.Column(): | |
with gr.Row(): | |
max_new_tokens = gr.Slider( | |
label="Maximum new tokens", | |
value=256, | |
minimum=0, | |
maximum=1024, | |
step=5, | |
interactive=True, | |
info="The maximum number of new tokens to generate", | |
) | |
with gr.Column(): | |
with gr.Row(): | |
seed = gr.Number( | |
label="Seed", | |
value=42, | |
interactive=True, | |
info="The seed to use for the generation", | |
precision=0 | |
) | |
with gr.Row(): | |
submit = gr.Button("Submit") | |
with gr.Row(): | |
with gr.Box(): | |
gr.Markdown("**Output**") | |
output_7b = gr.Markdown() | |
submit.click( | |
process_stream, | |
inputs=[rules, log, temperature, top_p, top_k, max_new_tokens,seed], | |
outputs=output_7b, | |
) | |
demo.queue(max_size=4, concurrency_count=1).launch(debug=True) |