JHigg commited on
Commit
0d2e65b
·
verified ·
1 Parent(s): 532a423

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +35 -28
app.py CHANGED
@@ -1,37 +1,44 @@
1
- from transformers import pipeline, AutoTokenizer, AutoModel
2
  import pandas as pd
3
- from sentence_transformers import SentenceTransformer, util
 
4
 
5
- # Load data
6
- data = pd.read_csv("Injury_History.csv")
7
 
8
- # Load embeddings model for retrieval
9
- embedding_model = SentenceTransformer('sentence-transformers/all-MiniLM-L6-v2')
10
 
11
- # Load language model for generation
12
- generator = pipeline("text-generation", model="gpt2") # Lightweight option
13
-
14
-
15
- # Step 1: Create embeddings of injury data
16
- data['embedding'] = data['Notes'].apply(lambda x: embedding_model.encode(x, convert_to_tensor=True))
17
-
18
- # Define RAG function
19
- def RAG_injury_info(player_query):
20
- # Step 2: Create embedding for the user query
21
- query_embedding = embedding_model.encode(player_query, convert_to_tensor=True)
22
 
23
- # Step 3: Compute cosine similarities
24
- data['similarity'] = data['embedding'].apply(lambda x: util.cos_sim(query_embedding, x).item())
25
- top_injuries = data.sort_values(by='similarity', ascending=False).head(3) # Get top matches
 
26
 
27
- # Step 4: Prepare context for generation
28
- context = ". ".join(top_injuries['Notes'].values)
29
 
30
- # Step 5: Generate response
31
- generated_response = generator(f"Based on the injury history: {context}", max_length=100)[0]["generated_text"]
 
32
 
33
- return generated_response
 
 
 
 
 
 
 
 
 
 
 
34
 
35
- # Example usage
36
- player_query = "Tell me about Jaylen Brown's injuries."
37
- print(RAG_injury_info(player_query))
 
 
1
  import pandas as pd
2
+ from transformers import pipeline
3
+ import gradio as gr
4
 
5
+ # Load the injury data
6
+ injury_data = pd.read_csv("Injury_History.csv")
7
 
8
+ # Initialize a text generation model (use 'gpt2' or similar for simplicity)
9
+ generator = pipeline("text-generation", model="gpt2")
10
 
11
+ # Define the RAG function for injury lookup and generation
12
+ def injury_query(player_query):
13
+ # Extract player name and optional year from the query
14
+ words = player_query.lower().split()
15
+ player_name = " ".join(words[:2]) # assuming the first two words are the player's name
16
+ year = next((word for word in words if word.isdigit()), None)
 
 
 
 
 
17
 
18
+ # Filter the data by player name and year if provided
19
+ filtered_data = injury_data[injury_data['Name'].str.lower() == player_name]
20
+ if year:
21
+ filtered_data = filtered_data[filtered_data['Date'].str.contains(year)]
22
 
23
+ if filtered_data.empty:
24
+ return "No injury records found for this player in the specified timeframe."
25
 
26
+ # Concatenate injury records for context
27
+ injury_details = ". ".join(filtered_data['Notes'].tolist())
28
+ context = f"{player_name.capitalize()}'s injuries: {injury_details}"
29
 
30
+ # Generate a response
31
+ response = generator(f"Based on available data, here are the injuries: {context}", max_length=100)[0]['generated_text']
32
+ return response
33
+
34
+ # Set up Gradio interface
35
+ interface = gr.Interface(
36
+ fn=injury_query,
37
+ inputs="text",
38
+ outputs="text",
39
+ title="NBA Player Injury Q&A",
40
+ description="Ask about a player's injury history, e.g., 'What injuries did Jaylen Brown have in 2017?'"
41
+ )
42
 
43
+ # Launch the app
44
+ interface.launch()