qanastek commited on
Commit
2a493e6
·
1 Parent(s): bdba6de

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +7 -4
app.py CHANGED
@@ -31,17 +31,20 @@ elif context == "New Text":
31
  def setModel(model_checkpoint, aggregation):
32
  model = AutoModelForTokenClassification.from_pretrained(model_checkpoint)
33
  tokenizer = AutoTokenizer.from_pretrained(model_checkpoint)
34
- return pipeline('token-classification', model=model, tokenizer=tokenizer, aggregation_strategy=aggregation)
35
 
36
  Run_Button = st.button("Run", key=None)
37
  if Run_Button == True:
38
 
39
- ner_pipeline = setModel(model_checkpoint, aggregation)
40
  output = ner_pipeline(input_text)
41
 
42
- print(output)
 
 
 
43
 
44
- df = pd.DataFrame.from_dict(output)
45
 
46
  if aggregation != "none":
47
  df.rename(index=str,columns={'entity_group':'POS Tag'},inplace=True)
 
31
  def setModel(model_checkpoint, aggregation):
32
  model = AutoModelForTokenClassification.from_pretrained(model_checkpoint)
33
  tokenizer = AutoTokenizer.from_pretrained(model_checkpoint)
34
+ return pipeline('token-classification', model=model, tokenizer=tokenizer, aggregation_strategy=aggregation), model.config.id2label
35
 
36
  Run_Button = st.button("Run", key=None)
37
  if Run_Button == True:
38
 
39
+ ner_pipeline, id2label = setModel(model_checkpoint, aggregation)
40
  output = ner_pipeline(input_text)
41
 
42
+ output_new = []
43
+ for o in output:
44
+ o["entity_group"] = id2label[o["entity_group"].split("_")[-1]]
45
+ output_new.append(o)
46
 
47
+ df = pd.DataFrame.from_dict(output_new)
48
 
49
  if aggregation != "none":
50
  df.rename(index=str,columns={'entity_group':'POS Tag'},inplace=True)