|
"""Load models to use them as a narrator and a common-sense oracle in the PAYADOR pipeline.""" |
|
import google.generativeai as genai |
|
import requests |
|
import os |
|
|
|
|
|
class GeminiModel(): |
|
def __init__ (self, api_key_file:str, model_name:str = "gemini-pro") -> None: |
|
""""Initialize the Gemini model using an API key.""" |
|
self.safety_settings = [ |
|
{ |
|
"category": "HARM_CATEGORY_DANGEROUS", |
|
"threshold": "BLOCK_NONE", |
|
}, |
|
{ |
|
"category": "HARM_CATEGORY_HARASSMENT", |
|
"threshold": "BLOCK_NONE", |
|
}, |
|
{ |
|
"category": "HARM_CATEGORY_HATE_SPEECH", |
|
"threshold": "BLOCK_NONE", |
|
}, |
|
{ |
|
"category": "HARM_CATEGORY_SEXUALLY_EXPLICIT", |
|
"threshold": "BLOCK_NONE", |
|
}, |
|
{ |
|
"category": "HARM_CATEGORY_DANGEROUS_CONTENT", |
|
"threshold": "BLOCK_NONE", |
|
}, |
|
] |
|
genai.configure(api_key=os.getenv(api_key_file)) |
|
self.model = genai.GenerativeModel(model_name) |
|
|
|
def prompt_model(self,prompt: str) -> str: |
|
"""Prompt the Gemini model.""" |
|
return self.model.generate_content(prompt, safety_settings=self.safety_settings).text |
|
|
|
|
|
def prompt_HF_API (prompt: str, model: str = "microsoft/Phi-3-mini-4k-instruct", api_key_file: str = "HF_API_key"): |
|
API_URL = f"https://api-inference.huggingface.co/models/{model}" |
|
|
|
headers = {"Authorization": f"Bearer {get_api_key(api_key_file)}"} |
|
payload = {"inputs": prompt} |
|
|
|
output = requests.post(API_URL, headers=headers, json=payload).json() |
|
|
|
return output[0]["generated_text"] |
|
|
|
def get_api_key(path: str) -> str: |
|
"""Load an API key from path.""" |
|
key = "" |
|
with open(path) as f: |
|
key = f.readline() |
|
return key |