import streamlit as st from setfit import SetFitModel # Load the model model = SetFitModel.from_pretrained("leavoigt/vulnerable_groups") # Define the classes group_dict = { 0: 'Coastal communities', 1: 'Small island developing states (SIDS)', 2: 'Landlocked countries', 3: 'Low-income households', 4: 'Informal settlements and slums', 5: 'Rural communities', 6: 'Children and youth', 7: 'Older adults and the elderly', 8: 'Women and girls', 9: 'People with pre-existing health conditions', 10: 'People with disabilities', 11: 'Small-scale farmers and subsistence agriculture', 12: 'Fisherfolk and fishing communities', 13: 'Informal sector workers', 14: 'Children with disabilities', 15: 'Remote communities', 16: 'Young adults', 17: 'Elderly population', 18: 'Urban slums', 19: 'Men and boys', 20: 'Gender non-conforming individuals', 21: 'Pregnant women and new mothers', 22: 'Mountain communities', 23: 'Riverine and flood-prone areas', 24: 'Drought-prone regions', 25: 'Indigenous peoples', 26: 'Migrants and displaced populations', 27: 'Outdoor workers', 28: 'Small-scale farmers', 29: 'Other'} #def predict(text): # preds = model([text])[0].item() # return group_dict[preds] # App st.title("Identify references to vulnerable groups.") st.write("This app allows you to identify whether a text contains any references to vulnerable groups. This can, for example, be used to analyse policy documents.") #col1, col2 = st.columns(2) # Create text input box input_text = st.text_area('Please enter your text here') # Make predictions preds = model(input_text) # Select lab def get_label(prediction_tensor): key = prediction_tensor.index(1) return group_dict[key] st.text(get_label(preds))