Spaces:
Sleeping
Sleeping
Update app.py
Browse files
app.py
CHANGED
@@ -1,44 +1,63 @@
|
|
1 |
import pandas as pd
|
|
|
|
|
|
|
2 |
from transformers import pipeline
|
3 |
import gradio as gr
|
4 |
|
5 |
# Load the injury data
|
6 |
injury_data = pd.read_csv("Injury_History.csv")
|
7 |
|
8 |
-
# Initialize
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
9 |
generator = pipeline("text-generation", model="gpt2")
|
10 |
|
11 |
-
# Define the
|
12 |
def injury_query(player_query):
|
13 |
-
#
|
14 |
-
|
15 |
-
player_name = " ".join(words[:2]) # assuming the first two words are the player's name
|
16 |
-
year = next((word for word in words if word.isdigit()), None)
|
17 |
-
|
18 |
-
# Filter the data by player name and year if provided
|
19 |
-
filtered_data = injury_data[injury_data['Name'].str.lower() == player_name]
|
20 |
-
if year:
|
21 |
-
filtered_data = filtered_data[filtered_data['Date'].str.contains(year)]
|
22 |
-
|
23 |
-
if filtered_data.empty:
|
24 |
-
return "No injury records found for this player in the specified timeframe."
|
25 |
|
26 |
-
#
|
27 |
-
injury_details = ". ".join(
|
28 |
-
context = f"
|
29 |
|
30 |
-
# Generate a response
|
31 |
-
response = generator(f"
|
32 |
return response
|
33 |
|
34 |
-
# Set up Gradio interface
|
35 |
interface = gr.Interface(
|
36 |
fn=injury_query,
|
37 |
inputs="text",
|
38 |
outputs="text",
|
39 |
title="NBA Player Injury Q&A",
|
40 |
-
description="Ask about a player's injury history,
|
41 |
)
|
42 |
|
43 |
-
# Launch the app
|
44 |
interface.launch()
|
|
|
1 |
import pandas as pd
|
2 |
+
import torch
|
3 |
+
from sentence_transformers import SentenceTransformer
|
4 |
+
import faiss
|
5 |
from transformers import pipeline
|
6 |
import gradio as gr
|
7 |
|
8 |
# Load the injury data
|
9 |
injury_data = pd.read_csv("Injury_History.csv")
|
10 |
|
11 |
+
# Initialize an embedding model for creating embeddings of the injury descriptions
|
12 |
+
embedding_model = SentenceTransformer('all-MiniLM-L6-v2')
|
13 |
+
|
14 |
+
# Generate embeddings for each injury record in the dataset
|
15 |
+
injury_data['embedding'] = injury_data['Notes'].apply(lambda x: embedding_model.encode(x, convert_to_tensor=True))
|
16 |
+
|
17 |
+
# Convert embeddings to numpy arrays for FAISS
|
18 |
+
embeddings = torch.stack(injury_data['embedding'].to_list()).cpu().numpy()
|
19 |
+
|
20 |
+
# Set up a FAISS index for efficient similarity search
|
21 |
+
index = faiss.IndexFlatL2(embeddings.shape[1])
|
22 |
+
index.add(embeddings)
|
23 |
+
|
24 |
+
# Define a function to retrieve injuries based on similarity to the query
|
25 |
+
def retrieve_injuries(query):
|
26 |
+
# Generate an embedding for the user query
|
27 |
+
query_embedding = embedding_model.encode(query, convert_to_tensor=True).cpu().numpy()
|
28 |
+
|
29 |
+
# Search the FAISS index for the top 3 similar injuries
|
30 |
+
k = 3 # number of results to retrieve
|
31 |
+
distances, indices = index.search(query_embedding.reshape(1, -1), k)
|
32 |
+
|
33 |
+
# Retrieve the most relevant injury records
|
34 |
+
results = injury_data.iloc[indices[0]]
|
35 |
+
return results
|
36 |
+
|
37 |
+
# Initialize a text generation model for generating responses
|
38 |
generator = pipeline("text-generation", model="gpt2")
|
39 |
|
40 |
+
# Define the main function to handle the user query, retrieve relevant injuries, and generate a response
|
41 |
def injury_query(player_query):
|
42 |
+
# Retrieve relevant injury data
|
43 |
+
retrieved_injuries = retrieve_injuries(player_query)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
44 |
|
45 |
+
# Combine injury details into a context string for generation
|
46 |
+
injury_details = ". ".join(retrieved_injuries['Notes'].tolist())
|
47 |
+
context = f"Injury history: {injury_details}"
|
48 |
|
49 |
+
# Generate a response based on the retrieved data
|
50 |
+
response = generator(f"Answer based on data: {context}", max_length=100)[0]['generated_text']
|
51 |
return response
|
52 |
|
53 |
+
# Set up the Gradio interface for the app
|
54 |
interface = gr.Interface(
|
55 |
fn=injury_query,
|
56 |
inputs="text",
|
57 |
outputs="text",
|
58 |
title="NBA Player Injury Q&A",
|
59 |
+
description="Ask about a player's injury history, or inquire about common injuries."
|
60 |
)
|
61 |
|
62 |
+
# Launch the Gradio app
|
63 |
interface.launch()
|