Spaces:
Sleeping
Sleeping
File size: 3,400 Bytes
8edd1fa cd77e73 8edd1fa 1ee7467 cd77e73 b6bcc82 1ee7467 122da82 1ee7467 cd77e73 122da82 cd77e73 1ee7467 cd77e73 1ee7467 cd77e73 122da82 cd77e73 1ee7467 cd77e73 1ee7467 cd77e73 1ee7467 8edd1fa b43bca2 8edd1fa 1ee7467 8edd1fa cd77e73 8edd1fa 3d9cfb9 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 |
import gradio as gr
import openai
import os
from dotenv import load_dotenv
import requests
from transformers import BertTokenizer, BertForSequenceClassification
import torch
import faiss
import numpy as np
# Load .env
load_dotenv()
# API Keys and Org ID
openai.api_key = os.getenv("OPENAI_API_KEY")
openai.organization = os.getenv("OPENAI_ORG_ID")
serper_api_key = os.getenv("SERPER_API_KEY")
# Load PubMedBERT tokenizer and model
tokenizer = BertTokenizer.from_pretrained("microsoft/BiomedNLP-PubMedBERT-base-uncased-abstract")
model = BertForSequenceClassification.from_pretrained("microsoft/BiomedNLP-PubMedBERT-base-uncased-abstract", num_labels=2)
# FAISS setup for vector search
dimension = 768
index = faiss.IndexFlatL2(dimension)
# Function to embed text (PubMedBERT)
def embed_text(text):
inputs = tokenizer(text, return_tensors="pt", truncation=True, padding="max_length", max_length=512)
outputs = model(**inputs, output_hidden_states=True)
hidden_state = outputs.hidden_states[-1]
return hidden_state.mean(dim=1).detach().numpy()
# Function to retrieve info from PubMedBERT
def handle_fda_query(query):
inputs = tokenizer(query, return_tensors="pt", padding="max_length", truncation=True, max_length=512)
outputs = model(**inputs)
logits = outputs.logits
prediction = torch.argmax(logits, dim=1).item()
# Simulate a meaningful FDA-related response
if prediction == 1:
return f"FDA Query Processed: '{query}' contains important regulatory information."
else:
return f"FDA Query Processed: '{query}' seems to be general and not regulatory-heavy."
# Function to enhance info via GPT-4o-mini
def enhance_with_gpt4o(fda_response):
try:
response = openai.ChatCompletion.create(
model="gpt-4o-mini", # Correct model
messages=[{"role": "system", "content": "You are an expert FDA assistant."}, {"role": "user", "content": f"Enhance this FDA info: {fda_response}"}],
max_tokens=150
)
return response['choices'][0]['message']['content']
except Exception as e:
return f"Error: {str(e)}"
# Main function that gets PubMedBERT output and enhances it using GPT-4o-mini
def respond(message, system_message, max_tokens, temperature, top_p):
try:
# First retrieve info via PubMedBERT
fda_response = handle_fda_query(message)
# Then enhance this info via GPT-4o-mini
enhanced_response = enhance_with_gpt4o(fda_response)
# Return both the PubMedBERT result and the enhanced version
return f"Original Info from PubMedBERT: {fda_response}\n\nEnhanced Info via GPT-4o-mini: {enhanced_response}"
except Exception as e:
return f"Error: {str(e)}"
# Gradio Interface
demo = gr.Interface(
fn=respond,
inputs=[
gr.Textbox(label="Enter your FDA query", placeholder="Ask Ferris2.0 anything FDA-related."),
gr.Textbox(value="You are Ferris2.0, the most advanced FDA Regulatory Assistant.", label="System message"),
gr.Slider(minimum=1, maximum=2048, value=512, step=1, label="Max new tokens"),
gr.Slider(minimum=0.1, maximum=4.0, value=0.7, step=0.1, label="Temperature"),
gr.Slider(minimum=0.1, maximum=1.0, value=0.95, step=0.05, label="Top-p (nucleus sampling)")
],
outputs="text",
)
if __name__ == "__main__":
demo.launch() |