File size: 2,340 Bytes
04d625a
2a24d0c
 
 
0d2e65b
 
04d625a
0d2e65b
 
04d625a
2a24d0c
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
0d2e65b
a270c67
2a24d0c
0d2e65b
2a24d0c
 
a270c67
2a24d0c
 
 
a270c67
2a24d0c
 
0d2e65b
 
2a24d0c
0d2e65b
 
 
 
 
2a24d0c
0d2e65b
04d625a
2a24d0c
0d2e65b
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
import pandas as pd
import torch
from sentence_transformers import SentenceTransformer
import faiss
from transformers import pipeline
import gradio as gr

# Load the injury data
injury_data = pd.read_csv("Injury_History.csv")

# Initialize an embedding model for creating embeddings of the injury descriptions
embedding_model = SentenceTransformer('all-MiniLM-L6-v2')

# Generate embeddings for each injury record in the dataset
injury_data['embedding'] = injury_data['Notes'].apply(lambda x: embedding_model.encode(x, convert_to_tensor=True))

# Convert embeddings to numpy arrays for FAISS
embeddings = torch.stack(injury_data['embedding'].to_list()).cpu().numpy()

# Set up a FAISS index for efficient similarity search
index = faiss.IndexFlatL2(embeddings.shape[1])
index.add(embeddings)

# Define a function to retrieve injuries based on similarity to the query
def retrieve_injuries(query):
    # Generate an embedding for the user query
    query_embedding = embedding_model.encode(query, convert_to_tensor=True).cpu().numpy()

    # Search the FAISS index for the top 3 similar injuries
    k = 3  # number of results to retrieve
    distances, indices = index.search(query_embedding.reshape(1, -1), k)
    
    # Retrieve the most relevant injury records
    results = injury_data.iloc[indices[0]]
    return results

# Initialize a text generation model for generating responses
generator = pipeline("text-generation", model="gpt2")

# Define the main function to handle the user query, retrieve relevant injuries, and generate a response
def injury_query(player_query):
    # Retrieve relevant injury data
    retrieved_injuries = retrieve_injuries(player_query)
    
    # Combine injury details into a context string for generation
    injury_details = ". ".join(retrieved_injuries['Notes'].tolist())
    context = f"Injury history: {injury_details}"
    
    # Generate a response based on the retrieved data
    response = generator(f"Answer based on data: {context}", max_length=100)[0]['generated_text']
    return response

# Set up the Gradio interface for the app
interface = gr.Interface(
    fn=injury_query,
    inputs="text",
    outputs="text",
    title="NBA Player Injury Q&A",
    description="Ask about a player's injury history, or inquire about common injuries."
)

# Launch the Gradio app
interface.launch()