Ferris2dotOh / app.py
Craig Pretzinger
Updated files for enhanced PubMedBERT and GPT-4o-mini integration
1ee7467
raw
history blame
3.4 kB
import gradio as gr
import openai
import os
from dotenv import load_dotenv
import requests
from transformers import BertTokenizer, BertForSequenceClassification
import torch
import faiss
import numpy as np
# Load .env
load_dotenv()
# API Keys and Org ID
openai.api_key = os.getenv("OPENAI_API_KEY")
openai.organization = os.getenv("OPENAI_ORG_ID")
serper_api_key = os.getenv("SERPER_API_KEY")
# Load PubMedBERT tokenizer and model
tokenizer = BertTokenizer.from_pretrained("microsoft/BiomedNLP-PubMedBERT-base-uncased-abstract")
model = BertForSequenceClassification.from_pretrained("microsoft/BiomedNLP-PubMedBERT-base-uncased-abstract", num_labels=2)
# FAISS setup for vector search
dimension = 768
index = faiss.IndexFlatL2(dimension)
# Function to embed text (PubMedBERT)
def embed_text(text):
inputs = tokenizer(text, return_tensors="pt", truncation=True, padding="max_length", max_length=512)
outputs = model(**inputs, output_hidden_states=True)
hidden_state = outputs.hidden_states[-1]
return hidden_state.mean(dim=1).detach().numpy()
# Function to retrieve info from PubMedBERT
def handle_fda_query(query):
inputs = tokenizer(query, return_tensors="pt", padding="max_length", truncation=True, max_length=512)
outputs = model(**inputs)
logits = outputs.logits
prediction = torch.argmax(logits, dim=1).item()
# Simulate a meaningful FDA-related response
if prediction == 1:
return f"FDA Query Processed: '{query}' contains important regulatory information."
else:
return f"FDA Query Processed: '{query}' seems to be general and not regulatory-heavy."
# Function to enhance info via GPT-4o-mini
def enhance_with_gpt4o(fda_response):
try:
response = openai.ChatCompletion.create(
model="gpt-4o-mini", # Correct model
messages=[{"role": "system", "content": "You are an expert FDA assistant."}, {"role": "user", "content": f"Enhance this FDA info: {fda_response}"}],
max_tokens=150
)
return response['choices'][0]['message']['content']
except Exception as e:
return f"Error: {str(e)}"
# Main function that gets PubMedBERT output and enhances it using GPT-4o-mini
def respond(message, system_message, max_tokens, temperature, top_p):
try:
# First retrieve info via PubMedBERT
fda_response = handle_fda_query(message)
# Then enhance this info via GPT-4o-mini
enhanced_response = enhance_with_gpt4o(fda_response)
# Return both the PubMedBERT result and the enhanced version
return f"Original Info from PubMedBERT: {fda_response}\n\nEnhanced Info via GPT-4o-mini: {enhanced_response}"
except Exception as e:
return f"Error: {str(e)}"
# Gradio Interface
demo = gr.Interface(
fn=respond,
inputs=[
gr.Textbox(label="Enter your FDA query", placeholder="Ask Ferris2.0 anything FDA-related."),
gr.Textbox(value="You are Ferris2.0, the most advanced FDA Regulatory Assistant.", label="System message"),
gr.Slider(minimum=1, maximum=2048, value=512, step=1, label="Max new tokens"),
gr.Slider(minimum=0.1, maximum=4.0, value=0.7, step=0.1, label="Temperature"),
gr.Slider(minimum=0.1, maximum=1.0, value=0.95, step=0.05, label="Top-p (nucleus sampling)")
],
outputs="text",
)
if __name__ == "__main__":
demo.launch()