Upload app (2).py
Browse files- app (2).py +120 -0
app (2).py
ADDED
@@ -0,0 +1,120 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from langchain.chains import LLMChain, SequentialChain # our langchain is going to allow us to run our topic through our prompt template and then run it through our llm chain and then run it through our sequential chain(generate our response)
|
2 |
+
from langchain.memory import ConversationBufferMemory
|
3 |
+
import streamlit as st
|
4 |
+
|
5 |
+
from transformers import AutoModel
|
6 |
+
import torch
|
7 |
+
import transformers
|
8 |
+
from transformers import AutoTokenizer, AutoModelForCausalLM
|
9 |
+
|
10 |
+
|
11 |
+
# App framework
|
12 |
+
st.title('🦜Seon\'s Legal QA For Dummies 🔗 ')
|
13 |
+
|
14 |
+
|
15 |
+
model = AutoModelForCausalLM.from_pretrained("PyaeSoneK/pythia_70m_legalQA",
|
16 |
+
device_map='auto',
|
17 |
+
torch_dtype=torch.float32,
|
18 |
+
use_auth_token= st.secrets['hf_access_token'],
|
19 |
+
)
|
20 |
+
|
21 |
+
|
22 |
+
|
23 |
+
# load_in_4bit=True
|
24 |
+
|
25 |
+
tokenizer = AutoTokenizer.from_pretrained("PyaeSoneK/pythia_70m_legalQA",
|
26 |
+
use_auth_token=st.secrets['hf_access_token'],)
|
27 |
+
|
28 |
+
# Use a pipeline for later
|
29 |
+
from transformers import pipeline
|
30 |
+
|
31 |
+
pipe = pipeline("text-generation",
|
32 |
+
model=model,
|
33 |
+
tokenizer= tokenizer,
|
34 |
+
torch_dtype=torch.float32,
|
35 |
+
device_map="auto",
|
36 |
+
max_new_tokens = 512,
|
37 |
+
do_sample=True,
|
38 |
+
top_k=30,
|
39 |
+
num_return_sequences=2,
|
40 |
+
eos_token_id=tokenizer.eos_token_id
|
41 |
+
)
|
42 |
+
|
43 |
+
import json
|
44 |
+
import textwrap
|
45 |
+
import torch
|
46 |
+
|
47 |
+
|
48 |
+
B_INST, E_INST = "[INST]", "[/INST]"
|
49 |
+
B_SYS, E_SYS = "<>\n", "\n<>\n\n"
|
50 |
+
DEFAULT_SYSTEM_PROMPT = """\
|
51 |
+
You are a helpful, respectful and honest assistant. Always answer as helpfully as possible, while being safe. Your answers should not include any harmful, unethical, racist, sexist, toxic, dangerous, or illegal content. Please ensure that your responses are socially unbiased and positive in nature.
|
52 |
+
|
53 |
+
If a question does not make any sense, or is not factually coherent, explain why instead of answering something not correct. If you don't know the answer to a question, please don't share false information.Just say you don't know and you are sorry!"""
|
54 |
+
|
55 |
+
|
56 |
+
def get_prompt(instruction, new_system_prompt=DEFAULT_SYSTEM_PROMPT, citation=None):
|
57 |
+
SYSTEM_PROMPT = B_SYS + new_system_prompt + E_SYS
|
58 |
+
prompt_template = B_INST + SYSTEM_PROMPT + instruction + E_INST
|
59 |
+
|
60 |
+
if citation:
|
61 |
+
prompt_template += f"\n\nCitation: {citation}" # Insert citation here
|
62 |
+
|
63 |
+
return prompt_template
|
64 |
+
|
65 |
+
def cut_off_text(text, prompt):
|
66 |
+
cutoff_phrase = prompt
|
67 |
+
index = text.find(cutoff_phrase)
|
68 |
+
if index != -1:
|
69 |
+
return text[:index]
|
70 |
+
else:
|
71 |
+
return text
|
72 |
+
|
73 |
+
def remove_substring(string, substring):
|
74 |
+
return string.replace(substring, "")
|
75 |
+
|
76 |
+
def generate(text, citation=None):
|
77 |
+
prompt = get_prompt(text, citation=citation)
|
78 |
+
inputs = tokenizer(prompt, return_tensors="pt")
|
79 |
+
with torch.no_grad():
|
80 |
+
outputs = model.generate(**inputs,
|
81 |
+
max_length=512,
|
82 |
+
eos_token_id=tokenizer.eos_token_id,
|
83 |
+
pad_token_id=tokenizer.eos_token_id,
|
84 |
+
)
|
85 |
+
final_outputs = tokenizer.batch_decode(outputs, skip_special_tokens=True)[0]
|
86 |
+
final_outputs = cut_off_text(final_outputs, '')
|
87 |
+
final_outputs = remove_substring(final_outputs, prompt)
|
88 |
+
|
89 |
+
return final_outputs
|
90 |
+
|
91 |
+
def parse_text(text):
|
92 |
+
wrapped_text = textwrap.fill(text, width=100)
|
93 |
+
print(wrapped_text + '\n\n')
|
94 |
+
# return assistant_text
|
95 |
+
|
96 |
+
|
97 |
+
from langchain import HuggingFacePipeline
|
98 |
+
from langchain import PromptTemplate, LLMChain
|
99 |
+
|
100 |
+
llm = HuggingFacePipeline(pipeline = pipe, model_kwargs = {'temperature':0})
|
101 |
+
|
102 |
+
|
103 |
+
system_prompt = "You are an advanced legal assistant that excels at giving advice. "
|
104 |
+
instruction = "Convert the following input text from stupid to legally reasoned and step-by-step throughout advice:\n\n {text}"
|
105 |
+
template = get_prompt(instruction, system_prompt)
|
106 |
+
print(template)
|
107 |
+
|
108 |
+
prompt = PromptTemplate(template=template, input_variables=["text"])
|
109 |
+
|
110 |
+
llm_chain = LLMChain(prompt=prompt, llm=llm)
|
111 |
+
|
112 |
+
text = st.text_input('Plug in your prompt here')
|
113 |
+
# Instantiate the prompt template # this will show stuff to the screen if there's a prompt
|
114 |
+
|
115 |
+
if text:
|
116 |
+
response = llm_chain.run(text)
|
117 |
+
st.write(parse_text(response))
|
118 |
+
|
119 |
+
|
120 |
+
|