|
import os |
|
from typing import Dict, Any |
|
from flow_modules.aiflows.CodeFileEditFlowModule import CodeFileEditAtomicFlow |
|
|
|
class TestCodeFileEditAtomicFlow(CodeFileEditAtomicFlow): |
|
"""Refer to: https://huggingface.co/aiflows/CodeFileEditFlowModule/tree/main |
|
|
|
*Input Interface*: |
|
- `code`: str |
|
- `memory_files`: Dict[str, str] |
|
|
|
*Output Interface*: |
|
- `code_editor_output`: str, the code editor output |
|
- `temp_code_file_location`: str, the location of the temporary code file |
|
|
|
""" |
|
def _generate_import_statement(self, code_lib_location): |
|
""" |
|
Generate the import statement for the code library. |
|
:param code_lib_location: the location of the code library |
|
:type code_lib_location: str |
|
:return: the import statement |
|
:rtype: str |
|
""" |
|
module_dir = os.path.dirname(code_lib_location) |
|
module_name = os.path.splitext(os.path.basename(code_lib_location))[0] |
|
|
|
import_code = ( |
|
f"import sys\n" |
|
f"sys.path.insert(0, '{module_dir}')\n" |
|
f"from {module_name} import *\n" |
|
) |
|
return import_code |
|
|
|
def _generate_content(self, code_lib_location, code_str) -> str: |
|
""" |
|
Generate the content of the temporary code file. |
|
:param code_lib_location: the location of the code library |
|
:type code_lib_location: str |
|
:param code_str: the code string |
|
:type code_str: str |
|
:return: the content of the temporary code file |
|
:rtype: str |
|
""" |
|
import_code_lib_str = self._generate_import_statement(code_lib_location) |
|
content = ( |
|
"# Don't touch this import statement \n" |
|
+ import_code_lib_str + "\n" |
|
"# Here is the code just generated \n" + |
|
code_str + "\n" |
|
"# Below, please provide code to test it.\n" |
|
"# The simplest form could be just calling it with appropriate parameters. \n" |
|
"# You could also assert the output, anyway, the test results will be informed to JARVIS. \n" |
|
"# If you do not write anything, JARVIS just checks if the syntax is alright. \n" |
|
"###########\n" |
|
"# Test Code:\n" + |
|
"\n############\n" |
|
) |
|
return content |
|
|
|
def _generate_temp_file_location(self, code_lib_location): |
|
""" |
|
Generate the location of the temporary code file. |
|
:param code_lib_location: the location of the code library |
|
:type code_lib_location: str |
|
:return: the location of the temporary code file |
|
:rtype: str |
|
""" |
|
directory = os.path.dirname(code_lib_location) |
|
ret = os.path.join(directory, 'temp_tests.py') |
|
return ret |
|
|
|
def _check_input(self, input_data: Dict[str, Any]): |
|
""" |
|
Check if the input data is valid. |
|
:param input_data: the input data |
|
:type input_data: Dict[str, Any] |
|
:raises AssertionError: if code or memory_files is not passed to TestCodeFileEditAtomicFlow |
|
""" |
|
assert "code" in input_data, "code is not passed to TestCodeFileEditAtomicFlow" |
|
assert "memory_files" in input_data, "memory_files is not passed to TestCodeFileEditAtomicFlow" |
|
|
|
|