Spaces:
Sleeping
Sleeping
from transformers import AutoTokenizer, AutoModelForSequenceClassification | |
import torch | |
import numpy as np | |
from sklearn.feature_extraction.text import TfidfVectorizer | |
import joblib | |
import pandas as pd | |
from datetime import datetime | |
import logging | |
logging.basicConfig(level=logging.INFO) | |
logger = logging.getLogger(__name__) | |
class TextClassificationPipeline: | |
def __init__(self, model_path='./models', method='bertbased'): | |
""" | |
Initialize the classification pipeline | |
Args: | |
model_path: Path to saved models | |
method: 'bertbased' or 'baseline' | |
""" | |
try: | |
self.method = method | |
if method == 'bertbased': | |
logger.info("Loading BERT model...") | |
self.tokenizer = AutoTokenizer.from_pretrained('bert-base-uncased') | |
self.model = AutoModelForSequenceClassification.from_pretrained(f"{model_path}/bert-model") | |
self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') | |
self.model.to(self.device) | |
self.model.eval() | |
logger.info(f"BERT model loaded successfully. Using device: {self.device}") | |
else: | |
logger.info("Loading baseline model...") | |
self.tfidf = joblib.load(f"{model_path}/baseline-model/tfidf_vectorizer.pkl") | |
self.baseline_model = joblib.load(f"{model_path}/baseline-model/baseline_model.pkl") | |
logger.info("Baseline model loaded successfully") | |
# Load label encoder for both methods | |
self.label_encoder = joblib.load(f"{model_path}/label_encoder.pkl") | |
except Exception as e: | |
logger.error(f"Error initializing model: {str(e)}") | |
raise | |
# def preprocess_text(self, text): | |
# """Clean and preprocess text""" | |
# if isinstance(text, str): | |
# # Basic cleaning | |
# text = text.strip() | |
# text = ' '.join(text.split()) # Remove extra whitespace | |
# return text | |
# return text | |
def preprocess_text(self, text): | |
"""Clean and preprocess text""" | |
if isinstance(text, str): | |
# Basic cleaning | |
text = text.strip() | |
text = ' '.join(text.split()) # Remove extra whitespace | |
# Capitalize first letter to match training data format | |
text = text.title() # This will capitalize first letter of each word | |
return text | |
return text | |
def preprocess(self, text): | |
""" | |
Preprocess the input text based on method | |
""" | |
try: | |
# Clean text first | |
text = self.preprocess_text(text) | |
if self.method == 'bertbased': | |
# BERT preprocessing | |
encodings = self.tokenizer( | |
text, | |
truncation=True, | |
padding=True, | |
max_length=512, | |
return_tensors='pt' | |
) | |
encodings = {k: v.to(self.device) for k, v in encodings.items()} | |
return encodings | |
else: | |
# Baseline preprocessing | |
return self.tfidf.transform([text] if isinstance(text, str) else text) | |
except Exception as e: | |
logger.error(f"Error in preprocessing: {str(e)}") | |
raise | |
def predict(self, text, return_probability=False): | |
""" | |
Predict using either BERT or baseline model | |
Args: | |
text: Input text or list of texts | |
return_probability: Whether to return probability scores | |
Returns: | |
Predictions with metadata | |
""" | |
try: | |
# Handle both single string and list of strings | |
if isinstance(text, str): | |
text = [text] | |
# Preprocess | |
inputs = self.preprocess(text) | |
if self.method == 'bertbased': | |
# BERT predictions | |
with torch.no_grad(): | |
outputs = self.model(**inputs) | |
probabilities = torch.softmax(outputs.logits, dim=-1) | |
predictions = torch.argmax(probabilities, dim=-1) | |
predictions = predictions.cpu().numpy() | |
probabilities = probabilities.cpu().numpy() | |
else: | |
# Baseline predictions | |
predictions = self.baseline_model.predict(inputs) | |
probabilities = self.baseline_model.predict_proba(inputs) | |
# Convert numeric predictions to original labels | |
predicted_labels = self.label_encoder.inverse_transform(predictions) | |
# Ensure consistent casing with training data | |
predicted_labels = [label.title() for label in predicted_labels] | |
if return_probability: | |
results = [] | |
for t, label, prob, probs in zip(text, predicted_labels, | |
probabilities.max(axis=1), | |
probabilities): | |
result = { | |
'text': t[:200] + '...' if len(t) > 200 else t, | |
'predicted_label': label.title(), # Ensure consistent casing | |
'confidence': float(prob), | |
'model_type': self.method, | |
'probabilities': { | |
self.label_encoder.inverse_transform([i])[0].title(): float(p) # Consistent casing | |
for i, p in enumerate(probs) | |
}, | |
# ... rest of the result dictionary ... | |
} | |
results.append(result) | |
return results[0] if len(text) == 1 else results | |
return predicted_labels[0] if len(text) == 1 else predicted_labels | |
except Exception as e: | |
logger.error(f"Error in prediction: {str(e)}") | |
raise | |
def predict_old(self, text, return_probability=False): | |
""" | |
Predict using either BERT or baseline model | |
Args: | |
text: Input text or list of texts | |
return_probability: Whether to return probability scores | |
Returns: | |
Predictions with metadata | |
""" | |
try: | |
# Handle both single string and list of strings | |
if isinstance(text, str): | |
text = [text] | |
# Preprocess | |
inputs = self.preprocess(text) | |
if self.method == 'bertbased': | |
# BERT predictions | |
with torch.no_grad(): | |
outputs = self.model(**inputs) | |
probabilities = torch.softmax(outputs.logits, dim=-1) | |
predictions = torch.argmax(probabilities, dim=-1) | |
predictions = predictions.cpu().numpy() | |
probabilities = probabilities.cpu().numpy() | |
else: | |
# Baseline predictions | |
predictions = self.baseline_model.predict(inputs) | |
probabilities = self.baseline_model.predict_proba(inputs) | |
# Convert numeric predictions to original labels | |
predicted_labels = self.label_encoder.inverse_transform(predictions) | |
if return_probability: | |
results = [] | |
for t, label, prob, probs in zip(text, predicted_labels, | |
probabilities.max(axis=1), | |
probabilities): | |
# Create detailed result dictionary | |
result = { | |
'text': t[:200] + '...' if len(t) > 200 else t, # Truncate long text | |
'predicted_label': label, | |
'confidence': float(prob), | |
'model_type': self.method, | |
'probabilities': { | |
self.label_encoder.inverse_transform([i])[0]: float(p) | |
for i, p in enumerate(probs) | |
}, | |
'timestamp': datetime.now().isoformat(), | |
'metadata': { | |
'model_name': 'BERT' if self.method == 'bertbased' else 'Baseline', | |
'text_length': len(t), | |
'preprocessing_steps': ['cleaning', 'tokenization'] | |
} | |
} | |
results.append(result) | |
return results[0] if len(text) == 1 else results | |
return predicted_labels[0] if len(text) == 1 else predicted_labels | |
except Exception as e: | |
logger.error(f"Error in prediction: {str(e)}") | |
raise | |
def get_model_info(self): | |
"""Return model information""" | |
return { | |
'model_type': self.method, | |
'model_name': 'BERT' if self.method == 'bertbased' else 'Baseline', | |
'device': str(self.device) if self.method == 'bertbased' else 'CPU', | |
'max_sequence_length': 512 if self.method == 'bertbased' else None, | |
'number_of_classes': len(self.label_encoder.classes_), | |
'classes': list(self.label_encoder.classes_) | |
} | |
def load_and_process_pdf(url_or_file): | |
""" | |
Load and process PDF from URL or file | |
Returns extracted text | |
""" | |
try: | |
# Your PDF processing code here | |
# Return extracted text | |
pass | |
except Exception as e: | |
logger.error(f"Error processing PDF: {str(e)}") | |
raise | |
# Example usage | |
if __name__ == "__main__": | |
# Test the pipeline | |
classifier = TextClassificationPipeline() | |
# Test single prediction | |
text = "Example construction document text" | |
result = classifier.predict(text, return_probability=True) | |
print("\nSingle Prediction Result:") | |
print(result) | |
# Test batch prediction | |
texts = ["First document", "Second document"] | |
results = classifier.predict(texts, return_probability=True) | |
print("\nBatch Prediction Results:") | |
for result in results: | |
print(f"\nText: {result['text']}") | |
print(f"Prediction: {result['predicted_label']}") | |
print(f"Confidence: {result['confidence']:.4f}") |