Product-doc-classifier / utils /util_classifier.py
mlkorra's picture
Add app, utils classifier
a20a7ca verified
raw
history blame
10.6 kB
from transformers import AutoTokenizer, AutoModelForSequenceClassification
import torch
import numpy as np
from sklearn.feature_extraction.text import TfidfVectorizer
import joblib
import pandas as pd
from datetime import datetime
import logging
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)
class TextClassificationPipeline:
def __init__(self, model_path='./models', method='bertbased'):
"""
Initialize the classification pipeline
Args:
model_path: Path to saved models
method: 'bertbased' or 'baseline'
"""
try:
self.method = method
if method == 'bertbased':
logger.info("Loading BERT model...")
self.tokenizer = AutoTokenizer.from_pretrained('bert-base-uncased')
self.model = AutoModelForSequenceClassification.from_pretrained(f"{model_path}/bert-model")
self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
self.model.to(self.device)
self.model.eval()
logger.info(f"BERT model loaded successfully. Using device: {self.device}")
else:
logger.info("Loading baseline model...")
self.tfidf = joblib.load(f"{model_path}/baseline-model/tfidf_vectorizer.pkl")
self.baseline_model = joblib.load(f"{model_path}/baseline-model/baseline_model.pkl")
logger.info("Baseline model loaded successfully")
# Load label encoder for both methods
self.label_encoder = joblib.load(f"{model_path}/label_encoder.pkl")
except Exception as e:
logger.error(f"Error initializing model: {str(e)}")
raise
# def preprocess_text(self, text):
# """Clean and preprocess text"""
# if isinstance(text, str):
# # Basic cleaning
# text = text.strip()
# text = ' '.join(text.split()) # Remove extra whitespace
# return text
# return text
def preprocess_text(self, text):
"""Clean and preprocess text"""
if isinstance(text, str):
# Basic cleaning
text = text.strip()
text = ' '.join(text.split()) # Remove extra whitespace
# Capitalize first letter to match training data format
text = text.title() # This will capitalize first letter of each word
return text
return text
def preprocess(self, text):
"""
Preprocess the input text based on method
"""
try:
# Clean text first
text = self.preprocess_text(text)
if self.method == 'bertbased':
# BERT preprocessing
encodings = self.tokenizer(
text,
truncation=True,
padding=True,
max_length=512,
return_tensors='pt'
)
encodings = {k: v.to(self.device) for k, v in encodings.items()}
return encodings
else:
# Baseline preprocessing
return self.tfidf.transform([text] if isinstance(text, str) else text)
except Exception as e:
logger.error(f"Error in preprocessing: {str(e)}")
raise
def predict(self, text, return_probability=False):
"""
Predict using either BERT or baseline model
Args:
text: Input text or list of texts
return_probability: Whether to return probability scores
Returns:
Predictions with metadata
"""
try:
# Handle both single string and list of strings
if isinstance(text, str):
text = [text]
# Preprocess
inputs = self.preprocess(text)
if self.method == 'bertbased':
# BERT predictions
with torch.no_grad():
outputs = self.model(**inputs)
probabilities = torch.softmax(outputs.logits, dim=-1)
predictions = torch.argmax(probabilities, dim=-1)
predictions = predictions.cpu().numpy()
probabilities = probabilities.cpu().numpy()
else:
# Baseline predictions
predictions = self.baseline_model.predict(inputs)
probabilities = self.baseline_model.predict_proba(inputs)
# Convert numeric predictions to original labels
predicted_labels = self.label_encoder.inverse_transform(predictions)
# Ensure consistent casing with training data
predicted_labels = [label.title() for label in predicted_labels]
if return_probability:
results = []
for t, label, prob, probs in zip(text, predicted_labels,
probabilities.max(axis=1),
probabilities):
result = {
'text': t[:200] + '...' if len(t) > 200 else t,
'predicted_label': label.title(), # Ensure consistent casing
'confidence': float(prob),
'model_type': self.method,
'probabilities': {
self.label_encoder.inverse_transform([i])[0].title(): float(p) # Consistent casing
for i, p in enumerate(probs)
},
# ... rest of the result dictionary ...
}
results.append(result)
return results[0] if len(text) == 1 else results
return predicted_labels[0] if len(text) == 1 else predicted_labels
except Exception as e:
logger.error(f"Error in prediction: {str(e)}")
raise
def predict_old(self, text, return_probability=False):
"""
Predict using either BERT or baseline model
Args:
text: Input text or list of texts
return_probability: Whether to return probability scores
Returns:
Predictions with metadata
"""
try:
# Handle both single string and list of strings
if isinstance(text, str):
text = [text]
# Preprocess
inputs = self.preprocess(text)
if self.method == 'bertbased':
# BERT predictions
with torch.no_grad():
outputs = self.model(**inputs)
probabilities = torch.softmax(outputs.logits, dim=-1)
predictions = torch.argmax(probabilities, dim=-1)
predictions = predictions.cpu().numpy()
probabilities = probabilities.cpu().numpy()
else:
# Baseline predictions
predictions = self.baseline_model.predict(inputs)
probabilities = self.baseline_model.predict_proba(inputs)
# Convert numeric predictions to original labels
predicted_labels = self.label_encoder.inverse_transform(predictions)
if return_probability:
results = []
for t, label, prob, probs in zip(text, predicted_labels,
probabilities.max(axis=1),
probabilities):
# Create detailed result dictionary
result = {
'text': t[:200] + '...' if len(t) > 200 else t, # Truncate long text
'predicted_label': label,
'confidence': float(prob),
'model_type': self.method,
'probabilities': {
self.label_encoder.inverse_transform([i])[0]: float(p)
for i, p in enumerate(probs)
},
'timestamp': datetime.now().isoformat(),
'metadata': {
'model_name': 'BERT' if self.method == 'bertbased' else 'Baseline',
'text_length': len(t),
'preprocessing_steps': ['cleaning', 'tokenization']
}
}
results.append(result)
return results[0] if len(text) == 1 else results
return predicted_labels[0] if len(text) == 1 else predicted_labels
except Exception as e:
logger.error(f"Error in prediction: {str(e)}")
raise
def get_model_info(self):
"""Return model information"""
return {
'model_type': self.method,
'model_name': 'BERT' if self.method == 'bertbased' else 'Baseline',
'device': str(self.device) if self.method == 'bertbased' else 'CPU',
'max_sequence_length': 512 if self.method == 'bertbased' else None,
'number_of_classes': len(self.label_encoder.classes_),
'classes': list(self.label_encoder.classes_)
}
def load_and_process_pdf(url_or_file):
"""
Load and process PDF from URL or file
Returns extracted text
"""
try:
# Your PDF processing code here
# Return extracted text
pass
except Exception as e:
logger.error(f"Error processing PDF: {str(e)}")
raise
# Example usage
if __name__ == "__main__":
# Test the pipeline
classifier = TextClassificationPipeline()
# Test single prediction
text = "Example construction document text"
result = classifier.predict(text, return_probability=True)
print("\nSingle Prediction Result:")
print(result)
# Test batch prediction
texts = ["First document", "Second document"]
results = classifier.predict(texts, return_probability=True)
print("\nBatch Prediction Results:")
for result in results:
print(f"\nText: {result['text']}")
print(f"Prediction: {result['predicted_label']}")
print(f"Confidence: {result['confidence']:.4f}")