kataniccc commited on
Commit
060541a
1 Parent(s): c480dfd

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +12 -16
app.py CHANGED
@@ -1,34 +1,30 @@
 
 
1
  import gradio as gr
2
  import pandas as pd
 
3
 
4
- from transformers import pipeline
5
-
6
-
7
- from datasets import Dataset,DatasetDict
8
 
 
 
9
  classifier = pipeline("text-classification", model=model_nm)
10
 
11
-
12
- df = pd.read_csv(path/'train.csv')
13
-
14
  df.describe(include='object')
15
-
16
  df['input'] = 'TEXT1: ' + df.context + '; TEXT2: ' + df.target + '; ANC1: ' + df.anchor
17
-
18
- df.input.head()
19
-
20
-
21
  ds = Dataset.from_pandas(df)
22
 
 
23
  def predict_text(input_text):
24
  prediction = classifier(input_text)
25
  return prediction
26
 
27
-
28
-
29
  text_input = gr.inputs.Textbox(lines=7, label="Unesite tekst")
30
  output_text = gr.outputs.Textbox(label="Predikcija")
31
-
32
  gr.Interface(predict_text, inputs=text_input, outputs=output_text).launch()
33
 
34
-
 
1
+ import torch
2
+ from transformers import pipeline
3
  import gradio as gr
4
  import pandas as pd
5
+ from datasets import Dataset
6
 
7
+ # Enable SafeTensors if available
8
+ if torch.__version__ >= "1.10":
9
+ torch.set_safety_enabled(True)
 
10
 
11
+ # Load the model
12
+ model_nm = 'microsoft/deberta-v3-small'
13
  classifier = pipeline("text-classification", model=model_nm)
14
 
15
+ # Read and preprocess data
16
+ df = pd.read_csv("path/to/train.csv") # Replace "path/to/train.csv" with the actual path
 
17
  df.describe(include='object')
 
18
  df['input'] = 'TEXT1: ' + df.context + '; TEXT2: ' + df.target + '; ANC1: ' + df.anchor
 
 
 
 
19
  ds = Dataset.from_pandas(df)
20
 
21
+ # Define prediction function
22
  def predict_text(input_text):
23
  prediction = classifier(input_text)
24
  return prediction
25
 
26
+ # Define Gradio interface
 
27
  text_input = gr.inputs.Textbox(lines=7, label="Unesite tekst")
28
  output_text = gr.outputs.Textbox(label="Predikcija")
 
29
  gr.Interface(predict_text, inputs=text_input, outputs=output_text).launch()
30