Skip to content

Commit

Permalink
Merge branch 'main' into reorder-component-name-in-arguments
Browse files Browse the repository at this point in the history
  • Loading branch information
willtai authored Aug 28, 2024
2 parents a246669 + b89b932 commit 9ee3d4a
Show file tree
Hide file tree
Showing 7 changed files with 321 additions and 17 deletions.
9 changes: 9 additions & 0 deletions .github/workflows/scheduled-e2e-tests.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,9 @@ jobs:
image: cr.weaviate.io/semitechnologies/transformers-inference:sentence-transformers-all-MiniLM-L6-v2-onnx
env:
ENABLE_CUDA: '0'
credentials:
username: ${{ secrets.DOCKERHUB_USERNAME }}
password: ${{ secrets.DOCKERHUB_TOKEN }}
weaviate:
image: cr.weaviate.io/semitechnologies/weaviate:1.25.1
env:
Expand All @@ -34,6 +37,9 @@ jobs:
ports:
- 8080:8080
- 50051:50051
credentials:
username: ${{ secrets.DOCKERHUB_USERNAME }}
password: ${{ secrets.DOCKERHUB_TOKEN }}
neo4j:
image: neo4j:${{ matrix.neo4j-version }}-${{ matrix.neo4j-edition }}
env:
Expand All @@ -43,6 +49,9 @@ jobs:
ports:
- 7687:7687
- 7474:7474
credentials:
username: ${{ secrets.DOCKERHUB_USERNAME }}
password: ${{ secrets.DOCKERHUB_TOKEN }}

steps:
- name: Check out repository code
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
import enum
import json
import logging
import re
from datetime import datetime
from typing import Any, Dict, List, Union

Expand Down Expand Up @@ -50,6 +51,79 @@ class OnError(enum.Enum):
NODE_TO_CHUNK_RELATIONSHIP_TYPE = "FROM_CHUNK"


def balance_curly_braces(json_string: str) -> str:
"""
Balances curly braces `{}` in a JSON string. This function ensures that every opening brace has a corresponding
closing brace, but only when they are not part of a string value. If there are unbalanced closing braces,
they are ignored. If there are missing closing braces, they are appended at the end of the string.
Args:
json_string (str): A potentially malformed JSON string with unbalanced curly braces.
Returns:
str: A JSON string with balanced curly braces.
"""
stack = []
fixed_json = []
in_string = False
escape = False

for char in json_string:
if char == '"' and not escape:
in_string = not in_string
elif char == "\\" and in_string:
escape = not escape
fixed_json.append(char)
continue
else:
escape = False

if not in_string:
if char == "{":
stack.append(char)
fixed_json.append(char)
elif char == "}" and stack and stack[-1] == "{":
stack.pop()
fixed_json.append(char)
elif char == "}" and (not stack or stack[-1] != "{"):
continue
else:
fixed_json.append(char)
else:
fixed_json.append(char)

# If stack is not empty, add missing closing braces
while stack:
stack.pop()
fixed_json.append("}")

return "".join(fixed_json)


def fix_invalid_json(invalid_json_string: str) -> str:
# Fix missing quotes around field names
invalid_json_string = re.sub(
r"([{,]\s*)(\w+)(\s*:)", r'\1"\2"\3', invalid_json_string
)

# Fix missing quotes around string values, correctly ignoring null, true, false, and numeric values
invalid_json_string = re.sub(
r"(?<=:\s)(?!(null|true|false|\d+\.?\d*))([a-zA-Z_][a-zA-Z0-9_]*)\s*(?=[,}])",
r'"\2"',
invalid_json_string,
)

# Correct the specific issue: remove trailing commas within arrays or objects before closing braces or brackets
invalid_json_string = re.sub(r",\s*(?=[}\]])", "", invalid_json_string)

# Normalize excessive curly braces
invalid_json_string = re.sub(r"{{+", "{", invalid_json_string)
invalid_json_string = re.sub(r"}}+", "}", invalid_json_string)

# Balance curly braces
return balance_curly_braces(invalid_json_string)


class EntityRelationExtractor(Component, abc.ABC):
"""Abstract class for entity relation extraction components.
Expand Down Expand Up @@ -200,16 +274,20 @@ async def extract_for_chunk(
llm_result = self.llm.invoke(prompt)
try:
result = json.loads(llm_result.content)
except json.JSONDecodeError as e:
if self.on_error == OnError.RAISE:
raise LLMGenerationError(
f"LLM response is not valid JSON {llm_result.content}: {e}"
)
else:
logger.error(
f"LLM response is not valid JSON {llm_result.content} for chunk_index={chunk_index}"
)
result = {"nodes": [], "relationships": []}
except json.JSONDecodeError:
fixed_content = fix_invalid_json(llm_result.content)
try:
result = json.loads(fixed_content)
except json.JSONDecodeError as e:
if self.on_error == OnError.RAISE:
raise LLMGenerationError(
f"LLM response is not valid JSON {fixed_content}: {e}"
)
else:
logger.error(
f"LLM response is not valid JSON {llm_result.content} for chunk_index={chunk_index}"
)
result = {"nodes": [], "relationships": []}
try:
chunk_graph = Neo4jGraph(**result)
except ValidationError as e:
Expand Down
8 changes: 8 additions & 0 deletions src/neo4j_genai/experimental/pipeline/pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
import enum
import logging
from datetime import datetime
from timeit import default_timer
from typing import Any, AsyncGenerator, Awaitable, Callable, Optional

from pydantic import BaseModel, Field
Expand Down Expand Up @@ -119,6 +120,7 @@ async def execute(self, **kwargs: Any) -> RunResult | None:
was unsuccessful.
"""
logger.debug(f"Running component {self.name} with {kwargs}")
start_time = default_timer()
try:
await self.set_status(RunStatus.RUNNING)
except PipelineStatusUpdateError:
Expand All @@ -130,6 +132,8 @@ async def execute(self, **kwargs: Any) -> RunResult | None:
status=self.status,
result=component_result,
)
end_time = default_timer()
logger.debug(f"Component {self.name} finished in {end_time - start_time}s")
return run_result

def validate_inputs_config(self, input_data: dict[str, Any]) -> None:
Expand Down Expand Up @@ -467,8 +471,12 @@ def validate_inputs_config(self, data: dict[str, Any]) -> None:
task.validate_inputs_config(data)

async def run(self, data: dict[str, Any]) -> dict[str, Any]:
logger.debug("Starting pipeline")
start_time = default_timer()
self.validate_inputs_config(data)
self.reinitialize()
orchestrator = Orchestrator(self)
await orchestrator.run(data)
end_time = default_timer()
logger.debug(f"Pipeline finished in {end_time - start_time}s")
return self._final_results.all()
14 changes: 7 additions & 7 deletions src/neo4j_genai/generation/graphrag.py
Original file line number Diff line number Diff line change
Expand Up @@ -85,13 +85,13 @@ def search(
DeprecationWarning,
stacklevel=2,
)
elif isinstance(query, str):
warnings.warn(
"'query' is deprecated and will be removed in a future version, please use 'query_text' instead.",
DeprecationWarning,
stacklevel=2,
)
query_text = query
elif isinstance(query, str):
warnings.warn(
"'query' is deprecated and will be removed in a future version, please use 'query_text' instead.",
DeprecationWarning,
stacklevel=2,
)
query_text = query

validated_data = RagSearchModel(
query_text=query_text,
Expand Down
2 changes: 2 additions & 0 deletions src/neo4j_genai/generation/prompts.py
Original file line number Diff line number Diff line change
Expand Up @@ -140,6 +140,8 @@ class ERExtractionTemplate(PromptTemplate):
Do respect the source and target node types for relationship and
the relationship direction.
Do not return any additional information other than the JSON in it.
Examples:
{examples}
Expand Down
Loading

0 comments on commit 9ee3d4a

Please sign in to comment.