import tensorflow as tf from transformers import BertTokenizer, TFBertForSequenceClassification import numpy as np import json import requests import gradio as gr import logging bert_tokenizer = BertTokenizer.from_pretrained('MultiTokenizer_ep10') bert_model = TFBertForSequenceClassification.from_pretrained('MultiModel_ep10') # def send_results_to_api(data, result_url): # headers = {'Content-Type':'application/json'} # response = requests.post(result_url, json = data, headers=headers) # if response.status_code == 200: # return response.json # else: # return {'error':f"failed to send result to API: {response.status_code}"} def predict_text(params): try: params = json.loads(params) except JSONDecodeError as e: logging.error(f"Invalid JSON input: {e.msg} at line {e.lineno} column {e.colno}") return {"error": f"Invalid JSON input: {e.msg} at line {e.lineno} column {e.colno}"} texts = params.get("texts",[]) # api = params.get("api", "") # job_id = params.get("job_id","") if not texts: return {"error": "Missing required parameters: 'urls'"} solutions = [] for text in texts: encoding = bert_tokenizer.encode_plus( text, add_special_tokens=True, max_length=128, return_token_type_ids=True, padding = 'max_length', truncation=True, return_attention_mask=True, return_tensors='tf' ) input_ids = encoding['input_ids'] token_type_ids = encoding['token_type_ids'] attention_mask = encoding['attention_mask'] pred = bert_model.predict([input_ids, token_type_ids, attention_mask]) logits = pred.logits pred_label = tf.argmax(logits, axis=1).numpy()[0] label = {0: 'BUSINESS', 1: 'COMEDY', 2: 'CRIME', 3: 'FOOD & DRINK', 4: 'POLITICS', 5: 'SPORTS', 6: 'TRAVEL'} result = {'text':text, 'label':[label[pred_label]]} solutions.append(result) # result_url = f"{api}/{job_id}" # send_results_to_api(solutions, result_url) return json.dumps({"solutions":solutions}) inputt = gr.Textbox(label="Parameters in Json Format... Eg. {'texts':['text1', 'text2']") outputt = gr.JSON() application = gr.Interface(fn = predict_text, inputs = inputt, outputs = outputt, title='Multi Text Classification with API Integration..') application.launch()