Spaces:
Runtime error
Runtime error
File size: 2,340 Bytes
04d625a 2a24d0c 0d2e65b 04d625a 0d2e65b 04d625a 2a24d0c 0d2e65b a270c67 2a24d0c 0d2e65b 2a24d0c a270c67 2a24d0c a270c67 2a24d0c 0d2e65b 2a24d0c 0d2e65b 2a24d0c 0d2e65b 04d625a 2a24d0c 0d2e65b |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 |
import pandas as pd
import torch
from sentence_transformers import SentenceTransformer
import faiss
from transformers import pipeline
import gradio as gr
# Load the injury data
injury_data = pd.read_csv("Injury_History.csv")
# Initialize an embedding model for creating embeddings of the injury descriptions
embedding_model = SentenceTransformer('all-MiniLM-L6-v2')
# Generate embeddings for each injury record in the dataset
injury_data['embedding'] = injury_data['Notes'].apply(lambda x: embedding_model.encode(x, convert_to_tensor=True))
# Convert embeddings to numpy arrays for FAISS
embeddings = torch.stack(injury_data['embedding'].to_list()).cpu().numpy()
# Set up a FAISS index for efficient similarity search
index = faiss.IndexFlatL2(embeddings.shape[1])
index.add(embeddings)
# Define a function to retrieve injuries based on similarity to the query
def retrieve_injuries(query):
# Generate an embedding for the user query
query_embedding = embedding_model.encode(query, convert_to_tensor=True).cpu().numpy()
# Search the FAISS index for the top 3 similar injuries
k = 3 # number of results to retrieve
distances, indices = index.search(query_embedding.reshape(1, -1), k)
# Retrieve the most relevant injury records
results = injury_data.iloc[indices[0]]
return results
# Initialize a text generation model for generating responses
generator = pipeline("text-generation", model="gpt2")
# Define the main function to handle the user query, retrieve relevant injuries, and generate a response
def injury_query(player_query):
# Retrieve relevant injury data
retrieved_injuries = retrieve_injuries(player_query)
# Combine injury details into a context string for generation
injury_details = ". ".join(retrieved_injuries['Notes'].tolist())
context = f"Injury history: {injury_details}"
# Generate a response based on the retrieved data
response = generator(f"Answer based on data: {context}", max_length=100)[0]['generated_text']
return response
# Set up the Gradio interface for the app
interface = gr.Interface(
fn=injury_query,
inputs="text",
outputs="text",
title="NBA Player Injury Q&A",
description="Ask about a player's injury history, or inquire about common injuries."
)
# Launch the Gradio app
interface.launch()
|