From 90540233319c555a174b6323bf73fdda98869531 Mon Sep 17 00:00:00 2001 From: willtai Date: Fri, 16 Aug 2024 17:50:33 +0100 Subject: [PATCH] Add attempt to fix invalid json before throwing error (#106) * Add attempt to fix invalid json before throwing error * Add checks to balance braces * Add function to balance braces and brackets * Fixed multiple issues test --- .../components/entity_relation_extractor.py | 98 ++++++++- src/neo4j_genai/generation/prompts.py | 2 + .../test_entity_relation_extractor.py | 187 ++++++++++++++++++ 3 files changed, 277 insertions(+), 10 deletions(-) diff --git a/src/neo4j_genai/experimental/components/entity_relation_extractor.py b/src/neo4j_genai/experimental/components/entity_relation_extractor.py index cb46e17f..5bf11f48 100644 --- a/src/neo4j_genai/experimental/components/entity_relation_extractor.py +++ b/src/neo4j_genai/experimental/components/entity_relation_extractor.py @@ -19,6 +19,7 @@ import enum import json import logging +import re from datetime import datetime from typing import Any, Dict, List, Union @@ -50,6 +51,79 @@ class OnError(enum.Enum): NODE_TO_CHUNK_RELATIONSHIP_TYPE = "FROM_CHUNK" +def balance_curly_braces(json_string: str) -> str: + """ + Balances curly braces `{}` in a JSON string. This function ensures that every opening brace has a corresponding + closing brace, but only when they are not part of a string value. If there are unbalanced closing braces, + they are ignored. If there are missing closing braces, they are appended at the end of the string. + + Args: + json_string (str): A potentially malformed JSON string with unbalanced curly braces. + + Returns: + str: A JSON string with balanced curly braces. + """ + stack = [] + fixed_json = [] + in_string = False + escape = False + + for char in json_string: + if char == '"' and not escape: + in_string = not in_string + elif char == "\\" and in_string: + escape = not escape + fixed_json.append(char) + continue + else: + escape = False + + if not in_string: + if char == "{": + stack.append(char) + fixed_json.append(char) + elif char == "}" and stack and stack[-1] == "{": + stack.pop() + fixed_json.append(char) + elif char == "}" and (not stack or stack[-1] != "{"): + continue + else: + fixed_json.append(char) + else: + fixed_json.append(char) + + # If stack is not empty, add missing closing braces + while stack: + stack.pop() + fixed_json.append("}") + + return "".join(fixed_json) + + +def fix_invalid_json(invalid_json_string: str) -> str: + # Fix missing quotes around field names + invalid_json_string = re.sub( + r"([{,]\s*)(\w+)(\s*:)", r'\1"\2"\3', invalid_json_string + ) + + # Fix missing quotes around string values, correctly ignoring null, true, false, and numeric values + invalid_json_string = re.sub( + r"(?<=:\s)(?!(null|true|false|\d+\.?\d*))([a-zA-Z_][a-zA-Z0-9_]*)\s*(?=[,}])", + r'"\2"', + invalid_json_string, + ) + + # Correct the specific issue: remove trailing commas within arrays or objects before closing braces or brackets + invalid_json_string = re.sub(r",\s*(?=[}\]])", "", invalid_json_string) + + # Normalize excessive curly braces + invalid_json_string = re.sub(r"{{+", "{", invalid_json_string) + invalid_json_string = re.sub(r"}}+", "}", invalid_json_string) + + # Balance curly braces + return balance_curly_braces(invalid_json_string) + + class EntityRelationExtractor(Component, abc.ABC): """Abstract class for entity relation extraction components. @@ -200,16 +274,20 @@ async def extract_for_chunk( llm_result = self.llm.invoke(prompt) try: result = json.loads(llm_result.content) - except json.JSONDecodeError as e: - if self.on_error == OnError.RAISE: - raise LLMGenerationError( - f"LLM response is not valid JSON {llm_result.content}: {e}" - ) - else: - logger.error( - f"LLM response is not valid JSON {llm_result.content} for chunk_index={chunk_index}" - ) - result = {"nodes": [], "relationships": []} + except json.JSONDecodeError: + fixed_content = fix_invalid_json(llm_result.content) + try: + result = json.loads(fixed_content) + except json.JSONDecodeError as e: + if self.on_error == OnError.RAISE: + raise LLMGenerationError( + f"LLM response is not valid JSON {fixed_content}: {e}" + ) + else: + logger.error( + f"LLM response is not valid JSON {llm_result.content} for chunk_index={chunk_index}" + ) + result = {"nodes": [], "relationships": []} try: chunk_graph = Neo4jGraph(**result) except ValidationError as e: diff --git a/src/neo4j_genai/generation/prompts.py b/src/neo4j_genai/generation/prompts.py index 132dda6e..bff5765f 100644 --- a/src/neo4j_genai/generation/prompts.py +++ b/src/neo4j_genai/generation/prompts.py @@ -140,6 +140,8 @@ class ERExtractionTemplate(PromptTemplate): Do respect the source and target node types for relationship and the relationship direction. +Do not return any additional information other than the JSON in it. + Examples: {examples} diff --git a/tests/unit/experimental/components/test_entity_relation_extractor.py b/tests/unit/experimental/components/test_entity_relation_extractor.py index e502a21f..41239f53 100644 --- a/tests/unit/experimental/components/test_entity_relation_extractor.py +++ b/tests/unit/experimental/components/test_entity_relation_extractor.py @@ -14,6 +14,7 @@ # limitations under the License. from __future__ import annotations +import json from unittest.mock import MagicMock import pytest @@ -22,6 +23,8 @@ EntityRelationExtractor, LLMEntityRelationExtractor, OnError, + balance_curly_braces, + fix_invalid_json, ) from neo4j_genai.experimental.components.types import ( Neo4jGraph, @@ -214,3 +217,187 @@ async def test_extractor_custom_prompt() -> None: chunks = TextChunks(chunks=[TextChunk(text="some text")]) await extractor.run(chunks=chunks) llm.invoke.assert_called_once_with("this is my prompt") + + +def test_fix_unquoted_keys() -> None: + json_string = '{name: "John", age: "30"}' + expected_result = '{"name": "John", "age": "30"}' + + fixed_json = fix_invalid_json(json_string) + + assert json.loads(fixed_json) + assert fixed_json == expected_result + + +def test_fix_unquoted_string_values() -> None: + json_string = '{"name": John, "age": 30}' + expected_result = '{"name": "John", "age": 30}' + + fixed_json = fix_invalid_json(json_string) + + assert json.loads(fixed_json) + assert fixed_json == expected_result + + +def test_remove_trailing_commas() -> None: + json_string = '{"name": "John", "age": 30,}' + expected_result = '{"name": "John", "age": 30}' + + fixed_json = fix_invalid_json(json_string) + + assert json.loads(fixed_json) + assert fixed_json == expected_result + + +def test_fix_excessive_braces() -> None: + json_string = '{{"name": "John"}}' + expected_result = '{"name": "John"}' + + fixed_json = fix_invalid_json(json_string) + + assert json.loads(fixed_json) + assert fixed_json == expected_result + + +def test_fix_multiple_issues() -> None: + json_string = '{name: John, "hobbies": ["reading", "swimming",], "age": 30}' + expected_result = '{"name": "John", "hobbies": ["reading", "swimming"], "age": 30}' + + fixed_json = fix_invalid_json(json_string) + + assert json.loads(fixed_json) + assert fixed_json == expected_result + + +def test_fix_null_values() -> None: + json_string = '{"name": John, "nickname": null}' + expected_result = '{"name": "John", "nickname": null}' + + fixed_json = fix_invalid_json(json_string) + + assert json.loads(fixed_json) + assert fixed_json == expected_result + + +def test_fix_numeric_values() -> None: + json_string = '{"age": 30, "score": 95.5}' + expected_result = '{"age": 30, "score": 95.5}' + + fixed_json = fix_invalid_json(json_string) + + assert json.loads(fixed_json) + assert fixed_json == expected_result + + +def test_balance_curly_braces_missing_closing() -> None: + json_string = '{"name": "John", "hobbies": {"reading": "yes"' + expected_result = '{"name": "John", "hobbies": {"reading": "yes"}}' + + fixed_json = balance_curly_braces(json_string) + + assert json.loads(fixed_json) + assert fixed_json == expected_result + + +def test_balance_curly_braces_extra_closing() -> None: + json_string = '{"name": "John", "hobbies": {"reading": "yes"}}}' + expected_result = '{"name": "John", "hobbies": {"reading": "yes"}}' + + fixed_json = balance_curly_braces(json_string) + + assert json.loads(fixed_json) + assert fixed_json == expected_result + + +def test_balance_curly_braces_balanced_input() -> None: + json_string = '{"name": "John", "hobbies": {"reading": "yes"}, "age": 30}' + expected_result = json_string + + fixed_json = balance_curly_braces(json_string) + + assert json.loads(fixed_json) + assert fixed_json == expected_result + + +def test_balance_curly_braces_nested_structure() -> None: + json_string = '{"person": {"name": "John", "hobbies": {"reading": "yes"}}}' + expected_result = json_string + + fixed_json = balance_curly_braces(json_string) + + assert json.loads(fixed_json) + assert fixed_json == expected_result + + +def test_balance_curly_braces_unbalanced_nested() -> None: + json_string = '{"person": {"name": "John", "hobbies": {"reading": "yes"}}' + expected_result = '{"person": {"name": "John", "hobbies": {"reading": "yes"}}}' + + fixed_json = balance_curly_braces(json_string) + + assert json.loads(fixed_json) + assert fixed_json == expected_result + + +def test_balance_curly_braces_unmatched_openings() -> None: + json_string = '{"name": "John", "hobbies": {"reading": "yes"' + expected_result = '{"name": "John", "hobbies": {"reading": "yes"}}' + + fixed_json = balance_curly_braces(json_string) + + assert json.loads(fixed_json) + assert fixed_json == expected_result + + +def test_balance_curly_braces_unmatched_closings() -> None: + json_string = '{"name": "John", "hobbies": {"reading": "yes"}}}' + expected_result = '{"name": "John", "hobbies": {"reading": "yes"}}' + + fixed_json = balance_curly_braces(json_string) + + assert json.loads(fixed_json) + assert fixed_json == expected_result + + +def test_balance_curly_braces_complex_structure() -> None: + json_string = ( + '{"name": "John", "details": {"age": 30, "hobbies": {"reading": "yes"}}}' + ) + expected_result = json_string + + fixed_json = balance_curly_braces(json_string) + + assert json.loads(fixed_json) + assert fixed_json == expected_result + + +def test_balance_curly_braces_incorrect_nested_closings() -> None: + json_string = '{"key1": {"key2": {"reading": "yes"}}, "key3": {"age": 30}}}' + expected_result = '{"key1": {"key2": {"reading": "yes"}}, "key3": {"age": 30}}' + + fixed_json = balance_curly_braces(json_string) + + assert json.loads(fixed_json) + assert fixed_json == expected_result + + +def test_balance_curly_braces_braces_inside_string() -> None: + json_string = '{"name": "John", "example": "a{b}c", "age": 30}' + expected_result = json_string + + fixed_json = balance_curly_braces(json_string) + + assert json.loads(fixed_json) + assert fixed_json == expected_result + + +def test_balance_curly_braces_unbalanced_with_string() -> None: + json_string = '{"name": "John", "example": "a{b}c", "hobbies": {"reading": "yes"' + expected_result = ( + '{"name": "John", "example": "a{b}c", "hobbies": {"reading": "yes"}}' + ) + + fixed_json = balance_curly_braces(json_string) + + assert json.loads(fixed_json) + assert fixed_json == expected_result