from huggingface_hub import list_models import streamlit as st from model import ReplicateModel import os import pandas as pd DATASETS_PATH = 'datasets' models = { 'mistral_instruct': ReplicateModel('mistralai/mistral-7b-instruct-v0.1:83b6a56e7c828e667f21fd596c338fd4f0039b46bcfa18d973e8e70e455fda70'), } prompts = { 'simple_prompt': ''' I have topic that is described by the following keywords: [KEYWORDS] Based on the information above, extract a short topic label in the following format: topic: ''', 'few_shot_examples': ''' I have a topic that is described by the following keywords: [KEYWORDS] Example 1: Keywords: apple,fruit,healthy,snack,red,orchard Topic label: Healthy Fruit Snacks Example 2: Keywords: computer,technology,silicon,programming,internet,hardware Topic label: Computer Technology Example 3: Keywords: democracy,government,elections,vote,political,representation Topic label: Democratic Governance Based on the information above, extract a short topic label in the following format: topic: ''' # 'custom_prompt': '' } topicsets = { 'example_topics': os.path.join(DATASETS_PATH, 'topics.csv'), } @st.cache_data(show_spinner=False) def get_available_models(): # return [model.modelId for model in list_models(author='textminr')] return models.keys() @st.cache_resource(show_spinner='Loading model...') def load_model(model_name: str): # model = AutoGPTQForCausalLM.from_quantized(model_name, device_map='auto') # return pipeline('text-generation', model=model, tokenizer=model_name) return models[model_name].load() st.set_page_config(page_title='TL playground', page_icon='🚀', layout='wide') st.title('🚀 Topic Labelling playground') percentage_width_main = 70 st.markdown( f''' ''', unsafe_allow_html=True, ) col1, col2 = st.columns(2, gap='medium') sel_model_name = col1.selectbox('Select a model', models, index=None, placeholder='Select a model') if sel_model_name: model = load_model(sel_model_name) sel_dataset_name = col1.selectbox('Select a dataset', topicsets.keys(), index=None) if sel_dataset_name: sel_dataset = pd.read_csv(topicsets[sel_dataset_name]) sel_dataset.drop(columns=['topic_id', 'domain'], inplace=True) col1.dataframe(sel_dataset) sel_row_index = col1.selectbox('Select a topic', sel_dataset.index) sel_prompt = col2.selectbox('Select a prompt', prompts.keys()) if sel_prompt != 'custom_prompt': col2.code(prompts[sel_prompt], language='text') sel_prompt_text = prompts[sel_prompt] else: sel_prompt_text = st.text_area('Custom prompt', height=200) col2.caption('Make sure to use "[KEYWORDS]" to indicate where the keywords should be inserted.') btn_generate = col2.button('Generate', disabled=(sel_model_name is None or sel_dataset_name is None)) if btn_generate: keywords = ','.join(sel_dataset.iloc[sel_row_index].tolist()[1:]) placeholder = col2.empty() with placeholder, st.spinner('Generating...'): prompt = sel_prompt_text.replace('[KEYWORDS]', keywords) # result = model(prompt, max_new_tokens=100, return_full_text=False)[0]['generated_text'] result = model.generate(prompt) message = col2.chat_message("ai") message.write(result) message.caption('Keywords: ' + keywords)