chatchat / app.py
PyaeSoneK's picture
Rename app .py to app.py
8751c2c
raw
history blame
4.38 kB
from langchain.chains import LLMChain, SequentialChain # our langchain is going to allow us to run our topic through our prompt template and then run it through our llm chain and then run it through our sequential chain(generate our response)
from langchain.memory import ConversationBufferMemory
import streamlit as st
from transformers import AutoModel
import torch
import transformers
from transformers import AutoTokenizer, AutoModelForCausalLM
model = AutoModelForCausalLM.from_pretrained("PyaeSoneK/LlamaV2LegalFineTuned",
device_map='auto',
torch_dtype=torch.float16,
use_auth_token= st.secrets['hf_access_token'],)
# load_in_4bit=True
tokenizer = AutoTokenizer.from_pretrained("PyaeSoneK/LlamaV2LegalFineTuned",
use_auth_token=st.secrets['hf_access_token'],)
# Use a pipeline for later
from transformers import pipeline
pipe = pipeline("text-generation",
model=model,
tokenizer= tokenizer,
torch_dtype=torch.bfloat16,
device_map="auto",
max_new_tokens = 512,
do_sample=True,
top_k=30,
num_return_sequences=1,
eos_token_id=tokenizer.eos_token_id
)
import json
import textwrap
import torch
from transformers import AutoTokenizer, AutoModelForCausalLM
B_INST, E_INST = "[INST]", "[/INST]"
B_SYS, E_SYS = "<>\n", "\n<>\n\n"
DEFAULT_SYSTEM_PROMPT = """\
You are a helpful, respectful and honest assistant. Always answer as helpfully as possible, while being safe. Your answers should not include any harmful, unethical, racist, sexist, toxic, dangerous, or illegal content. Please ensure that your responses are socially unbiased and positive in nature.
If a question does not make any sense, or is not factually coherent, explain why instead of answering something not correct. If you don't know the answer to a question, please don't share false information."""
def get_prompt(instruction, new_system_prompt=DEFAULT_SYSTEM_PROMPT, citation=None):
SYSTEM_PROMPT = B_SYS + new_system_prompt + E_SYS
prompt_template = B_INST + SYSTEM_PROMPT + instruction + E_INST
if citation:
prompt_template += f"\n\nCitation: {citation}" # Insert citation here
return prompt_template
def cut_off_text(text, prompt):
cutoff_phrase = prompt
index = text.find(cutoff_phrase)
if index != -1:
return text[:index]
else:
return text
def remove_substring(string, substring):
return string.replace(substring, "")
def generate(text, citation=None):
prompt = get_prompt(text, citation=citation)
inputs = tokenizer(prompt, return_tensors="pt")
with torch.no_grad():
outputs = model.generate(**inputs,
max_length=512,
eos_token_id=tokenizer.eos_token_id,
pad_token_id=tokenizer.eos_token_id,
)
final_outputs = tokenizer.batch_decode(outputs, skip_special_tokens=True)[0]
final_outputs = cut_off_text(final_outputs, '')
final_outputs = remove_substring(final_outputs, prompt)
return final_outputs
def parse_text(text):
wrapped_text = textwrap.fill(text, width=100)
print(wrapped_text + '\n\n')
# return assistant_text
from langchain import HuggingFacePipeline
from langchain import PromptTemplate, LLMChain
llm = HuggingFacePipeline(pipeline = pipe, model_kwargs = {'temperature':0})
system_prompt = "You are an advanced legal assistant that excels at giving advice. "
instruction = "Convert the following input text from stupid to legally reasoned and step-by-step throughout advice:\n\n {text}"
template = get_prompt(instruction, system_prompt)
print(template)
# App framework
st.title('🦜Seon\'s Legal QA For Dummies 🔗 ')
prompt = PromptTemplate(template=template, input_variables=["text"])
llm_chain = LLMChain(prompt=prompt, llm=llm)
text = st.text_input('Plug in your prompt here')
# Instantiate the prompt template # this will show stuff to the screen if there's a prompt
if text:
response = llm_chain.run(text)
st.write(parse_text(response))