Skip to content

Commit

Permalink
Cleaned up code some more and added extra cases for input and output …
Browse files Browse the repository at this point in the history
…tests
  • Loading branch information
BrentBlanckaert committed Dec 28, 2024
1 parent a35c49a commit 10b0eb3
Showing 1 changed file with 67 additions and 61 deletions.
128 changes: 67 additions & 61 deletions tested/nat_translation.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,10 +14,35 @@
_validate_testcase_combinations,
)

def natural_langauge_map_translation(value: YamlObject, language: str):
if isinstance(value, NaturalLanguageMap):
assert language in value
value = value[language]
return value

def parse_value(
value: list | str | int | float | dict, flattened_stack: dict
) -> list | str | int | float | dict:
def translate_input_files(dsl_object: dict, language: str, flattened_stack: dict) -> dict:
if (files := dsl_object.get("files")) is not None:
# Translation map can happen at the top level.
files = natural_langauge_map_translation(files, language)
assert isinstance(files, list)
for i in range(len(files)):
file = files[i]

# Do the formatting.
if isinstance(file, dict):
name = file["name"]
assert isinstance(name, str)
file["name"] = format_string(name, flattened_stack)
url = file["url"]
assert isinstance(url, str)
file["url"] = format_string(url, flattened_stack)
files[i] = file

dsl_object["files"] = files
return dsl_object


def parse_value(value: YamlObject, flattened_stack: dict) -> YamlObject:
# Will format the strings in different values.

if isinstance(value, str):
Expand Down Expand Up @@ -48,19 +73,15 @@ def format_string(string: str, flattened) -> str:
def translate_io(
io_object: YamlObject, key: str, language: str, flat_stack: dict
) -> YamlObject:
if isinstance(io_object, NaturalLanguageMap):
assert language in io_object
io_object = io_object[language]
# Translate NaturalLanguageMap
io_object = natural_langauge_map_translation(io_object, language)

if isinstance(io_object, dict):
data = io_object[key]
if isinstance(data, dict):
assert language in data
data = data[language]
data = natural_langauge_map_translation(io_object[key], language)
assert isinstance(data, str)
io_object[key] = format_string(data, flat_stack)

# Perform translation based of translation stack.
print(io_object)
if isinstance(io_object, str):
return format_string(io_object, flat_stack)

Expand All @@ -75,10 +96,8 @@ def translate_testcase(

key_to_set = "statement" if "statement" in testcase else "expression"
if (expr_stmt := testcase.get(key_to_set)) is not None:
# Must use !natural_language
if isinstance(expr_stmt, NaturalLanguageMap):
assert language in expr_stmt
expr_stmt = expr_stmt[language]
# Translate NaturalLanguageMap
expr_stmt = natural_langauge_map_translation(expr_stmt, language)

# Perform translation based of translation stack.
if isinstance(expr_stmt, dict):
Expand All @@ -90,35 +109,34 @@ def translate_testcase(

else:
if (stdin_stmt := testcase.get("stdin")) is not None:
if isinstance(stdin_stmt, dict):
assert language in stdin_stmt
stdin_stmt = stdin_stmt[language]
# Translate NaturalLanguageMap
stdin_stmt = natural_langauge_map_translation(stdin_stmt, language)

# Perform translation based of translation stack.
assert isinstance(stdin_stmt, str)
testcase["stdin"] = format_string(stdin_stmt, flat_stack)

# Translate NaturalLanguageMap
arguments = testcase.get("arguments", [])
if isinstance(arguments, dict):
assert language in arguments
arguments = arguments[language]
arguments = natural_langauge_map_translation(arguments, language)

# Perform translation based of translation stack.
assert isinstance(arguments, list)
testcase["arguments"] = [
format_string(str(arg), flat_stack) for arg in arguments
]
testcase["arguments"] = parse_value(arguments, flat_stack)

if (stdout := testcase.get("stdout")) is not None:
# Must use !natural_language
testcase["stdout"] = translate_io(stdout, "data", language, flat_stack)

if (file := testcase.get("file")) is not None:
# Must use !natural_language
if isinstance(file, NaturalLanguageMap):
assert language in file
testcase["file"] = file[language]
# TODO: SHOULD I ADD SUPPORT FOR TRANSLATION STACK HERE?
# Translate NaturalLanguageMap
file = natural_langauge_map_translation(file, language)

assert isinstance(file, dict)
file["content"] = format_string(str(file["content"]), flat_stack)
file["location"] = format_string(str(file["location"]), flat_stack)

testcase["file"] = file

if (stderr := testcase.get("stderr")) is not None:
testcase["stderr"] = translate_io(stderr, "data", language, flat_stack)

Expand All @@ -128,51 +146,37 @@ def translate_testcase(
if (result := testcase.get("return")) is not None:
if isinstance(result, ReturnOracle):
arguments = result.get("arguments", [])
if isinstance(arguments, dict):
assert language in arguments
arguments = arguments[language]
arguments = natural_langauge_map_translation(arguments, language)

# Perform translation based of translation stack.
result["arguments"] = [
format_string(str(arg), flat_stack) for arg in arguments
]
result["arguments"] = parse_value(arguments, flat_stack)

value = result.get("value")
# Must use !natural_language
if isinstance(value, NaturalLanguageMap):
assert language in value
value = value[language]
value = natural_langauge_map_translation(value, language)

assert isinstance(value, str)
result["value"] = parse_value(value, flat_stack)
testcase["return"] = result

elif isinstance(result, NaturalLanguageMap):
# Must use !natural_language
assert language in result
testcase["return"] = parse_value(result[language], flat_stack)
elif result is not None:
testcase["return"] = parse_value(result, flat_stack)

if (description := testcase.get("description")) is not None:
# Must use !natural_language
if isinstance(description, NaturalLanguageMap):
assert language in description
description = description[language]
description = natural_langauge_map_translation(description, language)

if isinstance(description, str):
testcase["description"] = format_string(description, flat_stack)

if isinstance(description, dict):
else:
assert isinstance(description, dict)
dd = description["description"]
if isinstance(dd, dict):
assert language in dd
dd = dd[language]
dd = natural_langauge_map_translation(dd, language)

assert isinstance(dd, str)
description["description"] = format_string(dd, flat_stack)

testcase["description"] = description
testcase = translate_input_files(testcase, language, flat_stack)

return testcase

Expand All @@ -192,6 +196,8 @@ def translate_contexts(contexts: list, language: str, translation_stack: list) -
result = []
for context in contexts:
assert isinstance(context, dict)

# Add translation to stack
if "translation" in context:
translation_stack.append(context["translation"])

Expand All @@ -201,12 +207,12 @@ def translate_contexts(contexts: list, language: str, translation_stack: list) -
context[key_to_set] = translate_testcases(
raw_testcases, language, translation_stack
)
if "files" in context:
files = context.get("files")
if isinstance(files, NaturalLanguageMap):
assert language in files
context["files"] = files[language]

flat_stack = flatten_stack(translation_stack, language)
context = translate_input_files(context, language, flat_stack)
result.append(context)

# Pop translation from stack
if "translation" in context:
translation_stack.pop()
context.pop("translation")
Expand All @@ -217,10 +223,7 @@ def translate_contexts(contexts: list, language: str, translation_stack: list) -
def translate_tab(tab: YamlDict, language: str, translation_stack: list) -> YamlDict:
key_to_set = "unit" if "unit" in tab else "tab"
name = tab.get(key_to_set)

if isinstance(name, dict):
assert language in name
name = name[language]
name = natural_langauge_map_translation(name, language)

assert isinstance(name, str)
tab[key_to_set] = format_string(name, flatten_stack(translation_stack, language))
Expand Down Expand Up @@ -283,6 +286,9 @@ def translate_dsl(dsl_object: YamlObject, language: str) -> YamlObject:
if "translation" in dsl_object:
translation_stack.append(dsl_object["translation"])
dsl_object.pop("translation")

flat_stack = flatten_stack(translation_stack, language)
dsl_object = translate_input_files(dsl_object, language, flat_stack)
dsl_object[key_to_set] = translate_tabs(tab_list, language, translation_stack)
return dsl_object

Expand Down

0 comments on commit 10b0eb3

Please sign in to comment.