Spaces:
Sleeping
Sleeping
import os | |
import torch | |
import numpy as np | |
import pandas as pd | |
import gradio as gr | |
#from google.colab import files | |
from transformers import AutoTokenizer, AutoModelForQuestionAnswering, pipeline | |
from sklearn.model_selection import train_test_split | |
from sklearn.preprocessing import LabelEncoder | |
class EnhancedHRAssistantModel: | |
def __init__(self, model_name='bert-large-uncased-whole-word-masking-finetuned-squad'): | |
# Use GPU if available | |
self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu") | |
# Load more advanced model for better context understanding | |
self.tokenizer = AutoTokenizer.from_pretrained(model_name) | |
self.model = AutoModelForQuestionAnswering.from_pretrained(model_name).to(self.device) | |
# Configure pipeline with improved parameters | |
self.qa_pipeline = pipeline( | |
'question-answering', | |
model=self.model, | |
tokenizer=self.tokenizer, | |
device=0 if torch.cuda.is_available() else -1 | |
) | |
def generate_response(self, question, context): | |
try: | |
# Improved response generation | |
result = self.qa_pipeline({ | |
'question': question, | |
'context': context | |
}) | |
# Enhanced response formatting | |
confidence = result.get('score', 0) * 100 | |
answer = result.get('answer', 'I could not find a specific answer.') | |
# Construct more informative response | |
if confidence > 50: | |
formatted_response = f"{answer}\n\n(Confidence: {confidence:.2f}%)" | |
return formatted_response | |
else: | |
return "I'm not certain about the exact details. Please consult with HR directly." | |
except Exception as e: | |
return f"Error generating response: {str(e)}" | |
def extract_key_information(self, context): | |
""" | |
Extract key points from the context for additional insights | |
""" | |
# Simple keyword-based extraction | |
key_phrases = [ | |
'benefits', 'coverage', 'policy', 'options', | |
'available', 'include', 'provide', 'offer' | |
] | |
extracted_points = [] | |
sentences = context.split('.') | |
for sentence in sentences: | |
if any(phrase in sentence.lower() for phrase in key_phrases): | |
extracted_points.append(sentence.strip()) | |
return extracted_points[:3] # Return top 3 key points | |
# Gradio Interface | |
class HRAssistantInterface: | |
def __init__(self, model): | |
self.model = model | |
def create_interface(self): | |
def comprehensive_query_handler(question, context): | |
# Primary response generation | |
primary_response = self.model.generate_response(question, context) | |
# Extract additional key information | |
additional_info = self.model.extract_key_information(context) | |
# Combine responses | |
full_response = f"{primary_response}\n\nAdditional Context:\n" | |
full_response += "\n".join(f"• {point}" for point in additional_info) | |
return full_response | |
iface = gr.Interface( | |
fn=comprehensive_query_handler, | |
inputs=[ | |
gr.Textbox(label="HR Question"), | |
gr.Textbox(label="Full Policy Context", lines=5) | |
], | |
outputs=gr.Textbox(label="Comprehensive HR Assistant Response"), | |
title="AJ: Advanced HR Policy Assistant", | |
description="Get detailed insights into your HR policies" | |
) | |
return iface | |
# Example Usage | |
def main(): | |
# Initialize Enhanced HR Assistant Model | |
hr_model = EnhancedHRAssistantModel() | |
# Create Gradio Interface | |
interface = HRAssistantInterface(hr_model) | |
gradio_app = interface.create_interface() | |
# Launch the interface | |
gradio_app.launch(share=True) | |
if __name__ == "__main__": | |
main() |