Spaces:
Running
Running
File size: 3,897 Bytes
65db96a 9f5f200 65db96a 9f5f200 65db96a 9f5f200 65db96a 9f5f200 65db96a 9f5f200 65db96a |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 |
import streamlit as st
import yaml
import requests
import re
import os
from langchain_core.prompts import PromptTemplate
import streamlit as st
from src.pdfParser import get_pdf_text
# Get HuggingFace API key
api_key_name = "HUGGINGFACE_HUB_TOKEN"
api_key = os.getenv(api_key_name)
if api_key is None:
st.error(f"Failed to read `{api_key_name}`. Ensure the token is correctly located")
# Load in model configuration and check the required keys are present
model_config_dir = "config/model_config.yml"
config_keys = ["system_message", "model_id", "template"]
with open(model_config_dir, "r") as file:
model_config = yaml.safe_load(file)
for var in model_config.keys():
if var not in config_keys:
raise ValueError(f"`{var}` key missing from `{model_config_dir}`")
system_message = model_config["system_message"]
model_id = model_config["model_id"]
template = model_config["template"]
prompt_template = PromptTemplate(
template=template,
input_variables=["system_message", "user_message"]
)
def query(payload, model_id):
headers = {"Authorization": f"Bearer {api_key}"}
API_URL = f"https://api-inference.huggingface.co/models/{model_id}"
response = requests.post(API_URL, headers=headers, json=payload)
return response.json()
def prompt_generator(system_message, user_message):
return f"""
<s>[INST] <<SYS>>
{system_message}
<</SYS>>
{user_message} [/INST]
"""
# Pattern to clean up text response from API
pattern = r".*\[/INST\]([\s\S]*)$"
# Initialize chat history
if "messages" not in st.session_state:
st.session_state.messages = []
# Include PDF upload ability
pdf_upload = st.file_uploader(
"Upload a .PDF here",
type=".pdf",
)
if pdf_upload is not None:
pdf_text = get_pdf_text(pdf_upload)
if "key_inputs" not in st.session_state:
st.session_state.key_inputs = {}
col1, col2, col3 = st.columns([3, 3, 2])
with col1:
key_name = st.text_input("Key/Column Name (e.g. patient_name)", key="key_name")
with col2:
key_description = st.text_area(
"*(Optional) Description of key/column", key="key_description"
)
with col3:
if st.button("Extract this column"):
if key_description:
st.session_state.key_inputs[key_name] = key_description
else:
st.session_state.key_inputs[key_name] = "No further description provided"
if st.session_state.key_inputs:
keys_title = st.write("\nKeys/Columns for extraction:")
keys_values = st.write(st.session_state.key_inputs)
with st.spinner("Extracting requested data"):
if st.button("Extract data!"):
user_message = f"""
Use the text provided and denoted by 3 backticks ```{pdf_text}```.
Extract the following columns and return a table that could be uploaded to an SQL database.
{'; '.join([key + ': ' + st.session_state.key_inputs[key] for key in st.session_state.key_inputs])}
"""
the_prompt = prompt_generator(
system_message=system_message, user_message=user_message
)
response = query(
{
"inputs": the_prompt,
"parameters": {"max_new_tokens": 500, "temperature": 0.1},
},
model_id,
)
try:
match = re.search(
pattern, response[0]["generated_text"], re.MULTILINE | re.DOTALL
)
if match:
response = match.group(1).strip()
response = eval(response)
st.success("Data Extracted Successfully!")
st.write(response)
except:
st.error("Unable to connect to model. Please try again later.")
# st.success(f"Data Extracted!")
|