File size: 3,418 Bytes
1da2a5a
 
 
 
9d5b023
1da2a5a
 
 
9d5b023
1da2a5a
 
 
 
 
 
 
 
 
 
 
 
9d5b023
1da2a5a
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
from fastapi import FastAPI, File, UploadFile
from fastapi.responses import HTMLResponse, JSONResponse
# from fastapi.staticfiles import StaticFiles
from starlette.requests import Request
from fastapi.templating import Jinja2Templates
from fastapi.middleware.cors import CORSMiddleware 
from pydantic import *

from module import config, transformers_utility as tr, utils, metrics, dataio
# from prettytable import PrettyTable
import numpy as np

app = FastAPI()
app.add_middleware( 
    CORSMiddleware,
    allow_origins=["*"], 
    allow_credentials=True,
    allow_methods=["*"],
    allow_headers=["*"],
)
# app.mount("FloraBERT.static", StaticFiles(directory="FloraBERT.static"), name="static")
templates = Jinja2Templates(directory="templates")

# table = PrettyTable()
TOKENIZER_DIR = config.models / "byte-level-bpe-tokenizer"
PRETRAINED_MODEL = config.models / "transformer" / "prediction-model" / "saved_model.pth"
DATA_DIR = config.data

def load_model(args, settings):
    return tr.load_model(
        args.model_name,
        args.tokenizer_dir,
        pretrained_model=args.pretrained_model,
        log_offset=args.log_offset,
        **settings,
    )

@app.get("/", response_class=HTMLResponse)
def read_root(request: Request):
    return templates.TemplateResponse("index.html", {"request": request})

@app.post("/uploadfile/")
async def create_upload_file(file: UploadFile = File(...)):
    file_path = DATA_DIR / file.filename
    with open(file_path, "wb") as f:
        f.write(file.file.read())
    return {"filename": file.filename}

@app.get("/process/{filename}", response_class=HTMLResponse)
def process_file(request: Request, filename: str):
    file_path = DATA_DIR / filename
    preds = main(
        data_dir=DATA_DIR,
        train_data=file_path,
        test_data=file_path,
        pretrained_model=PRETRAINED_MODEL,
        tokenizer_dir=TOKENIZER_DIR,
        model_name="roberta-pred-mean-pool",
    )
    predictions = []
    for i in range(len(preds)):
        predictions.append([{"tissue": config.tissues[j], "prediction": preds[i][j] } for j in range(8)])
    # print(predictions)
    return JSONResponse(content=predictions)

def main(data_dir: str, train_data: str, test_data: str, pretrained_model: str, tokenizer_dir: str, model_name: str):
    args = utils.get_args(
        data_dir=data_dir,
        train_data=train_data,
        test_data=test_data,
        pretrained_model=pretrained_model,
        tokenizer_dir=tokenizer_dir,
        model_name=model_name,
    )

    settings = utils.get_model_settings(config.settings, args)
    if args.output_mode:
        settings["output_mode"] = args.output_mode
    if args.tissue_subset is not None:
        settings["num_labels"] = len(args.tissue_subset)
    
    print("Loading model...")
    config_obj, tokenizer, model = load_model(args, settings)

    print("Loading data...")
    datasets = dataio.load_datasets(
        tokenizer,
        args.train_data,
        eval_data=args.eval_data,
        test_data=args.test_data,
        seq_key="text",
        file_type="text",
        filter_empty=args.filter_empty,
        shuffle=False,
    )
    dataset_test = datasets["train"]

    print("Getting predictions:")
    preds = np.exp(np.array(metrics.get_predictions(model, dataset_test))) - 1
    # print(preds)
    # for e in preds:
    #     table.add_row(e)
    # print(table)

    return preds.tolist()