Spaces:
Runtime error
Runtime error
import ast | |
import copy | |
import toml | |
from string import Template | |
from pathlib import Path | |
from flatdict import FlatDict | |
import google.generativeai as genai | |
from gen.utils import parse_first_json_snippet | |
def determine_model_name(given_image=None): | |
if given_image is None: | |
return "gemini-pro" | |
else: | |
return "gemini-pro-vision" | |
def construct_image_part(given_image): | |
return { | |
"mime_type": "image/jpeg", | |
"data": given_image | |
} | |
def call_gemini(prompt="", API_KEY=None, given_text=None, given_image=None, generation_config=None, safety_settings=None): | |
genai.configure(api_key=API_KEY) | |
if generation_config is None: | |
generation_config = { | |
"temperature": 0.8, | |
"top_p": 1, | |
"top_k": 32, | |
"max_output_tokens": 4096, | |
} | |
if safety_settings is None: | |
safety_settings = [ | |
{ | |
"category": "HARM_CATEGORY_HARASSMENT", | |
"threshold": "BLOCK_NONE" | |
}, | |
{ | |
"category": "HARM_CATEGORY_HATE_SPEECH", | |
"threshold": "BLOCK_NONE" | |
}, | |
{ | |
"category": "HARM_CATEGORY_SEXUALLY_EXPLICIT", | |
"threshold": "BLOCK_NONE" | |
}, | |
{ | |
"category": "HARM_CATEGORY_DANGEROUS_CONTENT", | |
"threshold": "BLOCK_NONE" | |
}, | |
] | |
model_name = determine_model_name(given_image) | |
model = genai.GenerativeModel(model_name=model_name, | |
generation_config=generation_config, | |
safety_settings=safety_settings) | |
USER_PROMPT = prompt | |
if given_text is not None: | |
USER_PROMPT += f"""{prompt} | |
------------------------------------------------ | |
{given_text} | |
""" | |
prompt_parts = [USER_PROMPT] | |
if given_image is not None: | |
prompt_parts.append(construct_image_part(given_image)) | |
response = model.generate_content(prompt_parts) | |
return response.text | |
def try_out(prompt, given_text, gemini_api_key, given_image=None, retry_num=10): | |
qna_json = None | |
cur_retry = 0 | |
while qna_json is None and cur_retry < retry_num: | |
try: | |
qna = call_gemini( | |
prompt=prompt, | |
given_text=given_text, | |
given_image=given_image, | |
API_KEY=gemini_api_key | |
) | |
qna_json = parse_first_json_snippet(qna) | |
except Exception as e: | |
cur_retry = cur_retry + 1 | |
print(f"......retry {e}") | |
return qna_json | |
def get_basic_qa(text, gemini_api_key, trucate=7000): | |
prompts = toml.load(Path('.') / 'constants' / 'prompts.toml') | |
basic_qa = try_out(prompts['basic_qa']['prompt'], text[:trucate], gemini_api_key=gemini_api_key) | |
return basic_qa | |
def get_deep_qa(text, basic_qa, gemini_api_key, trucate=7000): | |
prompts = toml.load(Path('.') / 'constants' / 'prompts.toml') | |
title = basic_qa['title'] | |
qnas = copy.deepcopy(basic_qa['qna']) | |
for idx, qna in enumerate(qnas): | |
q = qna['question'] | |
a_expert = qna['answers']['expert'] | |
depth_search_prompt = Template(prompts['deep_qa']['prompt']).substitute( | |
title=title, previous_question=q, previous_answer=a_expert, tone="in-depth" | |
) | |
breath_search_prompt = Template(prompts['deep_qa']['prompt']).substitute( | |
title=title, previous_question=q, previous_answer=a_expert, tone="broad" | |
) | |
depth_search_response = {} | |
breath_search_response = {} | |
while 'follow up question' not in depth_search_response or \ | |
'answers' not in depth_search_response or \ | |
'eli5' not in depth_search_response['answers'] or \ | |
'expert' not in depth_search_response['answers']: | |
depth_search_response = try_out(depth_search_prompt, text[:trucate], gemini_api_key=gemini_api_key) | |
while 'follow up question' not in breath_search_response or \ | |
'answers' not in breath_search_response or \ | |
'eli5' not in breath_search_response['answers'] or \ | |
'expert' not in breath_search_response['answers']: | |
breath_search_response = try_out(breath_search_prompt, text[:trucate], gemini_api_key=gemini_api_key) | |
if depth_search_response is not None: | |
qna['additional_depth_q'] = depth_search_response | |
if breath_search_response is not None: | |
qna['additional_breath_q'] = breath_search_response | |
qna = FlatDict(qna) | |
qna_tmp = copy.deepcopy(qna) | |
for k in qna_tmp: | |
value = qna.pop(k) | |
qna[f'{idx}_{k}'] = value | |
basic_qa.update(ast.literal_eval(str(qna))) | |
return basic_qa |