from sentence_transformers import SentenceTransformer, util import pickle import pandas as pd import numpy as np import os import json from flask import Flask, request, jsonify from werkzeug.utils import secure_filename # import logging # # Set up root logger, and add a file handler to root logger # logging.basicConfig(filename = 'log_file.log', # filemode='w', # level = logging.DEBUG, # format = '%(asctime)s:%(levelname)s:%(filename)s:%(funcName)s:%(lineno)d:%(message)s') # logger = logging.getLogger() app = Flask(__name__) @app.route('/match_text', methods=['POST']) def similarity(): try: # logger.debug(f'receiving the json data') data = request.get_json() # logger.debug(f'received the json data') if 'text1' not in data or 'text2' not in data: # logger.debug(f'Error : Both text1 and text2 must be provided!') return jsonify({'error': 'Both text1 and text2 must be provided.'}), 400 # logger.debug(f'extracting the sentences from the request') sentences1 = data['text1'] sentences2 = data['text2'] # logger.debug(f'extracted the sentences from the request') # logger.debug(f'calculating the embeddings') embeddings1 = model.encode(sentences1, convert_to_tensor=True) embeddings2 = model.encode(sentences2, convert_to_tensor=True) # logger.debug(f'embeddings calculated') # logger.debug(f'calculating the cosine score') cosine_scores = util.cos_sim(embeddings1, embeddings2) # logger.debug(f'calculated the cosine score') print(f'{cosine_scores[0][0].item()}') return jsonify({'similarity_score': cosine_scores[0][0].item()}), 200 except Exception as e: # logger.debug(f'Unknown error! : {e}') return jsonify({'error' : str(e)}), 500 if __name__ == '__main__': # logger.debug(f'loading model...') print(f'loading model...') # model = SentenceTransformer("all-MiniLM-L6-v2", cache_folder='./') model = SentenceTransformer("sentence-transformers/all-MiniLM-L6-v2", cache_folder='./') #model = SentenceTransformer("models--sentence-transformers--all-MiniLM-L6-v2/snapshots/1a310852cf8e58d22c5ebff537711d504ad4ad66") model.max_seq_length = 512 print(f'model max lenght is :{model.max_seq_length}') app.run(debug=False, port = 7860, host = '0.0.0.0', threaded = False)