Spaces:
Running
Running
import streamlit as st | |
from finetune_augmentor import AugmentationExample, AugmentationConfig, FinetuningDataAugmentor | |
import json | |
import streamlit.components.v1 as components | |
from streamlit_ace import st_ace # Editable code block | |
# ------------------------------- | |
# Page Configuration and CSS | |
# ------------------------------- | |
st.set_page_config( | |
page_title="Finetuning Data Augmentation Generator", | |
layout="wide", | |
initial_sidebar_state="expanded", | |
) | |
components.html( | |
""" | |
<div style="position: fixed; top: 10px; right: 10px; z-index: 100;"> | |
<a href="https://github.com/zamalali/ftboost" target="_blank"> | |
<img src="https://github.githubassets.com/images/modules/logos_page/GitHub-Mark.png" alt="GitHub" style="height: 30px; margin-right: 10px;"> | |
</a> | |
<a href="https://huggingface.co/zamal" target="_blank"> | |
<img src="https://huggingface.co/front/assets/huggingface_logo.svg" alt="Hugging Face" style="height: 30px;"> | |
</a> | |
</div> | |
""", | |
height=40 | |
) | |
st.markdown( | |
""" | |
<style> | |
/* Main content area */ | |
.block-container { | |
background-color: #121212; | |
color: #ffffff; | |
} | |
/* Sidebar styling */ | |
[data-testid="stSidebar"] { | |
background-color: #121212; | |
color: #ffffff; | |
} | |
[data-testid="stSidebar"] * { | |
color: #ffffff !important; | |
} | |
/* Button styling */ | |
.stButton>button, .stDownloadButton>button { | |
background-color: #808080 !important; | |
color: #ffffff !important; | |
font-size: 16px; | |
border: none; | |
border-radius: 5px; | |
padding: 0.5rem 1.5rem; | |
margin-top: 1rem; | |
} | |
/* Text inputs */ | |
.stTextInput>div>input, .stNumberInput>div>input { | |
border-radius: 5px; | |
border: 1px solid #ffffff; | |
padding: 0.5rem; | |
background-color: #1a1a1a; | |
color: #ffffff; | |
} | |
.stTextArea>textarea { | |
background-color: #1a1a1a; | |
color: #ffffff; | |
font-family: "Courier New", monospace; | |
border: 1px solid #ffffff; | |
border-radius: 5px; | |
padding: 1rem; | |
} | |
/* Header colors */ | |
h1 { color: #00FF00; } | |
h2, h3, h4 { color: #FFFF00; } | |
/* Field labels */ | |
label { color: #ffffff !important; } | |
/* Remove extra margin in code blocks */ | |
pre { margin: 0; } | |
/* Ace editor style overrides */ | |
.ace_editor { | |
border: none !important; | |
box-shadow: none !important; | |
background-color: #121212 !important; | |
} | |
/* Override alert (error/success) text colors */ | |
[data-testid="stAlert"] { color: #ffffff !important; } | |
/* Add white border to expander header */ | |
[data-testid="stExpander"] > div:first-child { | |
border: 1px solid #ffffff !important; | |
} | |
</style> | |
""", | |
unsafe_allow_html=True, | |
) | |
# Inject JavaScript to scroll to top on load | |
components.html( | |
""" | |
<script> | |
document.addEventListener("DOMContentLoaded", function() { | |
setTimeout(function() { window.scrollTo(0, 0); }, 100); | |
}); | |
</script> | |
""", | |
height=0, | |
) | |
# ------------------------------- | |
# App Title and Description | |
# ------------------------------- | |
st.title("ftBoost π") | |
st.markdown( | |
""" | |
**ftBoost Hero** is a powerful tool designed to help you generate high-quality fine-tuning data for AI models. | |
Whether you're working with OpenAI, Gemini, Mistral, or LLaMA models, this app allows you to create structured | |
input-output pairs and apply augmentation techniques to enhance dataset quality. With advanced tuning parameters, | |
semantic similarity controls, and fluency optimization, **ftBoost Hero** ensures that your fine-tuning data is diverse, | |
well-structured, and ready for training. π | |
""", | |
unsafe_allow_html=True, | |
) | |
# ------------------------------- | |
# Step A: File Upload & Auto-Detection | |
# ------------------------------- | |
st.markdown("##### Step 1: Upload Your Finetuning data JSONL File if you have one already (Optional)") | |
uploaded_file = st.file_uploader("Upload your train.jsonl file", type=["jsonl", "txt"]) | |
uploaded_examples = [] | |
detected_model = None | |
if uploaded_file is not None: | |
try: | |
file_content = uploaded_file.getvalue().decode("utf-8") | |
# Auto-detect model type from the first valid snippet | |
for line in file_content.splitlines(): | |
if line.strip(): | |
record = json.loads(line) | |
if "messages" in record: | |
msgs = record["messages"] | |
if len(msgs) >= 3 and msgs[0].get("role") == "system": | |
detected_model = "OpenAI Models" | |
elif len(msgs) == 2: | |
detected_model = "Mistral Models" | |
elif "contents" in record: | |
detected_model = "Gemini Models" | |
break | |
# Display an info message based on detection result | |
if detected_model is not None: | |
st.info(f"This JSONL file format supports the **{detected_model}**.") | |
else: | |
st.info("The uploaded JSONL file format is not recognized. Please manually select the appropriate model.") | |
# Process the entire file for valid examples | |
for line in file_content.splitlines(): | |
if not line.strip(): | |
continue | |
record = json.loads(line) | |
input_text, output_text = "", "" | |
if "messages" in record: | |
msgs = record["messages"] | |
if len(msgs) >= 3: | |
input_text = msgs[1].get("content", "").strip() | |
output_text = msgs[2].get("content", "").strip() | |
elif len(msgs) == 2: | |
input_text = msgs[0].get("content", "").strip() | |
output_text = msgs[1].get("content", "").strip() | |
elif "contents" in record: | |
contents = record["contents"] | |
if len(contents) >= 2 and "parts" in contents[0] and "parts" in contents[1]: | |
input_text = contents[0]["parts"][0].get("text", "").strip() | |
output_text = contents[1]["parts"][0].get("text", "").strip() | |
if input_text and output_text: | |
uploaded_examples.append(AugmentationExample(input_text=input_text, output_text=output_text)) | |
if len(uploaded_examples) < 3: | |
st.error("Uploaded file does not contain at least 3 valid input/output pairs.") | |
else: | |
st.success(f"Uploaded file processed: {len(uploaded_examples)} valid input/output pairs loaded.") | |
except Exception as e: | |
st.error(f"Error processing uploaded file: {e}") | |
# ------------------------------- | |
# Step B: Model Selection | |
# ------------------------------- | |
default_model = detected_model if detected_model is not None else "OpenAI Models" | |
model_options = ["OpenAI Models", "Gemini Models", "Mistral Models", "Llama Models"] | |
default_index = model_options.index(default_model) if default_model in model_options else 0 | |
model_type = st.selectbox( | |
"Select the output format for finetuning", | |
model_options, | |
index=default_index | |
) | |
# ------------------------------- | |
# Step C: System Message & API Key | |
# ------------------------------- | |
system_message = st.text_input("System Message (optional) only for OpenAI models", value="Marv is a factual chatbot that is also sarcastic.") | |
# groq_api_key = st.text_input("LangChain Groq API Key", type="password", help="Enter your LangChain Groq API Key for data augmentation") | |
groq_api_key = st.text_input( | |
"LangChain Groq API Key (if you don't have one, get it from [here](https://console.groq.com/keys))", | |
type="password", | |
help="Enter your LangChain Groq API Key for data augmentation" | |
) | |
# ------------------------------- | |
# Step D: Input Schema Template Display | |
# ------------------------------- | |
st.markdown("#### Input Schema Template") | |
if model_type == "OpenAI Models": | |
st.code( | |
'''{ | |
"messages": [ | |
{"role": "system", "content": "Marv is a factual chatbot that is also sarcastic."}, | |
{"role": "user", "content": "What's the capital of France?"}, | |
{"role": "assistant", "content": "Paris, as if everyone doesn't know that already."} | |
] | |
}''', language="json") | |
elif model_type == "Gemini Models": | |
st.code( | |
'''{ | |
"contents": [ | |
{"role": "user", "parts": [{"text": "What's the capital of France?"}]}, | |
{"role": "model", "parts": [{"text": "Paris, as if everyone doesn't know that already."}]} | |
] | |
}''', language="json") | |
else: | |
st.code( | |
'''{ | |
"messages": [ | |
{"role": "user", "content": "What's the capital of France?"}, | |
{"role": "assistant", "content": "Paris, as if everyone doesn't know that already."} | |
] | |
}''', language="json") | |
# ------------------------------- | |
# Step E: Manual Input of Pairs (if no file uploaded) | |
# ------------------------------- | |
if uploaded_file is None: | |
st.markdown("##### Enter at least 3 input/output pairs manually:") | |
num_pairs = st.number_input("Number of Pairs", min_value=3, value=3, step=1) | |
pair_templates = [] | |
for i in range(num_pairs): | |
st.markdown(f"##### Pair {i+1}") | |
if model_type == "OpenAI Models": | |
init_template = ('''{ | |
"messages": [ | |
{"role": "system", "content": "''' + system_message + '''"}, | |
{"role": "user", "content": "Enter your input text here"}, | |
{"role": "assistant", "content": "Enter your output text here"} | |
] | |
}''').strip() | |
ace_key = f"pair_{i}_{model_type}_{system_message}" | |
elif model_type == "Gemini Models": | |
init_template = ('''{ | |
"contents": [ | |
{"role": "user", "parts": [{"text": "Enter your input text here"}]}, | |
{"role": "model", "parts": [{"text": "Enter your output text here"}]} | |
] | |
}''').strip() | |
ace_key = f"pair_{i}_{model_type}" | |
else: | |
init_template = ('''{ | |
"messages": [ | |
{"role": "user", "content": "Enter your input text here"}, | |
{"role": "assistant", "content": "Enter your output text here"} | |
] | |
}''').strip() | |
ace_key = f"pair_{i}_{model_type}" | |
pair = st_ace( | |
placeholder="Edit your code here...", | |
value=init_template, | |
language="json", | |
theme="monokai", | |
key=ace_key, | |
height=150 | |
) | |
pair_templates.append(pair) | |
# ------------------------------- | |
# Step F: Augmentation Settings | |
# ------------------------------- | |
target_augmented = st.number_input("Number of Augmented Pairs to Generate", min_value=5, value=5, step=1) | |
finetuning_goal = "Improve conversational clarity and capture subtle nuances" | |
st.markdown(f"**Finetuning Goal:** {finetuning_goal}") | |
with st.expander("Show/Hide Advanced Tuning Parameters"): | |
min_semantic = st.slider("Minimum Semantic Similarity", 0.0, 1.0, 0.80, 0.01) | |
max_semantic = st.slider("Maximum Semantic Similarity", 0.0, 1.0, 0.95, 0.01) | |
min_diversity = st.slider("Minimum Diversity Score", 0.0, 1.0, 0.70, 0.01) | |
min_fluency = st.slider("Minimum Fluency Score", 0.0, 1.0, 0.80, 0.01) | |
# ------------------------------- | |
# Step G: Generate Data Button and Pipeline Execution | |
# ------------------------------- | |
if st.button("Generate Data"): | |
if not groq_api_key.strip(): | |
st.error("Please enter your LangChain Groq API Key to proceed.") | |
st.stop() | |
# Choose examples: from uploaded file if available; otherwise from manual input. | |
if uploaded_file is not None and len(uploaded_examples) >= 3: | |
examples = uploaded_examples | |
else: | |
examples = [] | |
errors = [] | |
for idx, pair in enumerate(pair_templates): | |
try: | |
record = json.loads(pair) | |
if model_type == "OpenAI Models": | |
msgs = record.get("messages", []) | |
if len(msgs) != 3: | |
raise ValueError("Expected 3 messages") | |
input_text = msgs[1].get("content", "").strip() | |
output_text = msgs[2].get("content", "").strip() | |
elif model_type == "Gemini Models": | |
contents = record.get("contents", []) | |
if len(contents) < 2: | |
raise ValueError("Expected at least 2 contents") | |
input_text = contents[0]["parts"][0].get("text", "").strip() | |
output_text = contents[1]["parts"][0].get("text", "").strip() | |
else: | |
msgs = record.get("messages", []) | |
if len(msgs) != 2: | |
raise ValueError("Expected 2 messages for this format") | |
input_text = msgs[0].get("content", "").strip() | |
output_text = msgs[1].get("content", "").strip() | |
if not input_text or not output_text: | |
raise ValueError("Input or output text is empty") | |
examples.append(AugmentationExample(input_text=input_text, output_text=output_text)) | |
except Exception as e: | |
errors.append(f"Error in pair {idx+1}: {e}") | |
if errors: | |
st.error("There were errors in your input pairs:\n" + "\n".join(errors)) | |
elif len(examples) < 3: | |
st.error("Please provide at least 3 valid pairs.") | |
if len(examples) >= 3: | |
target_model = "mixtral-8x7b-32768" | |
try: | |
config = AugmentationConfig( | |
target_model=target_model, | |
examples=examples, | |
finetuning_goal=finetuning_goal, | |
groq_api_key=groq_api_key, | |
system_message=system_message, | |
min_semantic_similarity=min_semantic, | |
max_semantic_similarity=max_semantic, | |
min_diversity_score=min_diversity, | |
min_fluency_score=min_fluency | |
) | |
except Exception as e: | |
st.error(f"Configuration error: {e}") | |
st.stop() | |
st.markdown('<p style="color: white;">Running augmentation pipeline... Please wait.</p>', unsafe_allow_html=True) | |
augmentor = FinetuningDataAugmentor(config) | |
augmentor.run_augmentation(target_count=target_augmented) | |
fmt = model_type.lower() | |
if fmt == "openai models": | |
output_data = augmentor.get_formatted_output(format_type="openai") | |
elif fmt == "gemini models": | |
output_data = augmentor.get_formatted_output(format_type="gemini") | |
elif fmt == "mistral models": | |
output_data = augmentor.get_formatted_output(format_type="mistral") | |
elif fmt == "llama models": | |
output_data = augmentor.get_formatted_output(format_type="llama") | |
else: | |
output_data = augmentor.get_formatted_output(format_type="openai") | |
st.markdown("### Augmented Data") | |
st.code(output_data, language="json") | |
st.download_button("Download train.jsonl", output_data, file_name="train.jsonl") | |