|
import json |
|
from typing import Any, Dict, Union |
|
|
|
from autogpt.call_ai_function import call_ai_function |
|
from autogpt.config import Config |
|
from autogpt.json_utils import correct_json |
|
from autogpt.logger import logger |
|
|
|
cfg = Config() |
|
|
|
JSON_SCHEMA = """ |
|
{ |
|
"command": { |
|
"name": "command name", |
|
"args": { |
|
"arg name": "value" |
|
} |
|
}, |
|
"thoughts": |
|
{ |
|
"text": "thought", |
|
"reasoning": "reasoning", |
|
"plan": "- short bulleted\n- list that conveys\n- long-term plan", |
|
"criticism": "constructive self-criticism", |
|
"speak": "thoughts summary to say to user" |
|
} |
|
} |
|
""" |
|
|
|
|
|
def fix_and_parse_json( |
|
json_str: str, try_to_fix_with_gpt: bool = True |
|
) -> Union[str, Dict[Any, Any]]: |
|
"""Fix and parse JSON string""" |
|
try: |
|
json_str = json_str.replace("\t", "") |
|
return json.loads(json_str) |
|
except json.JSONDecodeError as _: |
|
try: |
|
json_str = correct_json(json_str) |
|
return json.loads(json_str) |
|
except json.JSONDecodeError as _: |
|
pass |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
try: |
|
brace_index = json_str.index("{") |
|
json_str = json_str[brace_index:] |
|
last_brace_index = json_str.rindex("}") |
|
json_str = json_str[: last_brace_index + 1] |
|
return json.loads(json_str) |
|
|
|
except (json.JSONDecodeError, ValueError) as e: |
|
if try_to_fix_with_gpt: |
|
logger.warn( |
|
"Warning: Failed to parse AI output, attempting to fix." |
|
"\n If you see this warning frequently, it's likely that" |
|
" your prompt is confusing the AI. Try changing it up" |
|
" slightly." |
|
) |
|
|
|
ai_fixed_json = fix_json(json_str, JSON_SCHEMA) |
|
|
|
if ai_fixed_json != "failed": |
|
return json.loads(ai_fixed_json) |
|
else: |
|
|
|
|
|
logger.error("Failed to fix AI output, telling the AI.") |
|
return json_str |
|
else: |
|
raise e |
|
|
|
|
|
def fix_json(json_str: str, schema: str) -> str: |
|
"""Fix the given JSON string to make it parseable and fully compliant with the provided schema.""" |
|
|
|
function_string = "def fix_json(json_str: str, schema:str=None) -> str:" |
|
args = [f"'''{json_str}'''", f"'''{schema}'''"] |
|
description_string = ( |
|
"Fixes the provided JSON string to make it parseable" |
|
" and fully compliant with the provided schema.\n If an object or" |
|
" field specified in the schema isn't contained within the correct" |
|
" JSON, it is omitted.\n This function is brilliant at guessing" |
|
" when the format is incorrect." |
|
) |
|
|
|
|
|
if not json_str.startswith("`"): |
|
json_str = "```json\n" + json_str + "\n```" |
|
result_string = call_ai_function( |
|
function_string, args, description_string, model=cfg.fast_llm_model |
|
) |
|
logger.debug("------------ JSON FIX ATTEMPT ---------------") |
|
logger.debug(f"Original JSON: {json_str}") |
|
logger.debug("-----------") |
|
logger.debug(f"Fixed JSON: {result_string}") |
|
logger.debug("----------- END OF FIX ATTEMPT ----------------") |
|
|
|
try: |
|
json.loads(result_string) |
|
return result_string |
|
except: |
|
|
|
|
|
|
|
|
|
return "failed" |
|
|