Spaces:
Sleeping
Sleeping
import streamlit as st | |
from transformers import AutoModelForCausalLM, AutoTokenizer | |
import torch | |
import os | |
# Apply custom CSS for retro 80s green theme | |
def apply_custom_css(): | |
try: | |
with open("style.css") as f: | |
st.markdown(f"<style>{f.read()}</style>", unsafe_allow_html=True) | |
except FileNotFoundError: | |
st.warning("style.css not found. Using default styles.") | |
def load_model(): | |
model_path = "HuggingFaceH4/zephyr-7b-beta" | |
peft_model_path = "yitzashapiro/FDA-guidance-zephyr-7b-beta-PEFT" | |
try: | |
tokenizer = AutoTokenizer.from_pretrained(model_path) | |
model = AutoModelForCausalLM.from_pretrained( | |
model_path, | |
device_map="auto", | |
torch_dtype=torch.float16 # Adjust if necessary | |
).eval() | |
model.load_adapter(peft_model_path) | |
st.success("Model loaded successfully.") | |
except Exception as e: | |
st.error(f"Error loading model: {e}") | |
st.stop() | |
return tokenizer, model | |
def generate_response(tokenizer, model, user_input): | |
messages = [ | |
{"role": "user", "content": user_input} | |
] | |
try: | |
if hasattr(tokenizer, 'apply_chat_template'): | |
input_ids = tokenizer.apply_chat_template( | |
conversation=messages, | |
max_length=45, | |
tokenize=True, | |
add_generation_prompt=True, | |
return_tensors='pt' | |
) | |
else: | |
input_ids = tokenizer( | |
user_input, | |
return_tensors='pt', | |
truncation=True, | |
max_length=45 | |
)['input_ids'] | |
pad_token_id = tokenizer.pad_token_id if tokenizer.pad_token_id is not None else 0 | |
attention_mask = (input_ids != pad_token_id).long() | |
output_ids = model.generate( | |
input_ids.to(model.device), | |
max_length=2048, | |
max_new_tokens=500, | |
attention_mask=attention_mask.to(model.device) | |
) | |
response = tokenizer.decode(output_ids[0][input_ids.shape[1]:], skip_special_tokens=True) | |
return response | |
except Exception as e: | |
st.error(f"Error generating response: {e}") | |
return "An error occurred while generating the response." | |
def main(): | |
apply_custom_css() | |
st.set_page_config(page_title="FDA NDA Submission Assistant", layout="centered") | |
st.title("FDA NDA Submission Assistant") | |
st.write("Ask the model about submitting an NDA to the FDA.") | |
tokenizer, model = load_model() | |
user_input = st.text_input("Enter your question:", "What's the best way to submit an NDA to the FDA?") | |
if st.button("Generate Response"): | |
if user_input.strip() == "": | |
st.error("Please enter a valid question.") | |
else: | |
try: | |
with st.spinner("Generating response..."): | |
response = generate_response(tokenizer, model, user_input) | |
st.success("Response:") | |
st.write(response) | |
except Exception as e: | |
st.error(f"An error occurred: {e}") | |
if __name__ == "__main__": | |
main() |