zionia commited on
Commit
1dd5bbf
1 Parent(s): 4ecebd0

update interface to be consistent and allow file uploads

Browse files
Files changed (1) hide show
  1. app.py +41 -16
app.py CHANGED
@@ -1,5 +1,6 @@
1
  import gradio as gr
2
  from transformers import AutoModelForSequenceClassification, AutoTokenizer, pipeline
 
3
 
4
  MODEL_URL = "https://huggingface.co/dsfsi/PuoBERTa-News"
5
  WEBSITE_URL = "https://www.kodiks.com/ai_solutions.html"
@@ -22,31 +23,55 @@ categories = {
22
 
23
  def prediction(news):
24
  clasifer = pipeline("sentiment-analysis", tokenizer=tokenizer, model=model, return_all_scores=True)
25
-
26
  preds = clasifer(news)
27
-
28
- preds_dict = {}
29
- for pred in preds[0]:
30
- label = categories.get(pred['label'], pred['label'])
31
- preds_dict[label] = pred['score']
32
-
33
  return preds_dict
34
 
 
 
 
 
 
 
 
 
 
 
 
 
 
35
  gradio_ui = gr.Interface(
36
  fn=prediction,
37
  title="Setswana News Classification",
38
  description=f"Enter Setswana news article to see the category of the news.\n For this classification, the {MODEL_URL} model was used.",
39
- examples=[
40
- ['Ka Letsatsi la Aforika, Aforika Borwa e tla be e keteka mabaka a boikemelo, le diketso tse di siameng tse e di dirileng go tokafatsa dikamano tsa yona le dinaga tse dingwe tsa Aforika.'],
41
- ["Thuto ya Setswana ke nngwe ya dithuto tse di botlhokwa mo sekolong se se tlhamaletseng go ruta bana ba ba mo lefatsheng la Botswana."],
42
- ["Mo kgweding e e fetileng, dipuisano tsa ditheko tsa dijalo di ile tsa tswelela, ka batho ba rekang le barui ba ba ruileng."],
43
- ["Masole a Aforika Borwa a ne a ya kwa Mozambique go tlisetsa motlakase morago ga maduo a kgatlha."],
44
- ],
45
  inputs=gr.Textbox(lines=10, label="Paste some Setswana news here"),
46
  outputs=gr.Label(num_top_classes=5, label="News categories probabilities"),
47
- theme="huggingface",
48
- article="<p style='text-align: center'>For our other AI works: <a href='https://www.kodiks.com/ai_solutions.html' target='_blank'>https://www.kodiks.com/ai_solutions.html</a> | <a href='https://twitter.com/KodiksBilisim' target='_blank'>Contact us</a></p>",
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
49
  )
50
 
51
- gradio_ui.launch()
52
 
 
 
1
  import gradio as gr
2
  from transformers import AutoModelForSequenceClassification, AutoTokenizer, pipeline
3
+ import pandas as pd
4
 
5
  MODEL_URL = "https://huggingface.co/dsfsi/PuoBERTa-News"
6
  WEBSITE_URL = "https://www.kodiks.com/ai_solutions.html"
 
23
 
24
  def prediction(news):
25
  clasifer = pipeline("sentiment-analysis", tokenizer=tokenizer, model=model, return_all_scores=True)
 
26
  preds = clasifer(news)
27
+ preds_dict = {categories.get(pred['label'], pred['label']): pred['score'] for pred in preds[0]}
 
 
 
 
 
28
  return preds_dict
29
 
30
+ def file_prediction(file):
31
+ if file.name.endswith('.csv'):
32
+ df = pd.read_csv(file.name)
33
+ news_list = df.iloc[:, 0].tolist()
34
+ else:
35
+ news_list = [file.read().decode('utf-8')] # Load plain text
36
+
37
+ results = []
38
+ for news in news_list:
39
+ results.append(prediction(news))
40
+
41
+ return results
42
+
43
  gradio_ui = gr.Interface(
44
  fn=prediction,
45
  title="Setswana News Classification",
46
  description=f"Enter Setswana news article to see the category of the news.\n For this classification, the {MODEL_URL} model was used.",
 
 
 
 
 
 
47
  inputs=gr.Textbox(lines=10, label="Paste some Setswana news here"),
48
  outputs=gr.Label(num_top_classes=5, label="News categories probabilities"),
49
+ theme="default",
50
+ css="""
51
+ body {
52
+ background-color: white !important;
53
+ color: black !important;
54
+ }
55
+ .gradio-container {
56
+ background-color: white !important;
57
+ color: black !important;
58
+ }
59
+ .gr-button {
60
+ background-color: #f0f0f0 !important;
61
+ color: black !important;
62
+ }
63
+ """
64
+ )
65
+
66
+ gradio_file_ui = gr.Interface(
67
+ fn=file_prediction,
68
+ title="Upload File for Setswana News Classification",
69
+ description=f"Upload a text or CSV file with Setswana news articles. The first column in the CSV should contain the news text.",
70
+ inputs=gr.File(label="Upload text or CSV file"),
71
+ outputs=gr.Dataframe(headers=["News Text", "Category Predictions"], label="Predictions from file"),
72
+ theme="default"
73
  )
74
 
75
+ gradio_combined_ui = gr.TabbedInterface([gradio_ui, gradio_file_ui], ["Text Input", "File Upload"])
76
 
77
+ gradio_combined_ui.launch()