PAYADOR-experiments / models.py
sgongora27's picture
Update models.py
f404d37 verified
raw
history blame
1.88 kB
"""Load models to use them as a narrator and a common-sense oracle in the PAYADOR pipeline."""
import google.generativeai as genai
import requests
import os
class GeminiModel():
def __init__ (self, api_key_file:str, model_name:str = "gemini-pro") -> None:
""""Initialize the Gemini model using an API key."""
self.safety_settings = [
{
"category": "HARM_CATEGORY_DANGEROUS",
"threshold": "BLOCK_NONE",
},
{
"category": "HARM_CATEGORY_HARASSMENT",
"threshold": "BLOCK_NONE",
},
{
"category": "HARM_CATEGORY_HATE_SPEECH",
"threshold": "BLOCK_NONE",
},
{
"category": "HARM_CATEGORY_SEXUALLY_EXPLICIT",
"threshold": "BLOCK_NONE",
},
{
"category": "HARM_CATEGORY_DANGEROUS_CONTENT",
"threshold": "BLOCK_NONE",
},
]
genai.configure(api_key=os.getenv(api_key_file))
self.model = genai.GenerativeModel(model_name)
def prompt_model(self,prompt: str) -> str:
"""Prompt the Gemini model."""
return self.model.generate_content(prompt, safety_settings=self.safety_settings).text
def prompt_HF_API (prompt: str, model: str = "microsoft/Phi-3-mini-4k-instruct", api_key_file: str = "HF_API_key"):
API_URL = f"https://api-inference.huggingface.co/models/{model}"
headers = {"Authorization": f"Bearer {get_api_key(api_key_file)}"}
payload = {"inputs": prompt}
output = requests.post(API_URL, headers=headers, json=payload).json()
return output[0]["generated_text"]
def get_api_key(path: str) -> str:
"""Load an API key from path."""
key = ""
with open(path) as f:
key = f.readline()
return key