Spaces:
Sleeping
Sleeping
File size: 4,653 Bytes
d19b3e0 c9c2e0b d19b3e0 1bb9e65 c9c2e0b d19b3e0 e832753 d19b3e0 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 |
import streamlit as st
import requests
import csv
import datetime
import time
import pandas as pd
import matplotlib.pyplot as plt
import os
##########################
SECRET = os.environ["api_secret"]
headers = {"Authorization": "Bearer " + SECRET}
API_URL = "https://api-inference.huggingface.co/models/cccmatthew/surrey-gp30"
##########################
def load_response_times():
try:
df = pd.read_csv('model_interactions.csv', usecols=["Timestamp", "Response Time"])
df['Timestamp'] = pd.to_datetime(df['Timestamp'])
return df
except Exception as e:
st.error(f"Failed to read response times: {e}")
return pd.DataFrame()
def plot_response_times(df):
if not df.empty:
plt.figure(figsize=(10, 5))
plt.plot(df['Timestamp'], df['Response Time'], marker='o', linestyle='-')
plt.title('Response Times Over Time')
plt.xlabel('Timestamp')
plt.ylabel('Response Time (seconds)')
plt.grid(True)
st.pyplot(plt)
else:
st.write("No response time data to display.")
#Function to setup the logs ina csv file
def setup_csv_logger():
with open('model_interactions.csv', 'a', newline='') as file:
writer = csv.writer(file)
#The headers will be written if not present
if file.tell() == 0:
writer.writerow(["Timestamp", "User Input", "Model Prediction", "Response Time"])
def log_to_csv(sentence, results, response_time):
with open('model_interactions.csv', 'a', newline='') as file:
writer = csv.writer(file)
writer.writerow([datetime.datetime.now(), sentence, results, response_time])
setup_csv_logger()
st.title('Group 30 - DistilBERT')
st.write('This application uses DistilBERT to classify Abbreviations (AC) and Long Forms (LF)')
example_sentences = [
"RAFs are plotted for a selection of neurons in the dorsal zone (DZ) of auditory cortex in Figure 1.",
"Light dissolved inorganic carbon (DIC) resulting from the oxidation of hydrocarbons.",
"Images were acquired using a GE 3.0T MRI scanner with an upgrade for echo-planar imaging (EPI)."
]
sentence = st.selectbox('Choose an example sentence or type your own below:', example_sentences + ['Custom Input...'])
if sentence == 'Custom Input...':
sentence = st.text_input('Input your sentence here', '')
def merge_entities(sentence, entities):
entities = sorted(entities, key=lambda x: x['start'])
annotated_sentence = ""
last_end = 0
for entity in entities:
annotated_sentence += sentence[last_end:entity['start']]
annotated_sentence += f"<mark style='background-color: #ffcccb;'><b>{sentence[entity['start']:entity['end']]}</b> [{entity['entity_group']}]</mark>"
last_end = entity['end']
annotated_sentence += sentence[last_end:]
return annotated_sentence
def send_request_with_retry(url, headers, json_data, retries=3, backoff_factor=1):
"""Send request with retries on timeouts and HTTP 503 errors."""
for attempt in range(retries):
start_time = time.time()
try:
response = requests.post(url, headers=headers, json=json_data)
response.raise_for_status()
response_time = time.time() - start_time
return response, response_time
except requests.exceptions.HTTPError as e:
if response.status_code == 503:
st.info('Server is unavailable, retrying...')
else:
raise
except (requests.exceptions.ConnectionError, requests.exceptions.Timeout) as e:
st.info(f"Network issue ({str(e)}), retrying...")
time.sleep(backoff_factor * (2 ** attempt))
st.error("Failed to process request after several attempts.")
return None, None
if st.button('Classify'):
if sentence:
API_URL = API_URL
headers = headers
response, response_time = send_request_with_retry(API_URL, headers, {"inputs": sentence})
if response is not None:
results = response.json()
st.write('Results:')
annotated_sentence = merge_entities(sentence, results)
st.markdown(annotated_sentence, unsafe_allow_html=True)
log_to_csv(sentence, results, response_time)
df = load_response_times()
plot_response_times(df)
else:
st.error("Unable to classify the sentence due to server issues.")
else:
st.error('Please enter a sentence.')
#Separate button to just plot the response time
if st.button('Show Response Times'):
df = load_response_times()
plot_response_times(df)
|