JHigg commited on
Commit
a270c67
·
verified ·
1 Parent(s): 2a57c8f

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +30 -21
app.py CHANGED
@@ -1,27 +1,36 @@
 
1
  import pandas as pd
2
- import gradio as gr
3
 
4
- # Load the injury data
5
  data = pd.read_csv("Injury_History.csv")
6
 
7
- # Define a function to look up injuries by player name
8
- def lookup_injuries(player_name):
9
- # Filter injury data for the given player name
10
- injuries = data[data['Name'].str.lower() == player_name.lower()]
11
- if injuries.empty:
12
- return f"No injury records found for {player_name}."
13
-
14
- # Format the output to show only Date and Notes columns
15
- return injuries[['Date', 'Notes']].to_string(index=False)
16
 
17
- # Set up the Gradio interface
18
- interface = gr.Interface(
19
- fn=lookup_injuries,
20
- inputs="text",
21
- outputs="text",
22
- title="NBA Player Injury Lookup",
23
- description="Enter a player's name to see their injury history."
24
- )
 
 
 
 
 
 
 
 
 
 
 
25
 
26
- # Launch the app
27
- interface.launch()
 
 
1
+ from transformers import pipeline, AutoTokenizer, AutoModel
2
  import pandas as pd
3
+ from sentence_transformers import SentenceTransformer, util
4
 
5
+ # Load data
6
  data = pd.read_csv("Injury_History.csv")
7
 
8
+ # Load embeddings model for retrieval
9
+ embedding_model = SentenceTransformer('sentence-transformers/all-MiniLM-L6-v2')
10
+
11
+ # Load language model for generation
12
+ generator = pipeline("text-generation", model="gpt-3.5-turbo") # Example model
 
 
 
 
13
 
14
+ # Step 1: Create embeddings of injury data
15
+ data['embedding'] = data['Notes'].apply(lambda x: embedding_model.encode(x, convert_to_tensor=True))
16
+
17
+ # Define RAG function
18
+ def RAG_injury_info(player_query):
19
+ # Step 2: Create embedding for the user query
20
+ query_embedding = embedding_model.encode(player_query, convert_to_tensor=True)
21
+
22
+ # Step 3: Compute cosine similarities
23
+ data['similarity'] = data['embedding'].apply(lambda x: util.cos_sim(query_embedding, x).item())
24
+ top_injuries = data.sort_values(by='similarity', ascending=False).head(3) # Get top matches
25
+
26
+ # Step 4: Prepare context for generation
27
+ context = ". ".join(top_injuries['Notes'].values)
28
+
29
+ # Step 5: Generate response
30
+ generated_response = generator(f"Based on the injury history: {context}", max_length=100)[0]["generated_text"]
31
+
32
+ return generated_response
33
 
34
+ # Example usage
35
+ player_query = "Tell me about Jaylen Brown's injuries."
36
+ print(RAG_injury_info(player_query))