PyaeSoneK commited on
Commit
58d7897
·
1 Parent(s): 3ed3601

Upload app (2).py

Browse files
Files changed (1) hide show
  1. app (2).py +120 -0
app (2).py ADDED
@@ -0,0 +1,120 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from langchain.chains import LLMChain, SequentialChain # our langchain is going to allow us to run our topic through our prompt template and then run it through our llm chain and then run it through our sequential chain(generate our response)
2
+ from langchain.memory import ConversationBufferMemory
3
+ import streamlit as st
4
+
5
+ from transformers import AutoModel
6
+ import torch
7
+ import transformers
8
+ from transformers import AutoTokenizer, AutoModelForCausalLM
9
+
10
+
11
+ # App framework
12
+ st.title('🦜Seon\'s Legal QA For Dummies 🔗 ')
13
+
14
+
15
+ model = AutoModelForCausalLM.from_pretrained("PyaeSoneK/pythia_70m_legalQA",
16
+ device_map='auto',
17
+ torch_dtype=torch.float32,
18
+ use_auth_token= st.secrets['hf_access_token'],
19
+ )
20
+
21
+
22
+
23
+ # load_in_4bit=True
24
+
25
+ tokenizer = AutoTokenizer.from_pretrained("PyaeSoneK/pythia_70m_legalQA",
26
+ use_auth_token=st.secrets['hf_access_token'],)
27
+
28
+ # Use a pipeline for later
29
+ from transformers import pipeline
30
+
31
+ pipe = pipeline("text-generation",
32
+ model=model,
33
+ tokenizer= tokenizer,
34
+ torch_dtype=torch.float32,
35
+ device_map="auto",
36
+ max_new_tokens = 512,
37
+ do_sample=True,
38
+ top_k=30,
39
+ num_return_sequences=2,
40
+ eos_token_id=tokenizer.eos_token_id
41
+ )
42
+
43
+ import json
44
+ import textwrap
45
+ import torch
46
+
47
+
48
+ B_INST, E_INST = "[INST]", "[/INST]"
49
+ B_SYS, E_SYS = "<>\n", "\n<>\n\n"
50
+ DEFAULT_SYSTEM_PROMPT = """\
51
+ You are a helpful, respectful and honest assistant. Always answer as helpfully as possible, while being safe. Your answers should not include any harmful, unethical, racist, sexist, toxic, dangerous, or illegal content. Please ensure that your responses are socially unbiased and positive in nature.
52
+
53
+ If a question does not make any sense, or is not factually coherent, explain why instead of answering something not correct. If you don't know the answer to a question, please don't share false information.Just say you don't know and you are sorry!"""
54
+
55
+
56
+ def get_prompt(instruction, new_system_prompt=DEFAULT_SYSTEM_PROMPT, citation=None):
57
+ SYSTEM_PROMPT = B_SYS + new_system_prompt + E_SYS
58
+ prompt_template = B_INST + SYSTEM_PROMPT + instruction + E_INST
59
+
60
+ if citation:
61
+ prompt_template += f"\n\nCitation: {citation}" # Insert citation here
62
+
63
+ return prompt_template
64
+
65
+ def cut_off_text(text, prompt):
66
+ cutoff_phrase = prompt
67
+ index = text.find(cutoff_phrase)
68
+ if index != -1:
69
+ return text[:index]
70
+ else:
71
+ return text
72
+
73
+ def remove_substring(string, substring):
74
+ return string.replace(substring, "")
75
+
76
+ def generate(text, citation=None):
77
+ prompt = get_prompt(text, citation=citation)
78
+ inputs = tokenizer(prompt, return_tensors="pt")
79
+ with torch.no_grad():
80
+ outputs = model.generate(**inputs,
81
+ max_length=512,
82
+ eos_token_id=tokenizer.eos_token_id,
83
+ pad_token_id=tokenizer.eos_token_id,
84
+ )
85
+ final_outputs = tokenizer.batch_decode(outputs, skip_special_tokens=True)[0]
86
+ final_outputs = cut_off_text(final_outputs, '')
87
+ final_outputs = remove_substring(final_outputs, prompt)
88
+
89
+ return final_outputs
90
+
91
+ def parse_text(text):
92
+ wrapped_text = textwrap.fill(text, width=100)
93
+ print(wrapped_text + '\n\n')
94
+ # return assistant_text
95
+
96
+
97
+ from langchain import HuggingFacePipeline
98
+ from langchain import PromptTemplate, LLMChain
99
+
100
+ llm = HuggingFacePipeline(pipeline = pipe, model_kwargs = {'temperature':0})
101
+
102
+
103
+ system_prompt = "You are an advanced legal assistant that excels at giving advice. "
104
+ instruction = "Convert the following input text from stupid to legally reasoned and step-by-step throughout advice:\n\n {text}"
105
+ template = get_prompt(instruction, system_prompt)
106
+ print(template)
107
+
108
+ prompt = PromptTemplate(template=template, input_variables=["text"])
109
+
110
+ llm_chain = LLMChain(prompt=prompt, llm=llm)
111
+
112
+ text = st.text_input('Plug in your prompt here')
113
+ # Instantiate the prompt template # this will show stuff to the screen if there's a prompt
114
+
115
+ if text:
116
+ response = llm_chain.run(text)
117
+ st.write(parse_text(response))
118
+
119
+
120
+