Florian valade commited on
Commit
9d2c30f
·
2 Parent(s): a0417ab 92603a4

Merge branch 'main' of hf.co:spaces/valcore/Branchy-phi-2

Browse files
Files changed (3) hide show
  1. README.md +2 -2
  2. app.py +3 -1
  3. requirements.txt +2 -2
README.md CHANGED
@@ -4,7 +4,7 @@ emoji: ⚡
4
  colorFrom: purple
5
  colorTo: pink
6
  sdk: gradio
7
- sdk_version: 4.36.1
8
  app_file: app.py
9
  pinned: false
10
  ---
@@ -45,4 +45,4 @@ The command will start a local web server and open the application in your defau
45
 
46
  - **app.py**: The main Streamlit application script.
47
  - **requirements.txt**: Lists all the Python dependencies required by the project.
48
- - **src/**: Contains the source code for the project.
 
4
  colorFrom: purple
5
  colorTo: pink
6
  sdk: gradio
7
+ sdk_version: 4.37.2
8
  app_file: app.py
9
  pinned: false
10
  ---
 
45
 
46
  - **app.py**: The main Streamlit application script.
47
  - **requirements.txt**: Lists all the Python dependencies required by the project.
48
+ - **src/**: Contains the source code for the project.
app.py CHANGED
@@ -2,6 +2,7 @@ import gradio as gr
2
  import torch
3
  import pandas as pd
4
  import plotly.graph_objects as go
 
5
  from plotly.subplots import make_subplots
6
  from transformers import AutoModelForCausalLM, AutoTokenizer
7
  import time
@@ -73,7 +74,8 @@ def truncate_context(input_ids, max_length=2048):
73
  if len(input_ids[0]) > max_length:
74
  return input_ids[:, -max_length:]
75
  return input_ids
76
-
 
77
  def generate_response(message, chat_history, epsilon):
78
  global data, stop_generation
79
  data = pd.DataFrame(columns=["Time taken (in ms)", "Early exit depth", "Token"])
 
2
  import torch
3
  import pandas as pd
4
  import plotly.graph_objects as go
5
+ import spaces
6
  from plotly.subplots import make_subplots
7
  from transformers import AutoModelForCausalLM, AutoTokenizer
8
  import time
 
74
  if len(input_ids[0]) > max_length:
75
  return input_ids[:, -max_length:]
76
  return input_ids
77
+
78
+ @spaces.GPU
79
  def generate_response(message, chat_history, epsilon):
80
  global data, stop_generation
81
  data = pd.DataFrame(columns=["Time taken (in ms)", "Early exit depth", "Token"])
requirements.txt CHANGED
@@ -1,5 +1,5 @@
1
- gradio==4.32.2
2
- torch==2.3.1
3
  pandas==2.0.3
4
  transformers==4.41.1
5
  plotly==5.22.0
 
1
+ gradio==4.37.2
2
+ torch==2.2.0
3
  pandas==2.0.3
4
  transformers==4.41.1
5
  plotly==5.22.0