Spaces:
Sleeping
Sleeping
import pandas as pd | |
import torch | |
from sentence_transformers import SentenceTransformer | |
import faiss | |
from transformers import pipeline | |
import gradio as gr | |
# Load the injury data | |
injury_data = pd.read_csv("Injury_History.csv") | |
# Initialize an embedding model for creating embeddings of the injury descriptions | |
embedding_model = SentenceTransformer('all-MiniLM-L6-v2') | |
# Generate embeddings for each injury record in the dataset | |
injury_data['embedding'] = injury_data['Notes'].apply(lambda x: embedding_model.encode(x, convert_to_tensor=True)) | |
# Convert embeddings to numpy arrays for FAISS | |
embeddings = torch.stack(injury_data['embedding'].to_list()).cpu().numpy() | |
# Set up a FAISS index for efficient similarity search | |
index = faiss.IndexFlatL2(embeddings.shape[1]) | |
index.add(embeddings) | |
# Define a function to retrieve injuries based on similarity to the query | |
def retrieve_injuries(query): | |
# Generate an embedding for the user query | |
query_embedding = embedding_model.encode(query, convert_to_tensor=True).cpu().numpy() | |
# Search the FAISS index for the top 3 similar injuries | |
k = 3 # number of results to retrieve | |
distances, indices = index.search(query_embedding.reshape(1, -1), k) | |
# Retrieve the most relevant injury records | |
results = injury_data.iloc[indices[0]] | |
return results | |
# Initialize a text generation model for generating responses | |
generator = pipeline("text-generation", model="gpt2") | |
# Define the main function to handle the user query, retrieve relevant injuries, and generate a response | |
def injury_query(player_query): | |
# Retrieve relevant injury data | |
retrieved_injuries = retrieve_injuries(player_query) | |
# Combine injury details into a context string for generation | |
injury_details = ". ".join(retrieved_injuries['Notes'].tolist()) | |
context = f"Injury history: {injury_details}" | |
# Generate a response based on the retrieved data | |
response = generator(f"Answer based on data: {context}", max_length=100)[0]['generated_text'] | |
return response | |
# Set up the Gradio interface for the app | |
interface = gr.Interface( | |
fn=injury_query, | |
inputs="text", | |
outputs="text", | |
title="NBA Player Injury Q&A", | |
description="Ask about a player's injury history, or inquire about common injuries." | |
) | |
# Launch the Gradio app | |
interface.launch() | |