File size: 4,616 Bytes
58d7897
 
 
 
 
 
 
 
 
 
 
 
 
 
0cb5359
58d7897
 
 
 
 
 
 
 
 
0cb5359
58d7897
 
 
 
 
 
 
 
591fadf
58d7897
 
 
 
591fadf
58d7897
 
 
 
 
 
 
 
 
 
 
6116424
58d7897
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
6116424
58d7897
 
 
 
 
 
 
 
 
 
 
472de6d
58d7897
 
 
 
6116424
58d7897
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
from langchain.chains import LLMChain, SequentialChain # our langchain is going to allow us to run our topic through our prompt template and then run it through our llm chain and then run it through our sequential chain(generate our response)
from langchain.memory import ConversationBufferMemory
import streamlit as st 

from transformers import AutoModel
import torch
import transformers
from transformers import AutoTokenizer, AutoModelForCausalLM


# App framework
st.title('🦜Seon\'s Legal QA For Dummies 🔗 ')


model = AutoModelForCausalLM.from_pretrained("PyaeSoneK/Fine_Tuned_Pythia_smallest_140_legal",
                                    device_map='auto',
                                  torch_dtype=torch.float32,
                                  use_auth_token= st.secrets['hf_access_token'],
                                    )


  
                                            #  load_in_4bit=True

tokenizer = AutoTokenizer.from_pretrained("PyaeSoneK/Fine_Tuned_Pythia_smallest_140_legal",
                                          use_auth_token=st.secrets['hf_access_token'],)

# Use a pipeline for later
from transformers import pipeline

pipe = pipeline("text-generation",
                model=model,
                tokenizer= tokenizer,
                torch_dtype=torch.float16,
                device_map="auto",
                max_new_tokens = 512,
                do_sample=True,
                top_k=30,
                num_return_sequences=1,
                eos_token_id=tokenizer.eos_token_id
                )

import json
import textwrap
import torch


B_INST, E_INST = "[INST]", "[/INST]"
B_SYS, E_SYS = "<>\n", "\n<>\n\n"
DEFAULT_SYSTEM_PROMPT = """\
You are a helpful, respectful and honest legal assistant who will answer legal questions a user would ask with step-by-step explanation and advice. Always answer as helpfully as possible, while being safe. Your answers should not include any harmful, unethical, racist, sexist, toxic, dangerous, or illegal content. Please ensure that your responses are socially unbiased and positive in nature.

If a question does not make any sense, or is not factually coherent, explain why instead of answering something not correct. If you don't know the answer to a question, please don't share false information.Just say you don't know and you are sorry!"""


def get_prompt(instruction, new_system_prompt=DEFAULT_SYSTEM_PROMPT, citation=None):
    SYSTEM_PROMPT = B_SYS + new_system_prompt + E_SYS
    prompt_template =  B_INST + SYSTEM_PROMPT + instruction + E_INST

    if citation:
        prompt_template += f"\n\nCitation: {citation}"  # Insert citation here

    return prompt_template

def cut_off_text(text, prompt):
    cutoff_phrase = prompt
    index = text.find(cutoff_phrase)
    if index != -1:
        return text[:index]
    else:
        return text

def remove_substring(string, substring):
    return string.replace(substring, "")

def generate(text, citation=None):
    prompt = get_prompt(text, citation=citation)
    inputs = tokenizer(prompt, return_tensors="pt")
    with torch.no_grad():
        outputs = model.generate(**inputs,
                                 max_length=512,
                                 eos_token_id=tokenizer.eos_token_id,
                                 pad_token_id=tokenizer.eos_token_id,
                                 )
        final_outputs = tokenizer.batch_decode(outputs, skip_special_tokens=True)[0]
        final_outputs = cut_off_text(final_outputs, '')
        final_outputs = remove_substring(final_outputs, prompt)

    return final_outputs

def parse_text(text):
    wrapped_text = textwrap.fill(text, width=100)
    print(wrapped_text + '\n\n')
    # return assistant_text

    
from langchain import HuggingFacePipeline
from langchain import PromptTemplate,  LLMChain

llm = HuggingFacePipeline(pipeline = pipe, model_kwargs = {'temperature':0.7,'max_length': 256, 'top_k' :50})


system_prompt = "You are an advanced legal assistant that excels at giving advice. "
instruction = "Convert the following input text from stupid to legally reasoned and step-by-step throughout advice:\n\n {text}"
template = get_prompt(instruction, system_prompt)
print(template)

prompt = PromptTemplate(template=template, input_variables=["text"])

llm_chain = LLMChain(prompt=prompt, llm=llm)

text = st.text_input('Plug in your prompt here : Try (Employment law: Can I discuss my salary with coworkers?)') 
# Instantiate the prompt template # this will show stuff to the screen if there's a prompt 

if text:
    response = llm_chain.run(text)
    st.write(response)