Spaces:
Sleeping
Sleeping
import os | |
import json | |
import logging | |
from enum import Enum | |
from pydantic import BaseModel, Field | |
import pandas as pd | |
from huggingface_hub import InferenceClient | |
from tenacity import retry, stop_after_attempt, wait_exponential | |
# Configure logging | |
logger = logging.getLogger(__name__) | |
logger.setLevel(logging.INFO) | |
# Create handlers | |
console_handler = logging.StreamHandler() | |
console_handler.setLevel(logging.INFO) | |
file_handler = logging.FileHandler("hf_api.log") | |
file_handler.setLevel(logging.INFO) | |
# Create formatters and add to handlers | |
formatter = logging.Formatter('%(asctime)s - %(name)s - %(levelname)s - %(message)s') | |
console_handler.setFormatter(formatter) | |
file_handler.setFormatter(formatter) | |
# Add handlers to the logger | |
if not logger.handlers: | |
logger.addHandler(console_handler) | |
logger.addHandler(file_handler) | |
# Validate and retrieve the Hugging Face API token | |
HF_TOKEN = os.environ.get('HF_TOKEN') | |
if not HF_TOKEN: | |
logger.error("Hugging Face API token not found. Set the HF_TOKEN environment variable.") | |
raise EnvironmentError("HF_TOKEN environment variable is not set.") | |
# Initialize the InferenceClient | |
MODEL_NAME1 = "meta-llama/Llama-3.1-8B-Instruct" | |
MODEL_NAME2 = "Qwen/Qwen2.5-72B-Instruct" | |
try: | |
client1 = InferenceClient(model=MODEL_NAME1, token=HF_TOKEN) | |
logger.info(f"InferenceClient for model '{MODEL_NAME1}' instantiated successfully.") | |
except Exception as e: | |
logger.error(f"Failed to instantiate InferenceClient for model '{MODEL_NAME1}': {e}") | |
raise | |
try: | |
client2 = InferenceClient(model=MODEL_NAME2, token=HF_TOKEN) | |
logger.info(f"InferenceClient for model '{MODEL_NAME2}' instantiated successfully.") | |
except Exception as e: | |
logger.error(f"Failed to instantiate InferenceClient for model '{MODEL_NAME2}': {e}") | |
raise | |
# Define Pydantic schemas | |
class EvaluationSchema(BaseModel): | |
reasoning: str | |
relevance_score: int = Field(ge=0, le=10) | |
class TopicEnum(Enum): | |
Rheumatoid_Arthritis = "Rheumatoid Arthritis" | |
Systemic_Lupus_Erythematosus = "Systemic Lupus Erythematosus" | |
Scleroderma = "Scleroderma" | |
Sjogren_s_Disease = "Sjogren's Disease" | |
Ankylosing_Spondylitis = "Ankylosing Spondylitis" | |
Psoriatic_Arthritis = "Psoriatic Arthritis" | |
Gout = "Gout" | |
Vasculitis = "Vasculitis" | |
Osteoarthritis = "Osteoarthritis" | |
Infectious_Diseases = "Infectious Diseases" | |
Immunology = "Immunology" | |
Genetics = "Genetics" | |
Biologics = "Biologics" | |
Biosimilars = "Biosimilars" | |
Small_Molecules = "Small Molecules" | |
Clinical_Trials = "Clinical Trials" | |
Health_Policy = "Health Policy" | |
Patient_Education = "Patient Education" | |
Other_Rheumatic_Diseases = "Other Rheumatic Diseases" | |
class SummarySchema(BaseModel): | |
summary: str | |
# Enum for topic | |
topic: TopicEnum = TopicEnum.Other_Rheumatic_Diseases | |
class PaperSchema(BaseModel): | |
title: str | |
authors: str | |
journal: str | |
pmid: str | |
class TopicSummarySchema(BaseModel): | |
planning: str | |
summary: str | |
def evaluate_relevance(title: str, abstract: str) -> EvaluationSchema: | |
prompt = f""" | |
Title: {title} | |
Abstract: {abstract} | |
Instructions: Evaluate the relevance of this medical abstract for an audience of rheumatologists on a scale of 0 to 10 with 10 being reserved only for large clinical trials in rheumatology. | |
Be very discerning and only give a score above 8 for papers that are highly clinically relevant to rheumatologists. | |
Respond in JSON format using the following schema: | |
{json.dumps(EvaluationSchema.model_json_schema())} | |
""" | |
try: | |
response = client1.text_generation( | |
prompt, | |
max_new_tokens=512, | |
temperature=0.2, | |
grammar={"type": "json", "value": EvaluationSchema.model_json_schema()} | |
) | |
result = json.loads(response) | |
return result | |
except Exception as e: | |
logger.error(f"Error in evaluate_relevance: {e}") | |
raise | |
def summarize_abstract(abstract: str) -> SummarySchema: | |
prompt = f""" | |
Abstract: {abstract} | |
Instructions: Summarize this medical abstract in 1 sentence and select the most relevant topic from the following enum: | |
{TopicEnum.__doc__} | |
Respond in JSON format using the following schema: | |
{json.dumps(SummarySchema.model_json_schema())} | |
""" | |
try: | |
response = client1.text_generation( | |
prompt, | |
max_new_tokens=512, | |
temperature=0.2, | |
grammar={"type": "json", "value": SummarySchema.model_json_schema()} | |
) | |
result = json.loads(response) | |
return result | |
except Exception as e: | |
logger.error(f"Error in summarize_abstract: {e}") | |
raise | |
def _make_api_call(client, prompt, max_tokens=4096, temp=0.2, schema=None): | |
try: | |
response = client.text_generation( | |
prompt, | |
max_new_tokens=max_tokens, | |
temperature=temp, | |
grammar={"type": "json", "value": schema} if schema else None | |
) | |
return json.loads(response) | |
except Exception as e: | |
logger.error(f"API call failed: {e}") | |
raise | |
def compose_newsletter(papers: pd.DataFrame) -> str: | |
if papers.empty: | |
logger.info("No papers provided to compose the newsletter.") | |
return "" | |
content = ["# This Week in Rheumatology\n"] | |
topics = papers['Topic'].unique() | |
for topic in topics: | |
try: | |
relevant_papers = papers[papers['Topic'] == topic] | |
# Convert to dict with lowercase keys to match the expected schema | |
papers_dict = relevant_papers.rename(columns={ | |
'Title': 'title', | |
'Authors': 'authors', | |
'Journal': 'journal', | |
'PMID': 'pmid', | |
'Summary': 'summary' | |
}).to_dict('records') | |
prompt = f""" | |
Instructions: Generate a brief summary of the latest research on {topic} using the following papers. | |
Papers: {json.dumps(papers_dict)} | |
Respond in JSON format using the following schema: | |
{json.dumps(TopicSummarySchema.model_json_schema())} | |
You have the option of using the planning field first to organize your thoughts before writing the summary. | |
The summary should be concise, but because you are summarizing several papers, it should be detailed enough to give the reader a good idea of the latest research in the field. | |
The papers may be somewhat disjointed, so you will need to think carefully about how you can transition between them with clever wording. | |
You can use anywhere from 1 to 3 paragraphs for the summary. | |
""" | |
result = _make_api_call( | |
client2, | |
prompt, | |
max_tokens=4096, | |
temp=0.2, | |
schema=TopicSummarySchema.model_json_schema() | |
) | |
# Log the raw response for debugging | |
logger.debug(f"Raw response from Hugging Face: {result}") | |
# Parse the JSON response | |
summary = TopicSummarySchema(**result) | |
# Convert the structured summary to Markdown | |
topic_content = f"## {topic}\n\n" | |
topic_content += f"{summary.summary}\n\n" | |
# Add a references section | |
topic_content += "### References\n\n" | |
relevant_papers = papers[papers['Topic'] == topic] | |
for _, paper in relevant_papers.iterrows(): | |
topic_content += (f"- {paper['Title']} by {paper['Authors']}. {paper['Journal']}. " | |
f"[PMID: {paper['PMID']}](https://pubmed.ncbi.nlm.nih.gov/{paper['PMID']}/)\n") | |
content.append(topic_content) | |
except Exception as e: | |
logger.error(f"Error processing topic {topic}: {e}") | |
logger.error(f"Raw response: {result}") | |
continue | |
return "\n".join(content) | |