regraded01 commited on
Commit
65db96a
1 Parent(s): 10c623f

build: create a new streamlit app file that will be built in Langchain. TO DO: build out Langchain and then remove - rename back to when migration is successful

Browse files
Files changed (1) hide show
  1. app_langchain.py +110 -0
app_langchain.py ADDED
@@ -0,0 +1,110 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import streamlit as st
2
+ import yaml
3
+ import requests
4
+ import re
5
+ import os
6
+ from src.pdfParser import get_pdf_text
7
+
8
+ # Get HuggingFace API key
9
+ api_key_name = "HUGGINGFACE_HUB_TOKEN"
10
+ api_key = os.getenv(api_key_name)
11
+ if api_key is None:
12
+ st.error(f"Failed to read `{api_key_name}`. Ensure the token is correctly located")
13
+
14
+
15
+ with open("config/model_config.yml", "r") as file:
16
+ model_config = yaml.safe_load(file)
17
+
18
+ system_message = model_config["system_message"]
19
+ model_id = model_config["model_id"]
20
+
21
+
22
+ def query(payload, model_id):
23
+ headers = {"Authorization": f"Bearer {api_key}"}
24
+ API_URL = f"https://api-inference.huggingface.co/models/{model_id}"
25
+ response = requests.post(API_URL, headers=headers, json=payload)
26
+ return response.json()
27
+
28
+
29
+ def prompt_generator(system_message, user_message):
30
+ return f"""
31
+ <s>[INST] <<SYS>>
32
+ {system_message}
33
+ <</SYS>>
34
+ {user_message} [/INST]
35
+ """
36
+
37
+
38
+ # Pattern to clean up text response from API
39
+ pattern = r".*\[/INST\]([\s\S]*)$"
40
+
41
+ # Initialize chat history
42
+ if "messages" not in st.session_state:
43
+ st.session_state.messages = []
44
+
45
+ # Include PDF upload ability
46
+ pdf_upload = st.file_uploader(
47
+ "Upload a .PDF here",
48
+ type=".pdf",
49
+ )
50
+
51
+ if pdf_upload is not None:
52
+ pdf_text = get_pdf_text(pdf_upload)
53
+
54
+
55
+ if "key_inputs" not in st.session_state:
56
+ st.session_state.key_inputs = {}
57
+
58
+ col1, col2, col3 = st.columns([3, 3, 2])
59
+
60
+ with col1:
61
+ key_name = st.text_input("Key/Column Name (e.g. patient_name)", key="key_name")
62
+
63
+ with col2:
64
+ key_description = st.text_area(
65
+ "*(Optional) Description of key/column", key="key_description"
66
+ )
67
+
68
+ with col3:
69
+ if st.button("Extract this column"):
70
+ if key_description:
71
+ st.session_state.key_inputs[key_name] = key_description
72
+ else:
73
+ st.session_state.key_inputs[key_name] = "No further description provided"
74
+
75
+ if st.session_state.key_inputs:
76
+ keys_title = st.write("\nKeys/Columns for extraction:")
77
+ keys_values = st.write(st.session_state.key_inputs)
78
+
79
+ with st.spinner("Extracting requested data"):
80
+ if st.button("Extract data!"):
81
+ user_message = f"""
82
+ Use the text provided and denoted by 3 backticks ```{pdf_text}```.
83
+ Extract the following columns and return a table that could be uploaded to an SQL database.
84
+ {'; '.join([key + ': ' + st.session_state.key_inputs[key] for key in st.session_state.key_inputs])}
85
+ """
86
+ the_prompt = prompt_generator(
87
+ system_message=system_message, user_message=user_message
88
+ )
89
+ response = query(
90
+ {
91
+ "inputs": the_prompt,
92
+ "parameters": {"max_new_tokens": 500, "temperature": 0.1},
93
+ },
94
+ model_id,
95
+ )
96
+ try:
97
+ match = re.search(
98
+ pattern, response[0]["generated_text"], re.MULTILINE | re.DOTALL
99
+ )
100
+ if match:
101
+ response = match.group(1).strip()
102
+
103
+ response = eval(response)
104
+
105
+ st.success("Data Extracted Successfully!")
106
+ st.write(response)
107
+ except:
108
+ st.error("Unable to connect to model. Please try again later.")
109
+
110
+ # st.success(f"Data Extracted!")