Spaces:
Sleeping
Sleeping
import streamlit as st | |
import pandas as pd | |
from transformers import pipeline | |
# Load the zero-shot classification model | |
classifier = pipeline("zero-shot-classification", | |
model="facebook/bart-large-mnli") | |
# Sample dataset (replace this with your actual dataset) | |
df = pd.read_csv('datasets.csv') | |
def tag_finder(user_input): | |
keywords = df['Keyword'].unique() | |
result = classifier(user_input, keywords) | |
threshold = result['scores'][0] | |
for score in result['scores']: | |
if score == threshold: | |
continue | |
if (threshold - score) >= threshold / 10: | |
threshold = score | |
else: | |
break | |
useful_tags = [result['labels'][idx] for idx, score in enumerate(result['scores']) if score >= threshold] | |
relevant_datasets = [] | |
for tag in useful_tags: | |
relevant_datasets.extend(df[df['Keyword'] == tag]['Datasets'].tolist()) | |
return useful_tags, relevant_datasets | |
# Define the Streamlit app | |
def main(): | |
# Set title and description | |
st.title("Dataset Finder") | |
st.write("Enter short description about your ML model below and get relevant tags for your dataset.") | |
# Get user input | |
user_input = st.text_input("Enter the description:") | |
if st.button("Submit"): | |
# Find relevant tags and datasets | |
relevant_tags, relevant_datasets = tag_finder(user_input) | |
# Display relevant tags | |
if relevant_tags: | |
st.subheader("Datasets:") | |
for dataset in relevant_datasets: | |
tag = df[df['Datasets'] == dataset]['Keyword'].iloc[0] | |
st.markdown(f''' | |
<div style="border: 2px solid #555; border-radius: 10px; padding: 10px; margin-bottom: 10px; background-color: #333; color: white; display: flex; justify-content: space-between; align-items: center;"> | |
<div>{dataset}</div> | |
<div style="padding: 5px 10px; border: #fff 2px solid; border-radius: 5px;transition: background-color 0.3s;"><a href="https://datasetsearch.research.google.com/search?search&src=0&query={dataset}" style = "text-decoration: none; color: white;">link</a></div> | |
<div style="border: 1px solid #666; padding: 5px; background-color: #444; border-radius: 12px;"> | |
<img width="20" height="20" style="margin: 5px;" src="https://img.icons8.com/ios/50/ffffff/price-tag--v2.png" alt="price-tag--v2"/>{tag} | |
</div> | |
</div> | |
''', unsafe_allow_html=True) | |
else: | |
st.warning("No relevant tags found.") | |
# Run the app | |
if __name__ == "__main__": | |
main() | |