import pandas as pd import torch from sentence_transformers import SentenceTransformer import faiss from transformers import pipeline import gradio as gr # Load the injury data injury_data = pd.read_csv("Injury_History.csv") # Initialize an embedding model for creating embeddings of the injury descriptions embedding_model = SentenceTransformer('all-MiniLM-L6-v2') # Generate embeddings for each injury record in the dataset injury_data['embedding'] = injury_data['Notes'].apply(lambda x: embedding_model.encode(x, convert_to_tensor=True)) # Convert embeddings to numpy arrays for FAISS embeddings = torch.stack(injury_data['embedding'].to_list()).cpu().numpy() # Set up a FAISS index for efficient similarity search index = faiss.IndexFlatL2(embeddings.shape[1]) index.add(embeddings) # Define a function to retrieve injuries based on similarity to the query def retrieve_injuries(query): # Generate an embedding for the user query query_embedding = embedding_model.encode(query, convert_to_tensor=True).cpu().numpy() # Search the FAISS index for the top 3 similar injuries k = 3 # number of results to retrieve distances, indices = index.search(query_embedding.reshape(1, -1), k) # Retrieve the most relevant injury records results = injury_data.iloc[indices[0]] return results # Initialize a text generation model for generating responses generator = pipeline("text-generation", model="gpt2") # Define the main function to handle the user query, retrieve relevant injuries, and generate a response def injury_query(player_query): # Retrieve relevant injury data retrieved_injuries = retrieve_injuries(player_query) # Combine injury details into a context string for generation injury_details = ". ".join(retrieved_injuries['Notes'].tolist()) context = f"Injury history: {injury_details}" # Generate a response based on the retrieved data response = generator(f"Answer based on data: {context}", max_length=100)[0]['generated_text'] return response # Set up the Gradio interface for the app interface = gr.Interface( fn=injury_query, inputs="text", outputs="text", title="NBA Player Injury Q&A", description="Ask about a player's injury history, or inquire about common injuries." ) # Launch the Gradio app interface.launch()