zamalali commited on
Commit
9a94c10
Β·
0 Parent(s):

Fresh start without binary files

Browse files
.gitattributes ADDED
@@ -0,0 +1,2 @@
 
 
 
1
+ *.mp4 filter=lfs diff=lfs merge=lfs -text
2
+ *.pdf filter=lfs diff=lfs merge=lfs -text
app.py ADDED
@@ -0,0 +1,381 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import streamlit as st
2
+ from finetune_augmentor import AugmentationExample, AugmentationConfig, FinetuningDataAugmentor
3
+ import json
4
+ import streamlit.components.v1 as components
5
+ from streamlit_ace import st_ace # Editable code block
6
+
7
+ # -------------------------------
8
+ # Page Configuration and CSS
9
+ # -------------------------------
10
+
11
+ st.set_page_config(
12
+ page_title="Finetuning Data Augmentation Generator",
13
+ layout="wide",
14
+ initial_sidebar_state="expanded",
15
+ )
16
+
17
+
18
+ components.html(
19
+ """
20
+
21
+ <div style="position: fixed; top: 10px; right: 10px; z-index: 100;">
22
+ <a href="https://github.com/zamalali/ftboost" target="_blank">
23
+ <img src="https://github.githubassets.com/images/modules/logos_page/GitHub-Mark.png" alt="GitHub" style="height: 30px; margin-right: 10px;">
24
+ </a>
25
+ <a href="https://huggingface.co/zamal" target="_blank">
26
+ <img src="https://huggingface.co/front/assets/huggingface_logo.svg" alt="Hugging Face" style="height: 30px;">
27
+ </a>
28
+ </div>
29
+ """,
30
+ height=40
31
+ )
32
+
33
+
34
+ st.markdown(
35
+ """
36
+ <style>
37
+ /* Main content area */
38
+ .block-container {
39
+ background-color: #121212;
40
+ color: #ffffff;
41
+ }
42
+ /* Sidebar styling */
43
+ [data-testid="stSidebar"] {
44
+ background-color: #121212;
45
+ color: #ffffff;
46
+ }
47
+ [data-testid="stSidebar"] * {
48
+ color: #ffffff !important;
49
+ }
50
+ /* Button styling */
51
+ .stButton>button, .stDownloadButton>button {
52
+ background-color: #808080 !important;
53
+ color: #ffffff !important;
54
+ font-size: 16px;
55
+ border: none;
56
+ border-radius: 5px;
57
+ padding: 0.5rem 1.5rem;
58
+ margin-top: 1rem;
59
+ }
60
+ /* Text inputs */
61
+ .stTextInput>div>input, .stNumberInput>div>input {
62
+ border-radius: 5px;
63
+ border: 1px solid #ffffff;
64
+ padding: 0.5rem;
65
+ background-color: #1a1a1a;
66
+ color: #ffffff;
67
+ }
68
+ .stTextArea>textarea {
69
+ background-color: #1a1a1a;
70
+ color: #ffffff;
71
+ font-family: "Courier New", monospace;
72
+ border: 1px solid #ffffff;
73
+ border-radius: 5px;
74
+ padding: 1rem;
75
+ }
76
+ /* Header colors */
77
+ h1 { color: #00FF00; }
78
+ h2, h3, h4 { color: #FFFF00; }
79
+ /* Field labels */
80
+ label { color: #ffffff !important; }
81
+ /* Remove extra margin in code blocks */
82
+ pre { margin: 0; }
83
+ /* Ace editor style overrides */
84
+ .ace_editor {
85
+ border: none !important;
86
+ box-shadow: none !important;
87
+ background-color: #121212 !important;
88
+ }
89
+ /* Override alert (error/success) text colors */
90
+ [data-testid="stAlert"] { color: #ffffff !important; }
91
+ /* Add white border to expander header */
92
+ [data-testid="stExpander"] > div:first-child {
93
+ border: 1px solid #ffffff !important;
94
+ }
95
+ </style>
96
+ """,
97
+ unsafe_allow_html=True,
98
+ )
99
+
100
+ # Inject JavaScript to scroll to top on load
101
+ components.html(
102
+ """
103
+ <script>
104
+ document.addEventListener("DOMContentLoaded", function() {
105
+ setTimeout(function() { window.scrollTo(0, 0); }, 100);
106
+ });
107
+ </script>
108
+ """,
109
+ height=0,
110
+ )
111
+
112
+ # -------------------------------
113
+ # App Title and Description
114
+ # -------------------------------
115
+ st.title("ftBoost πŸš€")
116
+ st.markdown(
117
+ """
118
+ **ftBoost Hero** is a powerful tool designed to help you generate high-quality fine-tuning data for AI models.
119
+ Whether you're working with OpenAI, Gemini, Mistral, or LLaMA models, this app allows you to create structured
120
+ input-output pairs and apply augmentation techniques to enhance dataset quality. With advanced tuning parameters,
121
+ semantic similarity controls, and fluency optimization, **ftBoost Hero** ensures that your fine-tuning data is diverse,
122
+ well-structured, and ready for training. πŸš€
123
+ """,
124
+ unsafe_allow_html=True,
125
+ )
126
+
127
+ # -------------------------------
128
+ # Step A: File Upload & Auto-Detection
129
+ # -------------------------------
130
+ st.markdown("##### Step 1: Upload Your Finetuning data JSONL File if you have one already (Optional)")
131
+ uploaded_file = st.file_uploader("Upload your train.jsonl file", type=["jsonl", "txt"])
132
+ uploaded_examples = []
133
+ detected_model = None
134
+
135
+ if uploaded_file is not None:
136
+ try:
137
+ file_content = uploaded_file.getvalue().decode("utf-8")
138
+ # Auto-detect model type from the first valid snippet
139
+ for line in file_content.splitlines():
140
+ if line.strip():
141
+ record = json.loads(line)
142
+ if "messages" in record:
143
+ msgs = record["messages"]
144
+ if len(msgs) >= 3 and msgs[0].get("role") == "system":
145
+ detected_model = "OpenAI Models"
146
+ elif len(msgs) == 2:
147
+ detected_model = "Mistral Models"
148
+ elif "contents" in record:
149
+ detected_model = "Gemini Models"
150
+ break
151
+
152
+ # Display an info message based on detection result
153
+ if detected_model is not None:
154
+ st.info(f"This JSONL file format supports the **{detected_model}**.")
155
+ else:
156
+ st.info("The uploaded JSONL file format is not recognized. Please manually select the appropriate model.")
157
+
158
+ # Process the entire file for valid examples
159
+ for line in file_content.splitlines():
160
+ if not line.strip():
161
+ continue
162
+ record = json.loads(line)
163
+ input_text, output_text = "", ""
164
+ if "messages" in record:
165
+ msgs = record["messages"]
166
+ if len(msgs) >= 3:
167
+ input_text = msgs[1].get("content", "").strip()
168
+ output_text = msgs[2].get("content", "").strip()
169
+ elif len(msgs) == 2:
170
+ input_text = msgs[0].get("content", "").strip()
171
+ output_text = msgs[1].get("content", "").strip()
172
+ elif "contents" in record:
173
+ contents = record["contents"]
174
+ if len(contents) >= 2 and "parts" in contents[0] and "parts" in contents[1]:
175
+ input_text = contents[0]["parts"][0].get("text", "").strip()
176
+ output_text = contents[1]["parts"][0].get("text", "").strip()
177
+ if input_text and output_text:
178
+ uploaded_examples.append(AugmentationExample(input_text=input_text, output_text=output_text))
179
+ if len(uploaded_examples) < 3:
180
+ st.error("Uploaded file does not contain at least 3 valid input/output pairs.")
181
+ else:
182
+ st.success(f"Uploaded file processed: {len(uploaded_examples)} valid input/output pairs loaded.")
183
+ except Exception as e:
184
+ st.error(f"Error processing uploaded file: {e}")
185
+
186
+ # -------------------------------
187
+ # Step B: Model Selection
188
+ # -------------------------------
189
+ default_model = detected_model if detected_model is not None else "OpenAI Models"
190
+ model_options = ["OpenAI Models", "Gemini Models", "Mistral Models", "Llama Models"]
191
+ default_index = model_options.index(default_model) if default_model in model_options else 0
192
+ model_type = st.selectbox(
193
+ "Select the output format for finetuning",
194
+ model_options,
195
+ index=default_index
196
+ )
197
+
198
+ # -------------------------------
199
+ # Step C: System Message & API Key
200
+ # -------------------------------
201
+ system_message = st.text_input("System Message (optional) only for OpenAI models", value="Marv is a factual chatbot that is also sarcastic.")
202
+ # groq_api_key = st.text_input("LangChain Groq API Key", type="password", help="Enter your LangChain Groq API Key for data augmentation")
203
+
204
+
205
+
206
+ groq_api_key = st.text_input(
207
+ "LangChain Groq API Key (if you don't have one, get it from [here](https://console.groq.com/keys))",
208
+ type="password",
209
+ help="Enter your LangChain Groq API Key for data augmentation"
210
+ )
211
+ # -------------------------------
212
+ # Step D: Input Schema Template Display
213
+ # -------------------------------
214
+ st.markdown("#### Input Schema Template")
215
+ if model_type == "OpenAI Models":
216
+ st.code(
217
+ '''{
218
+ "messages": [
219
+ {"role": "system", "content": "Marv is a factual chatbot that is also sarcastic."},
220
+ {"role": "user", "content": "What's the capital of France?"},
221
+ {"role": "assistant", "content": "Paris, as if everyone doesn't know that already."}
222
+ ]
223
+ }''', language="json")
224
+ elif model_type == "Gemini Models":
225
+ st.code(
226
+ '''{
227
+ "contents": [
228
+ {"role": "user", "parts": [{"text": "What's the capital of France?"}]},
229
+ {"role": "model", "parts": [{"text": "Paris, as if everyone doesn't know that already."}]}
230
+ ]
231
+ }''', language="json")
232
+ else:
233
+ st.code(
234
+ '''{
235
+ "messages": [
236
+ {"role": "user", "content": "What's the capital of France?"},
237
+ {"role": "assistant", "content": "Paris, as if everyone doesn't know that already."}
238
+ ]
239
+ }''', language="json")
240
+
241
+ # -------------------------------
242
+ # Step E: Manual Input of Pairs (if no file uploaded)
243
+ # -------------------------------
244
+ if uploaded_file is None:
245
+ st.markdown("##### Enter at least 3 input/output pairs manually:")
246
+ num_pairs = st.number_input("Number of Pairs", min_value=3, value=3, step=1)
247
+ pair_templates = []
248
+ for i in range(num_pairs):
249
+ st.markdown(f"##### Pair {i+1}")
250
+ if model_type == "OpenAI Models":
251
+ init_template = ('''{
252
+ "messages": [
253
+ {"role": "system", "content": "''' + system_message + '''"},
254
+ {"role": "user", "content": "Enter your input text here"},
255
+ {"role": "assistant", "content": "Enter your output text here"}
256
+ ]
257
+ }''').strip()
258
+ ace_key = f"pair_{i}_{model_type}_{system_message}"
259
+ elif model_type == "Gemini Models":
260
+ init_template = ('''{
261
+ "contents": [
262
+ {"role": "user", "parts": [{"text": "Enter your input text here"}]},
263
+ {"role": "model", "parts": [{"text": "Enter your output text here"}]}
264
+ ]
265
+ }''').strip()
266
+ ace_key = f"pair_{i}_{model_type}"
267
+ else:
268
+ init_template = ('''{
269
+ "messages": [
270
+ {"role": "user", "content": "Enter your input text here"},
271
+ {"role": "assistant", "content": "Enter your output text here"}
272
+ ]
273
+ }''').strip()
274
+ ace_key = f"pair_{i}_{model_type}"
275
+
276
+ pair = st_ace(
277
+ placeholder="Edit your code here...",
278
+ value=init_template,
279
+ language="json",
280
+ theme="monokai",
281
+ key=ace_key,
282
+ height=150
283
+ )
284
+ pair_templates.append(pair)
285
+
286
+ # -------------------------------
287
+ # Step F: Augmentation Settings
288
+ # -------------------------------
289
+ target_augmented = st.number_input("Number of Augmented Pairs to Generate", min_value=5, value=5, step=1)
290
+ finetuning_goal = "Improve conversational clarity and capture subtle nuances"
291
+ st.markdown(f"**Finetuning Goal:** {finetuning_goal}")
292
+
293
+ with st.expander("Show/Hide Advanced Tuning Parameters"):
294
+ min_semantic = st.slider("Minimum Semantic Similarity", 0.0, 1.0, 0.80, 0.01)
295
+ max_semantic = st.slider("Maximum Semantic Similarity", 0.0, 1.0, 0.95, 0.01)
296
+ min_diversity = st.slider("Minimum Diversity Score", 0.0, 1.0, 0.70, 0.01)
297
+ min_fluency = st.slider("Minimum Fluency Score", 0.0, 1.0, 0.80, 0.01)
298
+
299
+ # -------------------------------
300
+ # Step G: Generate Data Button and Pipeline Execution
301
+ # -------------------------------
302
+ if st.button("Generate Data"):
303
+ if not groq_api_key.strip():
304
+ st.error("Please enter your LangChain Groq API Key to proceed.")
305
+ st.stop()
306
+
307
+ # Choose examples: from uploaded file if available; otherwise from manual input.
308
+ if uploaded_file is not None and len(uploaded_examples) >= 3:
309
+ examples = uploaded_examples
310
+ else:
311
+ examples = []
312
+ errors = []
313
+ for idx, pair in enumerate(pair_templates):
314
+ try:
315
+ record = json.loads(pair)
316
+ if model_type == "OpenAI Models":
317
+ msgs = record.get("messages", [])
318
+ if len(msgs) != 3:
319
+ raise ValueError("Expected 3 messages")
320
+ input_text = msgs[1].get("content", "").strip()
321
+ output_text = msgs[2].get("content", "").strip()
322
+ elif model_type == "Gemini Models":
323
+ contents = record.get("contents", [])
324
+ if len(contents) < 2:
325
+ raise ValueError("Expected at least 2 contents")
326
+ input_text = contents[0]["parts"][0].get("text", "").strip()
327
+ output_text = contents[1]["parts"][0].get("text", "").strip()
328
+ else:
329
+ msgs = record.get("messages", [])
330
+ if len(msgs) != 2:
331
+ raise ValueError("Expected 2 messages for this format")
332
+ input_text = msgs[0].get("content", "").strip()
333
+ output_text = msgs[1].get("content", "").strip()
334
+ if not input_text or not output_text:
335
+ raise ValueError("Input or output text is empty")
336
+ examples.append(AugmentationExample(input_text=input_text, output_text=output_text))
337
+ except Exception as e:
338
+ errors.append(f"Error in pair {idx+1}: {e}")
339
+ if errors:
340
+ st.error("There were errors in your input pairs:\n" + "\n".join(errors))
341
+ elif len(examples) < 3:
342
+ st.error("Please provide at least 3 valid pairs.")
343
+
344
+ if len(examples) >= 3:
345
+ target_model = "mixtral-8x7b-32768"
346
+ try:
347
+ config = AugmentationConfig(
348
+ target_model=target_model,
349
+ examples=examples,
350
+ finetuning_goal=finetuning_goal,
351
+ groq_api_key=groq_api_key,
352
+ system_message=system_message,
353
+ min_semantic_similarity=min_semantic,
354
+ max_semantic_similarity=max_semantic,
355
+ min_diversity_score=min_diversity,
356
+ min_fluency_score=min_fluency
357
+ )
358
+ except Exception as e:
359
+ st.error(f"Configuration error: {e}")
360
+ st.stop()
361
+
362
+ st.markdown('<p style="color: white;">Running augmentation pipeline... Please wait.</p>', unsafe_allow_html=True)
363
+
364
+ augmentor = FinetuningDataAugmentor(config)
365
+ augmentor.run_augmentation(target_count=target_augmented)
366
+
367
+ fmt = model_type.lower()
368
+ if fmt == "openai models":
369
+ output_data = augmentor.get_formatted_output(format_type="openai")
370
+ elif fmt == "gemini models":
371
+ output_data = augmentor.get_formatted_output(format_type="gemini")
372
+ elif fmt == "mistral models":
373
+ output_data = augmentor.get_formatted_output(format_type="mistral")
374
+ elif fmt == "llama models":
375
+ output_data = augmentor.get_formatted_output(format_type="llama")
376
+ else:
377
+ output_data = augmentor.get_formatted_output(format_type="openai")
378
+
379
+ st.markdown("### Augmented Data")
380
+ st.code(output_data, language="json")
381
+ st.download_button("Download train.jsonl", output_data, file_name="train.jsonl")
finetune_augmentor/__init__.py ADDED
@@ -0,0 +1 @@
 
 
1
+ from .augmentor import AugmentationExample, AugmentationConfig, FinetuningDataAugmentor, load_examples_from_file
finetune_augmentor/augmentor.py ADDED
@@ -0,0 +1,583 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ augmentor.py
3
+
4
+ This module implements a robust and scalable pipeline for finetuning data augmentation.
5
+ It supports generating augmented data in either OpenAI, Gemini, Mistral, or LLama fine‐tuning JSONL format.
6
+ Users may optionally override metric thresholds and load existing examples from a JSONL file.
7
+ The LangChain Groq API key is now provided via the configuration rather than the .env file.
8
+ """
9
+
10
+ import os
11
+ import json
12
+ import uuid
13
+ import logging
14
+ import re
15
+ import random
16
+ import ast
17
+ from typing import List, Dict, Any, Optional
18
+
19
+ # Removed dotenv load for GROQ_API_KEY since it is now provided in config
20
+ # Configure logging
21
+ logging.basicConfig(level=logging.INFO)
22
+ logger = logging.getLogger("FinetuningAugmentor")
23
+
24
+ # Environment tokens (kept for HF_TOKEN if needed)
25
+ HF_TOKEN = os.getenv("HF_TOKEN")
26
+ # GROQ_API_KEY will now be provided in the configuration
27
+
28
+ # -----------------------------
29
+ # Data Models and Preprocessing
30
+ # -----------------------------
31
+ from pydantic import BaseModel, field_validator, ValidationError
32
+
33
+ class AugmentationExample(BaseModel):
34
+ """
35
+ An input/output example for augmentation.
36
+ """
37
+ input_text: str
38
+ output_text: str
39
+
40
+ @field_validator('input_text', 'output_text')
41
+ def non_empty(cls, v: str) -> str:
42
+ if not v.strip():
43
+ raise ValueError("Text fields must be non-empty")
44
+ return v.strip()
45
+
46
+ class AugmentationConfig(BaseModel):
47
+ """
48
+ Configuration for the augmentation process.
49
+ """
50
+ target_model: str # e.g., "mixtral-8x7b-32768" or any Groq-supported model name
51
+ examples: List[AugmentationExample]
52
+ finetuning_goal: str
53
+ groq_api_key: str
54
+ system_message: Optional[str] = "Marv is a factual chatbot that is also sarcastic."
55
+ # Optional metric thresholds (if not provided, defaults are used)
56
+ min_semantic_similarity: Optional[float] = 0.80
57
+ max_semantic_similarity: Optional[float] = 0.95
58
+ min_diversity_score: Optional[float] = 0.70
59
+ min_fluency_score: Optional[float] = 0.80
60
+
61
+ @field_validator('examples')
62
+ def check_examples_length(cls, v: List[AugmentationExample]) -> List[AugmentationExample]:
63
+ if len(v) < 3:
64
+ raise ValueError("Provide at least 3 examples")
65
+ return v
66
+
67
+ class StandardExample(BaseModel):
68
+ """
69
+ Standardized format for input examples.
70
+ """
71
+ id: str
72
+ input_text: str
73
+ output_text: str
74
+ metadata: Dict[str, Any] = {}
75
+
76
+ def normalize_examples(examples: List[AugmentationExample]) -> List[StandardExample]:
77
+ """
78
+ Normalize and standardize input examples.
79
+ """
80
+ normalized = []
81
+ for ex in examples:
82
+ norm_ex = StandardExample(
83
+ id=str(uuid.uuid4()),
84
+ input_text=ex.input_text.lower(),
85
+ output_text=ex.output_text.lower(),
86
+ metadata={"original_word_count": len(ex.input_text.split())}
87
+ )
88
+ normalized.append(norm_ex)
89
+ logger.info(f"Normalized {len(normalized)} examples.")
90
+ return normalized
91
+
92
+ # -----------------------------
93
+ # Dynamic Strategy Selection
94
+ # -----------------------------
95
+ def determine_augmentation_strategy(config: AugmentationConfig) -> Dict[str, Any]:
96
+ """
97
+ Determine the augmentation strategy based on the finetuning goal.
98
+ """
99
+ goal = config.finetuning_goal.lower()
100
+ strategy = {}
101
+ if any(word in goal for word in ["dialogue", "q&a", "conversation", "chat"]):
102
+ strategy["methods"] = ["llm_paraphrasing", "back_translation"]
103
+ else:
104
+ strategy["methods"] = ["eda_synonym_replacement", "llm_paraphrasing", "synthetic_noise"]
105
+ strategy["diversity_threshold"] = 0.7
106
+ logger.info(f"Determined augmentation strategy: {strategy}")
107
+ return strategy
108
+
109
+ # -----------------------------
110
+ # Helper Functions
111
+ # -----------------------------
112
+ def extract_json(text: str) -> dict:
113
+ """
114
+ Extract the first valid JSON object from a given text.
115
+ """
116
+ json_pattern = re.compile(r'\{.*\}', re.DOTALL)
117
+ match = json_pattern.search(text)
118
+ if match:
119
+ json_str = match.group()
120
+ try:
121
+ return json.loads(json_str)
122
+ except json.JSONDecodeError as e:
123
+ raise ValueError(f"JSON decoding error: {e}")
124
+ else:
125
+ raise ValueError("No valid JSON found in text.")
126
+
127
+ def make_hashable(item: Any) -> Any:
128
+ """
129
+ Recursively convert unhashable types (lists/dicts) into hashable tuples.
130
+ """
131
+ if isinstance(item, (list, tuple)):
132
+ return tuple(make_hashable(i) for i in item)
133
+ elif isinstance(item, dict):
134
+ return tuple(sorted((k, make_hashable(v)) for k, v in item.items()))
135
+ else:
136
+ return item
137
+
138
+ def validate_jsonl_record(record: dict) -> bool:
139
+ """
140
+ Validates that the record follows the required OpenAI format:
141
+ {"messages": [{"role": "system", "content": <str>},
142
+ {"role": "user", "content": <non-empty str>},
143
+ {"role": "assistant", "content": <non-empty str>}]}
144
+ """
145
+ if "messages" not in record:
146
+ logger.error("Record missing 'messages' key.")
147
+ return False
148
+ messages = record["messages"]
149
+ if not isinstance(messages, list) or len(messages) != 3:
150
+ logger.error("Record 'messages' must be a list of 3 items.")
151
+ return False
152
+ expected_roles = ["system", "user", "assistant"]
153
+ for msg, role in zip(messages, expected_roles):
154
+ if not isinstance(msg, dict):
155
+ logger.error("Each message must be a dictionary.")
156
+ return False
157
+ if msg.get("role") != role:
158
+ logger.error(f"Expected role '{role}', but got '{msg.get('role')}'.")
159
+ return False
160
+ if "content" not in msg or not isinstance(msg["content"], str):
161
+ logger.error("Each message must have a string 'content' field.")
162
+ return False
163
+ if role in ["user", "assistant"] and not msg["content"].strip():
164
+ logger.error(f"Message for role '{role}' has empty content.")
165
+ return False
166
+ return True
167
+
168
+ def get_text(value: Any) -> str:
169
+ """
170
+ Ensure the value is returned as a string.
171
+ If it is a list, recursively return the first element.
172
+ If it is a dict and contains a "text" key, return that.
173
+ If it is a string that resembles a dict, try to parse it.
174
+ """
175
+ if isinstance(value, list):
176
+ if value:
177
+ return get_text(value[0])
178
+ return ""
179
+ elif isinstance(value, dict):
180
+ if "text" in value:
181
+ return str(value["text"])
182
+ return str(value)
183
+ elif isinstance(value, str):
184
+ val = value.strip()
185
+ if val.startswith("{") and val.endswith("}"):
186
+ try:
187
+ parsed = ast.literal_eval(val)
188
+ if isinstance(parsed, dict) and "text" in parsed:
189
+ return str(parsed["text"])
190
+ except Exception:
191
+ pass
192
+ return val
193
+ else:
194
+ return str(value)
195
+
196
+ # --- New helper: Fix content formatting ---
197
+ def fix_content(content: str) -> str:
198
+ """
199
+ If the content appears to be a Python dict (using single quotes), try to
200
+ convert it to valid JSON (with double quotes). If parsing fails, return the original content.
201
+ """
202
+ content = content.strip()
203
+ if content.startswith("{") and content.endswith("}") and "'" in content:
204
+ try:
205
+ parsed = ast.literal_eval(content)
206
+ return json.dumps(parsed)
207
+ except Exception as e:
208
+ logger.debug(f"Failed to fix content formatting: {e}")
209
+ return content
210
+
211
+ def flatten_content(content: str) -> str:
212
+ """
213
+ If content (after fixing) is a JSON string representing a dictionary,
214
+ flatten it by joining its values into a single plain-text string.
215
+ """
216
+ try:
217
+ parsed = json.loads(content)
218
+ if isinstance(parsed, dict):
219
+ # Join values in sorted order by key
220
+ values = [str(parsed[k]).strip() for k in sorted(parsed)]
221
+ return " ".join(values)
222
+ except Exception:
223
+ pass
224
+ return content
225
+
226
+ # -----------------------------
227
+ # Augmentation Generation via LangChain Groq
228
+ # -----------------------------
229
+ from langchain_groq import ChatGroq
230
+ from langchain_core.prompts import ChatPromptTemplate
231
+
232
+ def instantiate_groq_llm(model: str, groq_api_key: str) -> ChatGroq:
233
+ """
234
+ Instantiate a ChatGroq LLM with the given model name and API key.
235
+ """
236
+ return ChatGroq(
237
+ model=model,
238
+ temperature=0.7,
239
+ max_tokens=256,
240
+ timeout=30,
241
+ max_retries=2,
242
+ groq_api_key=groq_api_key
243
+ )
244
+
245
+ def generate_initial_augmentation(example: StandardExample,
246
+ config: AugmentationConfig,
247
+ strategy: Dict[str, Any]) -> dict:
248
+ """
249
+ Generate an initial candidate augmentation using an LLM prompt chain.
250
+ """
251
+ prompt_template = ChatPromptTemplate.from_messages([
252
+ (
253
+ "system",
254
+ ("You are a creative augmentation assistant that produces diverse yet semantically consistent "
255
+ "input/output pairs for finetuning tasks.")
256
+ ),
257
+ (
258
+ "human",
259
+ (
260
+ "Augment the following example using the methods: {methods}. The finetuning goal is: {finetuning_goal}.\n"
261
+ "Ensure your output is in valid JSON format with keys 'augmented_input' and 'augmented_output'.\n"
262
+ "Input: {input_text}\n"
263
+ "Output: {output_text}\n"
264
+ "Return only the JSON response."
265
+ )
266
+ )
267
+ ])
268
+ prompt_vars = {
269
+ "methods": ", ".join(strategy["methods"]),
270
+ "finetuning_goal": config.finetuning_goal,
271
+ "input_text": example.input_text,
272
+ "output_text": example.output_text
273
+ }
274
+ chain = prompt_template | instantiate_groq_llm(config.target_model, config.groq_api_key)
275
+ ai_msg = chain.invoke(prompt_vars)
276
+ logger.info(f"Initial augmentation for {example.id}: {ai_msg.content.strip()}")
277
+ return extract_json(ai_msg.content.strip())
278
+
279
+ def refine_augmentation(candidate: dict,
280
+ example: StandardExample,
281
+ config: AugmentationConfig,
282
+ strategy: Dict[str, Any]) -> dict:
283
+ """
284
+ Refine a candidate augmentation using a second LLM prompt chain.
285
+ """
286
+ refinement_template = ChatPromptTemplate.from_messages([
287
+ (
288
+ "system",
289
+ "You are an expert data augmentation advisor who refines candidate augmentations to maximize semantic accuracy, diversity, and clarity."
290
+ ),
291
+ (
292
+ "human",
293
+ (
294
+ "Review the candidate augmentation for the following input/output pair and refine it if needed.\n"
295
+ "Finetuning Goal: {finetuning_goal}\n"
296
+ "Original Input: {input_text}\n"
297
+ "Original Output: {output_text}\n"
298
+ "Candidate Augmentation: {candidate}\n"
299
+ "Return a refined augmentation in valid JSON format with keys 'augmented_input' and 'augmented_output' only."
300
+ )
301
+ )
302
+ ])
303
+ refinement_vars = {
304
+ "finetuning_goal": config.finetuning_goal,
305
+ "input_text": example.input_text,
306
+ "output_text": example.output_text,
307
+ "candidate": json.dumps(candidate)
308
+ }
309
+ chain = refinement_template | instantiate_groq_llm(config.target_model, config.groq_api_key)
310
+ ai_msg = chain.invoke(refinement_vars)
311
+ try:
312
+ refined = extract_json(ai_msg.content.strip())
313
+ logger.info(f"Refined augmentation for {example.id}: {refined}")
314
+ return refined
315
+ except Exception as e:
316
+ logger.error(f"Refinement failed for {example.id}: {e}. Using original candidate.")
317
+ return candidate
318
+
319
+ def calculate_metrics(augmentation: dict, original: StandardExample) -> dict:
320
+ """
321
+ Simulate metric calculations for the candidate augmentation.
322
+ """
323
+ semantic_similarity = random.uniform(0.78, 0.97)
324
+ diversity_score = random.uniform(0.65, 0.9)
325
+ fluency_score = random.uniform(0.80, 0.95)
326
+ metrics = {
327
+ "semantic_similarity": semantic_similarity,
328
+ "diversity_score": diversity_score,
329
+ "fluency_score": fluency_score
330
+ }
331
+ logger.info(f"Metrics for candidate of {original.id}: {metrics}")
332
+ return metrics
333
+
334
+ def metrics_valid(metrics: dict, config: AugmentationConfig) -> bool:
335
+ """
336
+ Validate metric thresholds using configuration values.
337
+ """
338
+ if metrics["semantic_similarity"] < config.min_semantic_similarity or metrics["semantic_similarity"] > config.max_semantic_similarity:
339
+ return False
340
+ if metrics["diversity_score"] < config.min_diversity_score:
341
+ return False
342
+ if metrics["fluency_score"] < config.min_fluency_score:
343
+ return False
344
+ return True
345
+
346
+ def quality_check(augmentation: Dict[str, Any], config: AugmentationConfig) -> bool:
347
+ """
348
+ Simulate an LLM-based QA check.
349
+ """
350
+ qa_prompt = (
351
+ f"Verify that the following augmentation preserves the intended meaning and style for the input/output pair "
352
+ f"given the finetuning goal '{config.finetuning_goal}':\n"
353
+ f"{augmentation['augmentation']}\n"
354
+ "Answer 'yes' if valid, otherwise 'no'."
355
+ )
356
+ logger.debug(f"QA Prompt: {qa_prompt}")
357
+ return True # Simulation: always passes
358
+
359
+ def generate_augmentations(normalized_examples: List[StandardExample],
360
+ config: AugmentationConfig,
361
+ strategy: Dict[str, Any],
362
+ target_count: int = 50) -> List[Dict[str, Any]]:
363
+ """
364
+ Repeatedly generate candidate augmentations until at least target_count valid candidates are collected.
365
+ """
366
+ augmented_candidates = []
367
+ attempts = 0
368
+ max_attempts = 100 # Safety valve
369
+ while len(augmented_candidates) < target_count and attempts < max_attempts:
370
+ for ex in normalized_examples:
371
+ try:
372
+ candidate = generate_initial_augmentation(ex, config, strategy)
373
+ refined_candidate = refine_augmentation(candidate, ex, config, strategy)
374
+ metrics = calculate_metrics(refined_candidate, ex)
375
+ if not metrics_valid(metrics, config):
376
+ logger.info(f"Candidate for {ex.id} rejected by metrics: {metrics}")
377
+ continue
378
+ if quality_check({"augmentation": refined_candidate}, config):
379
+ full_candidate = {
380
+ "original_id": ex.id,
381
+ "augmentation": refined_candidate,
382
+ "strategy": strategy,
383
+ "metrics": metrics
384
+ }
385
+ augmented_candidates.append(full_candidate)
386
+ logger.info(f"Accepted candidate for {ex.id} (Total accepted: {len(augmented_candidates)})")
387
+ if len(augmented_candidates) >= target_count:
388
+ break
389
+ except Exception as e:
390
+ logger.error(f"Error generating augmentation for {ex.id}: {e}")
391
+ attempts += 1
392
+ if len(augmented_candidates) < target_count:
393
+ logger.warning(f"Only {len(augmented_candidates)} candidates generated after {attempts} attempts.")
394
+ return augmented_candidates
395
+
396
+ def deduplicate_augmentations(augmentations: List[Dict[str, Any]]) -> List[Dict[str, Any]]:
397
+ """
398
+ Remove duplicate augmentations based on hashable keys.
399
+ """
400
+ seen = set()
401
+ unique_aug = []
402
+ for aug in augmentations:
403
+ key = (make_hashable(aug["augmentation"].get("augmented_input")),
404
+ make_hashable(aug["augmentation"].get("augmented_output")))
405
+ if key not in seen:
406
+ seen.add(key)
407
+ unique_aug.append(aug)
408
+ logger.info(f"Deduplicated to {len(unique_aug)} unique augmentations.")
409
+ return unique_aug
410
+
411
+ def format_for_openai(augmentations: List[Dict[str, Any]], system_message: str) -> str:
412
+ """
413
+ Format augmentations in OpenAI fine-tuning JSONL format.
414
+ """
415
+ output_lines = []
416
+ sys_msg = system_message.strip() if system_message and system_message.strip() else ""
417
+ for aug in augmentations:
418
+ user_val = flatten_content(fix_content(get_text(aug["augmentation"].get("augmented_input", "")).strip()))
419
+ assistant_val = flatten_content(fix_content(get_text(aug["augmentation"].get("augmented_output", "")).strip()))
420
+ record = {
421
+ "messages": [
422
+ {"role": "system", "content": sys_msg},
423
+ {"role": "user", "content": user_val},
424
+ {"role": "assistant", "content": assistant_val}
425
+ ]
426
+ }
427
+ if validate_jsonl_record(record):
428
+ output_lines.append(json.dumps(record))
429
+ else:
430
+ logger.error(f"Record validation failed: {record}")
431
+ logger.info(f"Formatted {len(output_lines)} records in OpenAI fine-tuning format.")
432
+ return "\n".join(output_lines)
433
+
434
+ def format_for_gemini(augmentations: List[Dict[str, Any]]) -> str:
435
+ """
436
+ Format augmentations in Gemini fine-tuning JSONL format.
437
+ """
438
+ output_lines = []
439
+ for aug in augmentations:
440
+ user_val = flatten_content(fix_content(get_text(aug["augmentation"].get("augmented_input", "")).strip()))
441
+ assistant_val = flatten_content(fix_content(get_text(aug["augmentation"].get("augmented_output", "")).strip()))
442
+ record = {
443
+ "contents": [
444
+ {"role": "user", "parts": [{"text": user_val}]},
445
+ {"role": "model", "parts": [{"text": assistant_val}]}
446
+ ]
447
+ }
448
+ if user_val and assistant_val:
449
+ output_lines.append(json.dumps(record))
450
+ else:
451
+ logger.error(f"Gemini record validation failed: {record}")
452
+ logger.info(f"Formatted {len(output_lines)} records in Gemini fine-tuning format.")
453
+ return "\n".join(output_lines)
454
+
455
+ def format_for_common(augmentations: List[Dict[str, Any]]) -> str:
456
+ """
457
+ Format augmentations in a common JSONL format for both Mistral and LLama.
458
+ """
459
+ output_lines = []
460
+ for aug in augmentations:
461
+ user_val = flatten_content(fix_content(get_text(aug["augmentation"].get("augmented_input", "")).strip()))
462
+ assistant_val = flatten_content(fix_content(get_text(aug["augmentation"].get("augmented_output", "")).strip()))
463
+ record = {
464
+ "messages": [
465
+ {"role": "user", "content": user_val},
466
+ {"role": "assistant", "content": assistant_val}
467
+ ]
468
+ }
469
+ if user_val and assistant_val:
470
+ output_lines.append(json.dumps(record))
471
+ else:
472
+ logger.error(f"Common format record validation failed: {record}")
473
+ logger.info(f"Formatted {len(output_lines)} records in common JSONL format for Mistral/LLama.")
474
+ return "\n".join(output_lines)
475
+
476
+ def format_for_mistral(augmentations: List[Dict[str, Any]]) -> str:
477
+ """
478
+ Format augmentations in Mistral fine-tuning JSONL format.
479
+ Uses the common format.
480
+ """
481
+ return format_for_common(augmentations)
482
+
483
+ def format_for_llama(augmentations: List[Dict[str, Any]]) -> str:
484
+ """
485
+ Format augmentations in LLama fine-tuning JSONL format.
486
+ Uses the common format.
487
+ """
488
+ return format_for_common(augmentations)
489
+
490
+ # -----------------------------
491
+ # Optional: Load Existing Examples from JSONL
492
+ # -----------------------------
493
+ def load_examples_from_file(file_path: str, format_type: str = "openai") -> List[AugmentationExample]:
494
+ """
495
+ Load input/output pairs from a JSONL file.
496
+ """
497
+ examples = []
498
+ try:
499
+ with open(file_path, "r") as f:
500
+ for line in f:
501
+ line = line.strip()
502
+ if not line:
503
+ continue
504
+ record = json.loads(line)
505
+ if format_type.lower() == "openai":
506
+ msgs = record.get("messages", [])
507
+ if len(msgs) == 3:
508
+ user_text = msgs[1].get("content", "").strip()
509
+ assistant_text = msgs[2].get("content", "").strip()
510
+ if user_text and assistant_text:
511
+ examples.append(AugmentationExample(input_text=user_text, output_text=assistant_text))
512
+ elif format_type.lower() == "gemini":
513
+ contents = record.get("contents", [])
514
+ if len(contents) >= 2:
515
+ user_parts = contents[0].get("parts", [])
516
+ model_parts = contents[1].get("parts", [])
517
+ user_text = get_text(user_parts[0]) if user_parts else ""
518
+ assistant_text = get_text(model_parts[0]) if model_parts else ""
519
+ if user_text and assistant_text:
520
+ examples.append(AugmentationExample(input_text=user_text, output_text=assistant_text))
521
+ except Exception as e:
522
+ logger.error(f"Error loading examples from file: {e}")
523
+ logger.info(f"Loaded {len(examples)} examples from {file_path}")
524
+ return examples
525
+
526
+ # -----------------------------
527
+ # Pipeline Class
528
+ # -----------------------------
529
+ class FinetuningDataAugmentor:
530
+ """
531
+ Encapsulates the entire augmentation pipeline.
532
+ """
533
+ def __init__(self, config: AugmentationConfig):
534
+ self.config = config
535
+ self.normalized_examples = normalize_examples(config.examples)
536
+ self.strategy = determine_augmentation_strategy(config)
537
+ self.augmentations = []
538
+
539
+ def run_augmentation(self, target_count: int = 50) -> List[Dict[str, Any]]:
540
+ """
541
+ Generate candidate augmentations, deduplicate, and store results.
542
+ """
543
+ logger.info("Starting augmentation generation via LangChain Groq...")
544
+ candidates = generate_augmentations(self.normalized_examples, self.config, self.strategy, target_count=target_count)
545
+ logger.info(f"Generated {len(candidates)} candidate augmentations before deduplication.")
546
+ unique_candidates = deduplicate_augmentations(candidates)
547
+ logger.info(f"{len(unique_candidates)} unique augmentations after deduplication.")
548
+ self.augmentations = unique_candidates
549
+ return unique_candidates
550
+
551
+ def get_formatted_output(self, format_type: str = "openai") -> str:
552
+ """
553
+ Return the final augmented data in the desired finetuning format.
554
+ """
555
+ fmt = format_type.lower()
556
+ if fmt == "openai":
557
+ return format_for_openai(self.augmentations, self.config.system_message)
558
+ elif fmt == "gemini":
559
+ return format_for_gemini(self.augmentations)
560
+ elif fmt == "mistral":
561
+ return format_for_mistral(self.augmentations)
562
+ elif fmt == "llama":
563
+ return format_for_llama(self.augmentations)
564
+ else:
565
+ logger.error(f"Unknown format type: {format_type}. Defaulting to OpenAI format.")
566
+ return format_for_openai(self.augmentations, self.config.system_message)
567
+
568
+ def save_to_file(self, filename: str = "train.jsonl") -> None:
569
+ """
570
+ Save the formatted augmented data to a file.
571
+ """
572
+ output = self.get_formatted_output()
573
+ with open(filename, "w") as f:
574
+ f.write(output)
575
+ logger.info(f"Final augmented data saved to {filename}")
576
+
577
+ def run_review_interface(self) -> None:
578
+ """
579
+ Launch the interactive review interface.
580
+ """
581
+ from streamlit import runtime
582
+ formatted_data = self.get_formatted_output()
583
+ launch_review_app(formatted_data)
requirements.txt ADDED
@@ -0,0 +1,6 @@
 
 
 
 
 
 
 
1
+ # List of dependencies
2
+ streamlit==1.42.2
3
+ langchain-groq==0.2.4
4
+ langchain-core==0.3.37
5
+ streamlit-ace==0.1.1
6
+ dotenv