JHigg commited on
Commit
2a24d0c
1 Parent(s): 2520b8c

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +41 -22
app.py CHANGED
@@ -1,44 +1,63 @@
1
  import pandas as pd
 
 
 
2
  from transformers import pipeline
3
  import gradio as gr
4
 
5
  # Load the injury data
6
  injury_data = pd.read_csv("Injury_History.csv")
7
 
8
- # Initialize a text generation model (use 'gpt2' or similar for simplicity)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
9
  generator = pipeline("text-generation", model="gpt2")
10
 
11
- # Define the RAG function for injury lookup and generation
12
  def injury_query(player_query):
13
- # Extract player name and optional year from the query
14
- words = player_query.lower().split()
15
- player_name = " ".join(words[:2]) # assuming the first two words are the player's name
16
- year = next((word for word in words if word.isdigit()), None)
17
-
18
- # Filter the data by player name and year if provided
19
- filtered_data = injury_data[injury_data['Name'].str.lower() == player_name]
20
- if year:
21
- filtered_data = filtered_data[filtered_data['Date'].str.contains(year)]
22
-
23
- if filtered_data.empty:
24
- return "No injury records found for this player in the specified timeframe."
25
 
26
- # Concatenate injury records for context
27
- injury_details = ". ".join(filtered_data['Notes'].tolist())
28
- context = f"{player_name.capitalize()}'s injuries: {injury_details}"
29
 
30
- # Generate a response
31
- response = generator(f"Based on available data, here are the injuries: {context}", max_length=100)[0]['generated_text']
32
  return response
33
 
34
- # Set up Gradio interface
35
  interface = gr.Interface(
36
  fn=injury_query,
37
  inputs="text",
38
  outputs="text",
39
  title="NBA Player Injury Q&A",
40
- description="Ask about a player's injury history, e.g., 'What injuries did Jaylen Brown have in 2017?'"
41
  )
42
 
43
- # Launch the app
44
  interface.launch()
 
1
  import pandas as pd
2
+ import torch
3
+ from sentence_transformers import SentenceTransformer
4
+ import faiss
5
  from transformers import pipeline
6
  import gradio as gr
7
 
8
  # Load the injury data
9
  injury_data = pd.read_csv("Injury_History.csv")
10
 
11
+ # Initialize an embedding model for creating embeddings of the injury descriptions
12
+ embedding_model = SentenceTransformer('all-MiniLM-L6-v2')
13
+
14
+ # Generate embeddings for each injury record in the dataset
15
+ injury_data['embedding'] = injury_data['Notes'].apply(lambda x: embedding_model.encode(x, convert_to_tensor=True))
16
+
17
+ # Convert embeddings to numpy arrays for FAISS
18
+ embeddings = torch.stack(injury_data['embedding'].to_list()).cpu().numpy()
19
+
20
+ # Set up a FAISS index for efficient similarity search
21
+ index = faiss.IndexFlatL2(embeddings.shape[1])
22
+ index.add(embeddings)
23
+
24
+ # Define a function to retrieve injuries based on similarity to the query
25
+ def retrieve_injuries(query):
26
+ # Generate an embedding for the user query
27
+ query_embedding = embedding_model.encode(query, convert_to_tensor=True).cpu().numpy()
28
+
29
+ # Search the FAISS index for the top 3 similar injuries
30
+ k = 3 # number of results to retrieve
31
+ distances, indices = index.search(query_embedding.reshape(1, -1), k)
32
+
33
+ # Retrieve the most relevant injury records
34
+ results = injury_data.iloc[indices[0]]
35
+ return results
36
+
37
+ # Initialize a text generation model for generating responses
38
  generator = pipeline("text-generation", model="gpt2")
39
 
40
+ # Define the main function to handle the user query, retrieve relevant injuries, and generate a response
41
  def injury_query(player_query):
42
+ # Retrieve relevant injury data
43
+ retrieved_injuries = retrieve_injuries(player_query)
 
 
 
 
 
 
 
 
 
 
44
 
45
+ # Combine injury details into a context string for generation
46
+ injury_details = ". ".join(retrieved_injuries['Notes'].tolist())
47
+ context = f"Injury history: {injury_details}"
48
 
49
+ # Generate a response based on the retrieved data
50
+ response = generator(f"Answer based on data: {context}", max_length=100)[0]['generated_text']
51
  return response
52
 
53
+ # Set up the Gradio interface for the app
54
  interface = gr.Interface(
55
  fn=injury_query,
56
  inputs="text",
57
  outputs="text",
58
  title="NBA Player Injury Q&A",
59
+ description="Ask about a player's injury history, or inquire about common injuries."
60
  )
61
 
62
+ # Launch the Gradio app
63
  interface.launch()