import pandas as pd import together import json import os from sklearn.model_selection import train_test_split class ModelTrainer: def __init__(self, api_key): # Initialize Together AI with your API key together.api_key = api_key def prepare_data(self, csv_path): """Prepare data from CSV file for fine-tuning""" # Read the CSV file print("Loading CSV file...") df = pd.read_csv(csv_path, encoding='utf-8') # Print column names to help debug print("Available columns in CSV:", df.columns.tolist()) # Identify text and label columns text_column = self._get_text_column(df) label_column = self._get_label_column(df) print(f"Using '{text_column}' as text column and '{label_column}' as label column") # Split data into train and validation sets train_df, val_df = train_test_split(df, test_size=0.2, random_state=42) print(f"Training samples: {len(train_df)}, Validation samples: {len(val_df)}") # Convert to Together AI format train_data = self._convert_to_together_format(train_df, text_column, label_column) val_data = self._convert_to_together_format(val_df, text_column, label_column) # Save to jsonl files self._save_jsonl(train_data, 'training_data.jsonl') self._save_jsonl(val_data, 'validation_data.jsonl') return 'training_data.jsonl', 'validation_data.jsonl' def _get_text_column(self, df): """Identify the text column from common names""" text_column_options = ['text', 'message', 'content', 'Text', 'MESSAGE', 'CONTENT'] for col in text_column_options: if col in df.columns: return col # If no match found, use the first column return df.columns[0] def _get_label_column(self, df): """Identify the label column from common names""" label_column_options = ['label', 'Label', 'class', 'Class', 'target', 'Target'] for col in label_column_options: if col in df.columns: return col # If no match found, use the last column return df.columns[-1] def _convert_to_together_format(self, df, text_column, label_column): """Convert DataFrame to Together AI format""" formatted_data = [] for _, row in df.iterrows(): prompt = ( "Analyze the following text and determine if it's a scam or not.\n\n" f"Text: {row[text_column]}\n\n" "Is this a scam? " ) # Convert label to yes/no response completion = "Yes" if int(row[label_column]) == 1 else "No" formatted_data.append({ "prompt": prompt, "completion": completion }) return formatted_data def _save_jsonl(self, data, filename): """Save data in JSONL format""" with open(filename, 'w', encoding='utf-8') as f: for item in data: f.write(json.dumps(item) + '\n') def upload_file(self, file_path): """Upload file to Together AI""" print(f"Uploading {file_path}...") result = together.Files.upload(file_path) print(f"File uploaded with ID: {result['id']}") return result['id'] def create_fine_tuning_job(self, training_file_id, validation_file_id): """Create and start fine-tuning job""" job_params = { "training_file": training_file_id, "validation_file": validation_file_id, "model": "togethercomputer/RedPajama-INCITE-7B-Chat", "n_epochs": 3, "batch_size": 4, "learning_rate": 0.00001, "suffix": "scam_detector_v1" } result = together.FineTune.create(**job_params) return result['id'] def main(): # Initialize trainer with your API key API_KEY = "ebcbfe89e5c1cdf5851dca154326d4bf3303fa6361032b3973139d6a84a5f247" trainer = ModelTrainer(api_key=API_KEY) try: # Try to mount Google Drive if in Colab from google.colab import drive print("Mounting Google Drive...") drive.mount('/content/drive') csv_path = '/content/drive/MyDrive/scam4.csv' except: # If not in Colab, use local path csv_path = 'scam4.csv' try: # Prepare and upload data files train_file, val_file = trainer.prepare_data(csv_path) # Upload files training_file_id = trainer.upload_file(train_file) validation_file_id = trainer.upload_file(val_file) # Start fine-tuning print("\nStarting fine-tuning job...") job_id = trainer.create_fine_tuning_job(training_file_id, validation_file_id) print(f"Fine-tuning job created with ID: {job_id}") # Clean up temporary files os.remove(train_file) os.remove(val_file) print("\nTemporary files cleaned up") except Exception as e: print(f"An error occurred: {str(e)}") # Print more detailed error information import traceback traceback.print_exc() if __name__ == "__main__": main()