J.A.R.V.I.S / train_model.py
varun324242's picture
Upload folder using huggingface_hub
fe2a0f2 verified
import pandas as pd
import together
import json
import os
from sklearn.model_selection import train_test_split
class ModelTrainer:
def __init__(self, api_key):
# Initialize Together AI with your API key
together.api_key = api_key
def prepare_data(self, csv_path):
"""Prepare data from CSV file for fine-tuning"""
# Read the CSV file
print("Loading CSV file...")
df = pd.read_csv(csv_path, encoding='utf-8')
# Print column names to help debug
print("Available columns in CSV:", df.columns.tolist())
# Identify text and label columns
text_column = self._get_text_column(df)
label_column = self._get_label_column(df)
print(f"Using '{text_column}' as text column and '{label_column}' as label column")
# Split data into train and validation sets
train_df, val_df = train_test_split(df, test_size=0.2, random_state=42)
print(f"Training samples: {len(train_df)}, Validation samples: {len(val_df)}")
# Convert to Together AI format
train_data = self._convert_to_together_format(train_df, text_column, label_column)
val_data = self._convert_to_together_format(val_df, text_column, label_column)
# Save to jsonl files
self._save_jsonl(train_data, 'training_data.jsonl')
self._save_jsonl(val_data, 'validation_data.jsonl')
return 'training_data.jsonl', 'validation_data.jsonl'
def _get_text_column(self, df):
"""Identify the text column from common names"""
text_column_options = ['text', 'message', 'content', 'Text', 'MESSAGE', 'CONTENT']
for col in text_column_options:
if col in df.columns:
return col
# If no match found, use the first column
return df.columns[0]
def _get_label_column(self, df):
"""Identify the label column from common names"""
label_column_options = ['label', 'Label', 'class', 'Class', 'target', 'Target']
for col in label_column_options:
if col in df.columns:
return col
# If no match found, use the last column
return df.columns[-1]
def _convert_to_together_format(self, df, text_column, label_column):
"""Convert DataFrame to Together AI format"""
formatted_data = []
for _, row in df.iterrows():
prompt = (
"Analyze the following text and determine if it's a scam or not.\n\n"
f"Text: {row[text_column]}\n\n"
"Is this a scam? "
)
# Convert label to yes/no response
completion = "Yes" if int(row[label_column]) == 1 else "No"
formatted_data.append({
"prompt": prompt,
"completion": completion
})
return formatted_data
def _save_jsonl(self, data, filename):
"""Save data in JSONL format"""
with open(filename, 'w', encoding='utf-8') as f:
for item in data:
f.write(json.dumps(item) + '\n')
def upload_file(self, file_path):
"""Upload file to Together AI"""
print(f"Uploading {file_path}...")
result = together.Files.upload(file_path)
print(f"File uploaded with ID: {result['id']}")
return result['id']
def create_fine_tuning_job(self, training_file_id, validation_file_id):
"""Create and start fine-tuning job"""
job_params = {
"training_file": training_file_id,
"validation_file": validation_file_id,
"model": "togethercomputer/RedPajama-INCITE-7B-Chat",
"n_epochs": 3,
"batch_size": 4,
"learning_rate": 0.00001,
"suffix": "scam_detector_v1"
}
result = together.FineTune.create(**job_params)
return result['id']
def main():
# Initialize trainer with your API key
API_KEY = "ebcbfe89e5c1cdf5851dca154326d4bf3303fa6361032b3973139d6a84a5f247"
trainer = ModelTrainer(api_key=API_KEY)
try:
# Try to mount Google Drive if in Colab
from google.colab import drive
print("Mounting Google Drive...")
drive.mount('/content/drive')
csv_path = '/content/drive/MyDrive/scam4.csv'
except:
# If not in Colab, use local path
csv_path = 'scam4.csv'
try:
# Prepare and upload data files
train_file, val_file = trainer.prepare_data(csv_path)
# Upload files
training_file_id = trainer.upload_file(train_file)
validation_file_id = trainer.upload_file(val_file)
# Start fine-tuning
print("\nStarting fine-tuning job...")
job_id = trainer.create_fine_tuning_job(training_file_id, validation_file_id)
print(f"Fine-tuning job created with ID: {job_id}")
# Clean up temporary files
os.remove(train_file)
os.remove(val_file)
print("\nTemporary files cleaned up")
except Exception as e:
print(f"An error occurred: {str(e)}")
# Print more detailed error information
import traceback
traceback.print_exc()
if __name__ == "__main__":
main()