BIOML commited on
Commit
2cf2d78
1 Parent(s): 28b349c

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +76 -31
app.py CHANGED
@@ -1,31 +1,76 @@
1
- from analze import *
2
-
3
- app = Flask(__name__)
4
-
5
- @app.route('/')
6
- def home():
7
- return render_template('home.html')
8
-
9
- @app.route('/upload', methods=['GET', 'POST'])
10
- def upload_file():
11
- if request.method == 'POST':
12
- # Check if a file was uploaded
13
- if 'file' not in request.files:
14
- return render_template('home.html', content='No file uploaded.')
15
- file = request.files['file']
16
- # Check if the file has a filename
17
- if file.filename == '':
18
- return render_template('home.html', content='No file selected.')
19
- filepath = 'email files/' + file.filename
20
- return render_template('home.html',
21
- content=check_file_type(file),
22
- features = get_features(filepath),
23
- pre_content=predict_content(text_feature(filepath)),
24
- pre_tag=predict_html(html_tags_feature(filepath)),
25
- pre_num=predict_num(num_feature(filepath)),
26
- pre_extra=predict_extra(extra_feature(filepath)))
27
-
28
- return render_template('home.html')
29
-
30
- if __name__ == '__main__':
31
- app.run(host='127.0.0.1',port=5000)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import streamlit as st
3
+ from transformers import AutoTokenizer, OPTForCausalLM
4
+
5
+
6
+ @st.cache_resource
7
+ def load_model():
8
+ tokenizer = AutoTokenizer.from_pretrained("facebook/galactica-30b")
9
+ model = OPTForCausalLM.from_pretrained("facebook/galactica-30b", device_map='auto', low_cpu_mem_usage=True, torch_dtype=torch.float16)
10
+ model.gradient_checkpointing_enable()
11
+ return tokenizer, model
12
+
13
+
14
+ st.set_page_config(
15
+ page_title='BioML-SVM',
16
+ layout="wide"
17
+ )
18
+
19
+ with st.spinner("Loading Models and Tokens..."):
20
+ tokenizer, model = load_model()
21
+
22
+ with st.form(key='my_form'):
23
+ col1, col2 = st.columns([10, 1])
24
+ text_input = col1.text_input(label='Enter the amino sequence')
25
+ with col2:
26
+ st.text('')
27
+ st.text('')
28
+ submit_button = st.form_submit_button(label='Submit')
29
+
30
+ if submit_button:
31
+ st.session_state['result_done'] = False
32
+ # input_text = "[START_AMINO]GHMQSITAGQKVISKHKNGRFYQCEVVRLTTETFYEVNFDDGSFSDNLYPEDIVSQDCLQFGPPAEGEVVQVRWTDGQVYGAKFVASHPIQMYQVEFEDGSQLVVKRDDVYTLDEELP[END_AMINO]"
33
+ with st.spinner('Generating...'):
34
+ # formatted_text = f"[START_AMINO]{text_input}[END_AMINO]"
35
+ # formatted_text = f"Here is the sequence: [START_AMINO]{text_input}[END_AMINO]"
36
+ formatted_text = f"{text_input}"
37
+ input_ids = tokenizer(formatted_text, return_tensors="pt").input_ids.to("cuda")
38
+ outputs = model.generate(
39
+ input_ids=input_ids,
40
+ max_new_tokens=500
41
+ )
42
+ result = tokenizer.decode(outputs[0]).replace(formatted_text, "")
43
+ st.markdown(result)
44
+
45
+ if 'result_done' not in st.session_state or not st.session_state.result_done:
46
+ st.session_state['result_done'] = True
47
+ st.session_state['previous_state'] = result
48
+ else:
49
+ if 'result_done' in st.session_state and st.session_state.result_done:
50
+ st.markdown(st.session_state.previous_state)
51
+
52
+ if 'result_done' in st.session_state and st.session_state.result_done:
53
+ with st.form(key='ask_more'):
54
+ col1, col2 = st.columns([10, 1])
55
+ text_input = col1.text_input(label='Ask more question')
56
+ with col2:
57
+ st.text('')
58
+ st.text('')
59
+ submit_button = st.form_submit_button(label='Submit')
60
+
61
+ if submit_button:
62
+ with st.spinner('Generating...'):
63
+ # formatted_text = f"[START_AMINO]{text_input}[END_AMINO]"
64
+ formatted_text = f"Q:{text_input}\n\nA:\n\n"
65
+ input_ids = tokenizer(formatted_text, return_tensors="pt").input_ids.to("cuda")
66
+
67
+ outputs = model.generate(
68
+ input_ids=input_ids,
69
+ max_length=len(formatted_text) + 500,
70
+ do_sample=True,
71
+ top_k=40,
72
+ num_beams=1,
73
+ num_return_sequences=1
74
+ )
75
+ result = tokenizer.decode(outputs[0]).replace(formatted_text, "")
76
+ st.markdown(result)