from flask import Flask, request, jsonify from goldenretriever import GoldenRetriever app = Flask(__name__) # Initialize the GoldenRetriever model retriever_interventions = GoldenRetriever( question_encoder="models/interventions/question_encoder", document_index="models/interventions/document_index/", device="cpu" ) # Initialize the GoldenRetriever model retriever_outcomes = GoldenRetriever( question_encoder="models/outcomes/question_encoder", document_index="models/outcomes/document_index/", device="cpu" ) def retrieve_documents(retriever, text): pred_docs = retriever.retrieve(text, k=5, batch_size=1, progress_bar=False)[0] return [doc.document.text for doc in pred_docs] @app.route('/retrieve-intervention', methods=['POST']) def retrieve_intervention(): data = request.get_json() text = data.get('text', '') if text: result = retrieve_documents(retriever_interventions, text) return jsonify(result), 200 else: return jsonify({'error': 'No text provided'}), 400 @app.route('/retrieve-outcomes', methods=['POST']) def retrieve_outcomes(): data = request.get_json() text = data.get('text', '') if text: result = retrieve_documents(retriever_outcomes, text) return jsonify(result), 200 else: return jsonify({'error': 'No text provided'}), 400 if __name__ == '__main__': app.run(debug=True, host='0.0.0.0', port=8000)