from fastapi.middleware.cors import CORSMiddleware from fastapi import FastAPI, HTTPException, File, UploadFile, Form from fastapi.responses import JSONResponse, FileResponse from pydantic import BaseModel from typing import Optional import subprocess import os import logging from inference_transform import process_smiles, process_pdb, process_sdf, extract_and_convert_to_sdf, is_valid_smiles # Set up logging logging.basicConfig(level=logging.INFO) logger = logging.getLogger(__name__) app = FastAPI() app.add_middleware( CORSMiddleware, allow_origins=['*'], allow_credentials=True, allow_methods=['*'], allow_headers=['*'] ) sdf_file_path = "/root/CHEMISTral7Bv0.3/example/Conformer3D_COMPOUND_CID_240.sdf" class InferenceRequest(BaseModel): prompt: str max_tokens: int = 256 temperature: float = 1.0 @app.post("/predict_base") async def predict_base( prompt: str = Form(...), max_tokens: int = Form(256), temperature: float = Form(1.0), file: Optional[UploadFile] = File(None) ): try: if file: file_path = f"/tmp/{file.filename}" with open(file_path, "wb") as f: f.write(file.file.read()) if file.filename.endswith(".pdb"): prompt += f" {process_pdb(file_path)}" elif file.filename.endswith(".sdf"): prompt += f" {process_sdf(file_path)}" else: try: sdf_file = extract_and_convert_to_sdf(prompt) if sdf_file: prompt += f" {sdf_file}" except ValueError as e: logger.info(str(e)) command = [ "python", "/root/CHEMISTral7Bv0.3/mistral_chat_script.py", "/root/mistral_models/7B-v0.3/", prompt, f"--max_tokens={max_tokens}", f"--temperature={temperature}", "--instruct" ] logger.info(f"Running command: {' '.join(command)}") result = subprocess.run(command, capture_output=True, text=True) if result.returncode != 0: logger.error(f"Command failed with return code {result.returncode}") logger.error(f"stderr: {result.stderr}") raise HTTPException(status_code=500, detail=result.stderr) response = result.stdout.strip() sdf_file_path = "/root/CHEMISTral7Bv0.3/example/Conformer3D_COMPOUND_CID_240.sdf" return { "response": response, "sdf_file_path": sdf_file_path } except Exception as e: logger.exception("Exception occurred during inference.") raise HTTPException(status_code=500, detail=str(e)) @app.post("/predict") async def predict_alternative( prompt: str = Form(...), max_tokens: int = Form(256), temperature: float = Form(1.0), file: Optional[UploadFile] = File(None) ): try: if file: file_path = f"/tmp/{file.filename}" with open(file_path, "wb") as f: f.write(await file.read()) if file.filename.endswith(".pdb"): prompt += f" {process_pdb(file_path)}" elif file.filename.endswith(".sdf"): prompt += f" {process_sdf(file_path)}" else: try: sdf_file = extract_and_convert_to_sdf(prompt) if sdf_file: prompt += f" {sdf_file}" except ValueError as e: logger.info(str(e)) command = [ "python", "/root/CHEMISTral7Bv0.3/mistral_chat_script.py", "/root/mistral_models/7B-v0.3/", prompt, f"--max_tokens={max_tokens}", f"--temperature={temperature}", "--instruct", "--lora_path=/root/CHEMISTral7Bv0.3/runs/checkpoints/checkpoint_000300/consolidated/lora.safetensors" ] logger.info(f"Running command: {' '.join(command)}") result = subprocess.run(command, capture_output=True, text=True) if result.returncode != 0: logger.error(f"Command failed with return code {result.returncode}") logger.error(f"stderr: {result.stderr}") raise HTTPException(status_code=500, detail=result.stderr) response = result.stdout.strip() sdf_file_path = "/root/CHEMISTral7Bv0.3/example/Conformer3D_COMPOUND_CID_240.sdf" # Return the file as a direct download return FileResponse(sdf_file_path, media_type='chemical/x-mdl-sdfile', filename="Conformer3D_COMPOUND_CID_240.sdf") except Exception as e: logger.exception("Exception occurred during inference.") raise HTTPException(status_code=500, detail=str(e)) # @app.post("/predict") # async def predict_alternative( # prompt: str = Form(...), # max_tokens: int = Form(256), # temperature: float = Form(1.0), # file: Optional[UploadFile] = File(None) # ): # try: # global sdf_file_path # if file: # file_path = f"/tmp/{file.filename}" # with open(file_path, "wb") as f: # f.write(file.file.read()) # if file.filename.endswith(".pdb"): # prompt += f" {process_pdb(file_path)}" # elif file.filename.endswith(".sdf"): # prompt += f" {process_sdf(file_path)}" # else: # try: # sdf_file = extract_and_convert_to_sdf(prompt) # if sdf_file: # prompt += f" {sdf_file}" # except ValueError as e: # logger.info(str(e)) # command = [ # "python", # "/root/CHEMISTral7Bv0.3/mistral_chat_script.py", # "/root/mistral_models/7B-v0.3/", # prompt, # f"--max_tokens={max_tokens}", # f"--temperature={temperature}", # "--instruct", # "--lora_path=/root/CHEMISTral7Bv0.3/runs/checkpoints/checkpoint_000300/consolidated/lora.safetensors" # ] # logger.info(f"Running command: {' '.join(command)}") # result = subprocess.run(command, capture_output=True, text=True) # if result.returncode != 0: # logger.error(f"Command failed with return code {result.returncode}") # logger.error(f"stderr: {result.stderr}") # raise HTTPException(status_code=500, detail=result.stderr) # response = result.stdout.strip() # sdf_file_path = "/root/CHEMISTral7Bv0.3/example/Conformer3D_COMPOUND_CID_240.sdf" # return { # "response": response, # "sdf_file_path": sdf_file_path # } # except Exception as e: # logger.exception("Exception occurred during inference.") # raise HTTPException(status_code=500, detail=str(e)) @app.get("/download_sdf") async def download_sdf(): try: return FileResponse(path=sdf_file_path, filename="Conformer3D_COMPOUND_CID_240.sdf") except Exception as e: logger.exception("Exception occurred while sending SDF file.") raise HTTPException(status_code=500, detail=str(e)) if __name__ == "__main__": import uvicorn uvicorn.run(app, host="0.0.0.0", port=8000)