Skip to content

Commit

Permalink
Add attempt to fix invalid json before throwing error (#106)
Browse files Browse the repository at this point in the history
* Add attempt to fix invalid json before throwing error

* Add checks to balance braces

* Add function to balance braces and brackets

* Fixed multiple issues test
  • Loading branch information
willtai authored Aug 16, 2024
1 parent 2b78391 commit 9054023
Show file tree
Hide file tree
Showing 3 changed files with 277 additions and 10 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
import enum
import json
import logging
import re
from datetime import datetime
from typing import Any, Dict, List, Union

Expand Down Expand Up @@ -50,6 +51,79 @@ class OnError(enum.Enum):
NODE_TO_CHUNK_RELATIONSHIP_TYPE = "FROM_CHUNK"


def balance_curly_braces(json_string: str) -> str:
"""
Balances curly braces `{}` in a JSON string. This function ensures that every opening brace has a corresponding
closing brace, but only when they are not part of a string value. If there are unbalanced closing braces,
they are ignored. If there are missing closing braces, they are appended at the end of the string.
Args:
json_string (str): A potentially malformed JSON string with unbalanced curly braces.
Returns:
str: A JSON string with balanced curly braces.
"""
stack = []
fixed_json = []
in_string = False
escape = False

for char in json_string:
if char == '"' and not escape:
in_string = not in_string
elif char == "\\" and in_string:
escape = not escape
fixed_json.append(char)
continue
else:
escape = False

if not in_string:
if char == "{":
stack.append(char)
fixed_json.append(char)
elif char == "}" and stack and stack[-1] == "{":
stack.pop()
fixed_json.append(char)
elif char == "}" and (not stack or stack[-1] != "{"):
continue
else:
fixed_json.append(char)
else:
fixed_json.append(char)

# If stack is not empty, add missing closing braces
while stack:
stack.pop()
fixed_json.append("}")

return "".join(fixed_json)


def fix_invalid_json(invalid_json_string: str) -> str:
# Fix missing quotes around field names
invalid_json_string = re.sub(
r"([{,]\s*)(\w+)(\s*:)", r'\1"\2"\3', invalid_json_string
)

# Fix missing quotes around string values, correctly ignoring null, true, false, and numeric values
invalid_json_string = re.sub(
r"(?<=:\s)(?!(null|true|false|\d+\.?\d*))([a-zA-Z_][a-zA-Z0-9_]*)\s*(?=[,}])",
r'"\2"',
invalid_json_string,
)

# Correct the specific issue: remove trailing commas within arrays or objects before closing braces or brackets
invalid_json_string = re.sub(r",\s*(?=[}\]])", "", invalid_json_string)

# Normalize excessive curly braces
invalid_json_string = re.sub(r"{{+", "{", invalid_json_string)
invalid_json_string = re.sub(r"}}+", "}", invalid_json_string)

# Balance curly braces
return balance_curly_braces(invalid_json_string)


class EntityRelationExtractor(Component, abc.ABC):
"""Abstract class for entity relation extraction components.
Expand Down Expand Up @@ -200,16 +274,20 @@ async def extract_for_chunk(
llm_result = self.llm.invoke(prompt)
try:
result = json.loads(llm_result.content)
except json.JSONDecodeError as e:
if self.on_error == OnError.RAISE:
raise LLMGenerationError(
f"LLM response is not valid JSON {llm_result.content}: {e}"
)
else:
logger.error(
f"LLM response is not valid JSON {llm_result.content} for chunk_index={chunk_index}"
)
result = {"nodes": [], "relationships": []}
except json.JSONDecodeError:
fixed_content = fix_invalid_json(llm_result.content)
try:
result = json.loads(fixed_content)
except json.JSONDecodeError as e:
if self.on_error == OnError.RAISE:
raise LLMGenerationError(
f"LLM response is not valid JSON {fixed_content}: {e}"
)
else:
logger.error(
f"LLM response is not valid JSON {llm_result.content} for chunk_index={chunk_index}"
)
result = {"nodes": [], "relationships": []}
try:
chunk_graph = Neo4jGraph(**result)
except ValidationError as e:
Expand Down
2 changes: 2 additions & 0 deletions src/neo4j_genai/generation/prompts.py
Original file line number Diff line number Diff line change
Expand Up @@ -140,6 +140,8 @@ class ERExtractionTemplate(PromptTemplate):
Do respect the source and target node types for relationship and
the relationship direction.
Do not return any additional information other than the JSON in it.
Examples:
{examples}
Expand Down
187 changes: 187 additions & 0 deletions tests/unit/experimental/components/test_entity_relation_extractor.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
# limitations under the License.
from __future__ import annotations

import json
from unittest.mock import MagicMock

import pytest
Expand All @@ -22,6 +23,8 @@
EntityRelationExtractor,
LLMEntityRelationExtractor,
OnError,
balance_curly_braces,
fix_invalid_json,
)
from neo4j_genai.experimental.components.types import (
Neo4jGraph,
Expand Down Expand Up @@ -214,3 +217,187 @@ async def test_extractor_custom_prompt() -> None:
chunks = TextChunks(chunks=[TextChunk(text="some text")])
await extractor.run(chunks=chunks)
llm.invoke.assert_called_once_with("this is my prompt")


def test_fix_unquoted_keys() -> None:
json_string = '{name: "John", age: "30"}'
expected_result = '{"name": "John", "age": "30"}'

fixed_json = fix_invalid_json(json_string)

assert json.loads(fixed_json)
assert fixed_json == expected_result


def test_fix_unquoted_string_values() -> None:
json_string = '{"name": John, "age": 30}'
expected_result = '{"name": "John", "age": 30}'

fixed_json = fix_invalid_json(json_string)

assert json.loads(fixed_json)
assert fixed_json == expected_result


def test_remove_trailing_commas() -> None:
json_string = '{"name": "John", "age": 30,}'
expected_result = '{"name": "John", "age": 30}'

fixed_json = fix_invalid_json(json_string)

assert json.loads(fixed_json)
assert fixed_json == expected_result


def test_fix_excessive_braces() -> None:
json_string = '{{"name": "John"}}'
expected_result = '{"name": "John"}'

fixed_json = fix_invalid_json(json_string)

assert json.loads(fixed_json)
assert fixed_json == expected_result


def test_fix_multiple_issues() -> None:
json_string = '{name: John, "hobbies": ["reading", "swimming",], "age": 30}'
expected_result = '{"name": "John", "hobbies": ["reading", "swimming"], "age": 30}'

fixed_json = fix_invalid_json(json_string)

assert json.loads(fixed_json)
assert fixed_json == expected_result


def test_fix_null_values() -> None:
json_string = '{"name": John, "nickname": null}'
expected_result = '{"name": "John", "nickname": null}'

fixed_json = fix_invalid_json(json_string)

assert json.loads(fixed_json)
assert fixed_json == expected_result


def test_fix_numeric_values() -> None:
json_string = '{"age": 30, "score": 95.5}'
expected_result = '{"age": 30, "score": 95.5}'

fixed_json = fix_invalid_json(json_string)

assert json.loads(fixed_json)
assert fixed_json == expected_result


def test_balance_curly_braces_missing_closing() -> None:
json_string = '{"name": "John", "hobbies": {"reading": "yes"'
expected_result = '{"name": "John", "hobbies": {"reading": "yes"}}'

fixed_json = balance_curly_braces(json_string)

assert json.loads(fixed_json)
assert fixed_json == expected_result


def test_balance_curly_braces_extra_closing() -> None:
json_string = '{"name": "John", "hobbies": {"reading": "yes"}}}'
expected_result = '{"name": "John", "hobbies": {"reading": "yes"}}'

fixed_json = balance_curly_braces(json_string)

assert json.loads(fixed_json)
assert fixed_json == expected_result


def test_balance_curly_braces_balanced_input() -> None:
json_string = '{"name": "John", "hobbies": {"reading": "yes"}, "age": 30}'
expected_result = json_string

fixed_json = balance_curly_braces(json_string)

assert json.loads(fixed_json)
assert fixed_json == expected_result


def test_balance_curly_braces_nested_structure() -> None:
json_string = '{"person": {"name": "John", "hobbies": {"reading": "yes"}}}'
expected_result = json_string

fixed_json = balance_curly_braces(json_string)

assert json.loads(fixed_json)
assert fixed_json == expected_result


def test_balance_curly_braces_unbalanced_nested() -> None:
json_string = '{"person": {"name": "John", "hobbies": {"reading": "yes"}}'
expected_result = '{"person": {"name": "John", "hobbies": {"reading": "yes"}}}'

fixed_json = balance_curly_braces(json_string)

assert json.loads(fixed_json)
assert fixed_json == expected_result


def test_balance_curly_braces_unmatched_openings() -> None:
json_string = '{"name": "John", "hobbies": {"reading": "yes"'
expected_result = '{"name": "John", "hobbies": {"reading": "yes"}}'

fixed_json = balance_curly_braces(json_string)

assert json.loads(fixed_json)
assert fixed_json == expected_result


def test_balance_curly_braces_unmatched_closings() -> None:
json_string = '{"name": "John", "hobbies": {"reading": "yes"}}}'
expected_result = '{"name": "John", "hobbies": {"reading": "yes"}}'

fixed_json = balance_curly_braces(json_string)

assert json.loads(fixed_json)
assert fixed_json == expected_result


def test_balance_curly_braces_complex_structure() -> None:
json_string = (
'{"name": "John", "details": {"age": 30, "hobbies": {"reading": "yes"}}}'
)
expected_result = json_string

fixed_json = balance_curly_braces(json_string)

assert json.loads(fixed_json)
assert fixed_json == expected_result


def test_balance_curly_braces_incorrect_nested_closings() -> None:
json_string = '{"key1": {"key2": {"reading": "yes"}}, "key3": {"age": 30}}}'
expected_result = '{"key1": {"key2": {"reading": "yes"}}, "key3": {"age": 30}}'

fixed_json = balance_curly_braces(json_string)

assert json.loads(fixed_json)
assert fixed_json == expected_result


def test_balance_curly_braces_braces_inside_string() -> None:
json_string = '{"name": "John", "example": "a{b}c", "age": 30}'
expected_result = json_string

fixed_json = balance_curly_braces(json_string)

assert json.loads(fixed_json)
assert fixed_json == expected_result


def test_balance_curly_braces_unbalanced_with_string() -> None:
json_string = '{"name": "John", "example": "a{b}c", "hobbies": {"reading": "yes"'
expected_result = (
'{"name": "John", "example": "a{b}c", "hobbies": {"reading": "yes"}}'
)

fixed_json = balance_curly_braces(json_string)

assert json.loads(fixed_json)
assert fixed_json == expected_result

0 comments on commit 9054023

Please sign in to comment.