|
import torch |
|
import pandas as pd |
|
import torch.nn as nn |
|
|
|
from flask import Flask, render_template, request |
|
from model.bilstm import BiLSTM |
|
|
|
app = Flask(__name__) |
|
|
|
model = BiLSTM.load_from_checkpoint('checkpoints/bilstm_result/epoch=23-step=3456.ckpt', lr=1e-3, num_classes=1, input_size=12) |
|
model.eval() |
|
model.freeze() |
|
|
|
@app.route('/', methods=['POST', 'GET']) |
|
def index(): |
|
if request.method == 'POST': |
|
data = { |
|
'tau1': [request.form.get('tau1')], |
|
'tau2': [request.form.get('tau2')], |
|
'tau3': [request.form.get('tau3')], |
|
'tau4': [request.form.get('tau4')], |
|
'p1': [request.form.get('p1')], |
|
'p2': [request.form.get('p2')], |
|
'p3': [request.form.get('p3')], |
|
'p4': [request.form.get('p4')], |
|
'g1': [request.form.get('g1')], |
|
'g2': [request.form.get('g2')], |
|
'g3': [request.form.get('g3')], |
|
'g4': [request.form.get('g4')] |
|
} |
|
|
|
df = pd.DataFrame(data).astype('float') |
|
X = torch.tensor(df.values.tolist()) |
|
|
|
with torch.no_grad(): |
|
preds = model(X) |
|
|
|
preds = nn.Sigmoid()(preds.squeeze(1)) |
|
preds = preds.numpy() |
|
|
|
if preds > 0.5: |
|
result = 'Stable' |
|
|
|
else: |
|
result = 'Unstable' |
|
|
|
return render_template('index.html', result=result) |
|
|
|
return render_template('index.html') |
|
|
|
if __name__ == '__main__': |
|
app.run(debug=True) |