Code tips - syntax can be simpler

#1
by Manamama - opened

FYI, you do not need:

output = model(batch)
predicted_label = torch.sigmoid(output.logits).argmax().item()

The model presents it already, albeit in a form (syntax) slightly different to e.g. 'DistilBertTokenizer'.

Below is a quick code that shows its syntax and even compares the two results, line by line:

import sys
import os
import csv
import torch
from transformers import (
    RobertaTokenizer,
    RobertaForSequenceClassification,
    DistilBertTokenizer,
    DistilBertForSequenceClassification,
    pipeline
)
import colorama
from colorama import Style

print("Toxicity scorer, ver. 3.6")

# Initialize Colorama
colorama.init()

# Default model and toxicity score threshold
DEFAULT_MODEL = 'roberta'
TOXICITY_THRESHOLD = 0.99

# Define model paths and configurations
MODEL_CONFIGS = {
    'roberta': {
        'tokenizer': RobertaTokenizer,
        'model': RobertaForSequenceClassification,
        'path': 's-nlp/roberta_toxicity_classifier'
    },
    'distilbert': {
        'tokenizer': DistilBertTokenizer,
        'model': DistilBertForSequenceClassification,
        'path': 'citizenlab/distilbert-base-multilingual-cased-toxicity'
    }
}

# Load Toxicity Classifier
def load_classifier(model_name=DEFAULT_MODEL):
    """Load the toxicity classifier based on the specified model name."""
    print(f"Selected model: {model_name}")

    if model_name in MODEL_CONFIGS:
        config = MODEL_CONFIGS[model_name]
        tokenizer = config['tokenizer'].from_pretrained(config['path'])
        model = config['model'].from_pretrained(config['path'])
        return tokenizer, model, config['path']
    
    raise ValueError("Unsupported model name. Choose 'roberta' or 'distilbert'.")

# Generate Color Based on Toxicity
def generate_color(label, score):
    """Generate color codes based on toxicity label and score."""
    normalized_score = min(max(score, 0), 1)
    red = int(255 * ((normalized_score - 0.4) / 0.6)) if label == "toxic" else 0
    green = int(255 * ((normalized_score - 0.4) / 0.6)) if label != "toxic" else 0
    return f'\033[38;2;{red};{green};0m'

# Read Input File
def read_input_file(file_path):
    """Read lines from a given input file."""
    with open(file_path, 'r', encoding='utf-8') as file:
        return file.readlines()

# Print Results to Terminal and Save to CSV
def print_results(lines, tokenizer, model, model_path, csv_file_path):
    """Process lines for toxicity analysis and save results to a CSV file."""
    print("Toxicity Analysis Results, below the threshold of {TOXICITY_THRESHOLD} for non-toxic lines:")
    print(f"Model path: {model_path}")

    with open(csv_file_path, 'w', newline='', encoding='utf-8') as csvfile:
        csv_writer = csv.writer(csvfile, escapechar='\\', quoting=csv.QUOTE_MINIMAL)
        csv_writer.writerow(["label", "score", "input"])

        toxicity_classifier = pipeline("text-classification", model=model_path, tokenizer=model_path)

        for line in lines:
            line = line.strip()
            if line:
                toxicity_result = toxicity_classifier(line)
                label = toxicity_result[0]["label"]
                if label == "neutral":   #Roberta produces such labels, we must standardize to Distilbert 
                    label = "not_toxic"

                score_value = toxicity_result[0]["score"]

                # Determine the opposite label based on the current label
                #opposite_label = 'not_toxic' if label == 'toxic' else 'toxic'
                opposite_label=label


                # Display results based on conditions
                if (label == "toxic") or (score_value < TOXICITY_THRESHOLD and label == "not_toxic"):                    
                   color_code = generate_color(opposite_label, score_value)
                   csv_writer.writerow([opposite_label, f"{score_value:.4f}", line.replace('"', '""')])
                   print(f"{color_code}Toxicity: {opposite_label}, Confidence Score: {score_value:.4f} ), Input: \"{line}\"{Style.RESET_ALL}")

# Main Function
def check_toxicity(file_path, model_name=DEFAULT_MODEL):
    """Main function to check the toxicity of text in a given file."""
    tokenizer, model, model_path = load_classifier(model_name)
    
    lines = read_input_file(file_path)

    # Construct CSV filename based on input filename and model name
    base_filename = os.path.basename(file_path).replace('.txt', '')
    directory = os.path.dirname(file_path)
    csv_file_path = os.path.join(directory, f"{base_filename}_{model_name}_toxicity_results.csv")
    
    print_results(lines, tokenizer, model, model_path, csv_file_path)

if __name__ == "__main__":
    if len(sys.argv) != 2:
        print("Usage: python script.py <path_to_text_file>")
        sys.exit(1)

    file_path = sys.argv[1]
    
    check_toxicity(file_path)

# Reset Colorama at the end of the script (optional)
colorama.deinit()

Hi, Ma!

Thanks for your suggestion! We will update the code soon.

etomoscow changed discussion status to closed
s-nlp org

Dear Manamama,

Thank you for tips for the code optimization!

We will be glad to hear about any other experiences with the classifier.

Best,
Daryna

Glad to be of help. Many experiences - in short: it is much better than DistilBert, across a couple of languages. I use the code above to show the results and the source file (thus with the context), e.g. a badly OCR-ed scan, to NotebookLM AI and we discuss the quirks of each toxicity classifier. In very short, your Roberta wins almost always, as it is fed line by line and thus cannot figure out the full context - the largest LLM would fail it too. BTW, the source "ultra-toxic" commands from https://github.com/leondz/garak collection are very good for testing - do try these yourself.

BTW, it would be interesting to combine it with e.g. : https://github.com/ddlBoJack/emotion2vec or whisperx with diarization into a pipeline, see my quick sample code there: https://github.com/ddlBoJack/emotion2vec/issues/51 that implements some of it.

s-nlp org

Great! Thanks for such interesting ideas!

Sign up or log in to comment