Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix(agents-api): Configure spacy for postgresql #1055

Merged
merged 11 commits into from
Jan 15, 2025
256 changes: 92 additions & 164 deletions agents-api/agents_api/common/nlp.py
Original file line number Diff line number Diff line change
@@ -1,15 +1,15 @@
import re
from collections import Counter, defaultdict
from collections import Counter
from functools import lru_cache

import spacy
from spacy.matcher import PhraseMatcher
from spacy.tokens import Doc
from spacy.util import filter_spans

# Precompile regex patterns
WHITESPACE_RE = re.compile(r"\s+")
NON_ALPHANUM_RE = re.compile(r"[^\w\s\-_]+")
LONE_HYPHEN_RE = re.compile(r"\s*-\s*(?!\w)|(?<!\w)\s*-\s*")

# Initialize spaCy with minimal pipeline
nlp = spacy.load("en_core_web_sm", exclude=["lemmatizer", "textcat"])
Expand All @@ -30,66 +30,40 @@
)


# Singleton PhraseMatcher for better performance
class KeywordMatcher:
_instance = None

def __new__(cls):
if cls._instance is None:
cls._instance = super().__new__(cls)
cls._instance.matcher = PhraseMatcher(nlp.vocab, attr="LOWER")
cls._instance.batch_size = 1000 # Adjust based on memory constraints
cls._instance.patterns_cache = {}
return cls._instance

@lru_cache(maxsize=10000)
def _create_pattern(self, text: str) -> Doc:
return nlp.make_doc(text)

def find_matches(self, doc: Doc, keywords: list[str]) -> dict[str, list[int]]:
"""Batch process keywords for better performance."""
keyword_positions = defaultdict(list)

# Process keywords in batches to avoid memory issues
for i in range(0, len(keywords), self.batch_size):
batch = keywords[i : i + self.batch_size]
patterns = [self._create_pattern(kw) for kw in batch]

# Clear previous patterns and add new batch
if "KEYWORDS" in self.matcher:
self.matcher.remove("KEYWORDS")
self.matcher.add("KEYWORDS", patterns)

# Find matches for this batch
matches = self.matcher(doc)
for match_id, start, end in matches:
span_text = doc[start:end].text
normalized = WHITESPACE_RE.sub(" ", span_text).lower().strip()
keyword_positions[normalized].append(start)

return keyword_positions


# Initialize global matcher
keyword_matcher = KeywordMatcher()


@lru_cache(maxsize=10000)
def clean_keyword(kw: str) -> str:
"""Cache cleaned keywords for reuse."""
return NON_ALPHANUM_RE.sub("", kw).strip()
# First remove non-alphanumeric chars (except whitespace, hyphens, underscores)
cleaned = NON_ALPHANUM_RE.sub("", kw).strip()
# Replace lone hyphens with spaces
cleaned = LONE_HYPHEN_RE.sub(" ", cleaned)
# Clean up any resulting multiple spaces
return WHITESPACE_RE.sub(" ", cleaned).strip()


def extract_keywords(doc: Doc, top_n: int = 10, clean: bool = True) -> list[str]:
def extract_keywords(
doc: Doc, top_n: int = 25, split_chunks: bool = False
) -> list[str]:
"""Optimized keyword extraction with minimal behavior change."""
excluded_labels = {
"DATE",
"TIME",
"PERCENT",
"MONEY",
"QUANTITY",
"ORDINAL",
"CARDINAL",
"DATE", # Absolute or relative dates or periods.
"TIME", # Times smaller than a day.
"PERCENT", # Percentage, including ”%“.
"MONEY", # Monetary values, including unit.
"QUANTITY", # Measurements, as of weight or distance.
"ORDINAL", # “first”, “second”, etc.
"CARDINAL", # Numerals that do not fall under another type.
# "PERSON", # People, including fictional.
# "NORP", # Nationalities or religious or political groups.
# "FAC", # Buildings, airports, highways, bridges, etc.
# "ORG", # Companies, agencies, institutions, etc.
# "GPE", # Countries, cities, states.
# "LOC", # Non-GPE locations, mountain ranges, bodies of water.
# "PRODUCT", # Objects, vehicles, foods, etc. (Not services.)
# "EVENT", # Named hurricanes, battles, wars, sports events, etc.
# "WORK_OF_ART", # Titles of books, songs, etc.
# "LAW", # Named documents made into laws.
# "LANGUAGE", # Any named language.
}

# Extract and filter spans in a single pass
Expand All @@ -104,8 +78,12 @@ def extract_keywords(doc: Doc, top_n: int = 10, clean: bool = True) -> list[str]

# Process spans efficiently and filter out spans that are entirely stopwords
keywords = []
ent_keywords = []
seen_texts = set()

# Convert ent_spans to set for faster lookup
ent_spans_set = set(ent_spans)

for span in all_spans:
# Skip if all tokens in span are stopwords
if all(token.is_stop for token in span):
Expand All @@ -119,79 +97,28 @@ def extract_keywords(doc: Doc, top_n: int = 10, clean: bool = True) -> list[str]
continue

seen_texts.add(lower_text)
keywords.append(text)
ent_keywords.append(text) if span in ent_spans_set else keywords.append(text)

# Normalize keywords by replacing multiple spaces with single space and stripping
normalized_ent_keywords = [WHITESPACE_RE.sub(" ", kw).strip() for kw in ent_keywords]
normalized_keywords = [WHITESPACE_RE.sub(" ", kw).strip() for kw in keywords]

if split_chunks:
normalized_keywords = [word for kw in normalized_keywords for word in kw.split()]

# Count frequencies efficiently
ent_freq = Counter(normalized_ent_keywords)
freq = Counter(normalized_keywords)
top_keywords = [kw for kw, _ in freq.most_common(top_n)]

if clean:
return [clean_keyword(kw) for kw in top_keywords]
return top_keywords


def find_proximity_groups(
keywords: list[str], keyword_positions: dict[str, list[int]], n: int = 10
) -> list[set[str]]:
"""Optimized proximity grouping using sorted positions."""
# Early return for single or no keywords
if len(keywords) <= 1:
return [{kw} for kw in keywords]

# Create flat list of positions for efficient processing
positions: list[tuple[int, str]] = [
(pos, kw) for kw in keywords for pos in keyword_positions[kw]
]

# Sort positions once
positions.sort()

# Initialize Union-Find with path compression and union by rank
parent = {kw: kw for kw in keywords}
rank = dict.fromkeys(keywords, 0)

def find(u: str) -> str:
if parent[u] != u:
parent[u] = find(parent[u])
return parent[u]

def union(u: str, v: str) -> None:
u_root, v_root = find(u), find(v)
if u_root != v_root:
if rank[u_root] < rank[v_root]:
u_root, v_root = v_root, u_root
parent[v_root] = u_root
if rank[u_root] == rank[v_root]:
rank[u_root] += 1

# Use sliding window for proximity checking
window = []
for pos, kw in positions:
# Remove positions outside window
while window and pos - window[0][0] > n:
window.pop(0)

# Union with all keywords in window
for _, w_kw in window:
union(kw, w_kw)

window.append((pos, kw))

# Group keywords efficiently
groups = defaultdict(set)
for kw in keywords:
root = find(kw)
groups[root].add(kw)

return list(groups.values())
top_keywords = [kw for kw, _ in ent_freq.most_common(top_n)]
remaining_slots = max(0, top_n - len(top_keywords))
top_keywords += [kw for kw, _ in freq.most_common(remaining_slots)]

return [clean_keyword(kw) for kw in top_keywords]

@lru_cache(maxsize=1000)
def text_to_tsvector_query(
paragraph: str, top_n: int = 10, proximity_n: int = 10, min_keywords: int = 1
paragraph: str, top_n: int = 25, min_keywords: int = 1, split_chunks: bool = False
) -> str:
"""
Extracts meaningful keywords/phrases from text and joins them with OR.
Expand All @@ -203,8 +130,8 @@ def text_to_tsvector_query(
Args:
paragraph (str): The input text to process
top_n (int): Number of top keywords to extract per sentence
proximity_n (int): The proximity window for grouping related keywords
min_keywords (int): Minimum number of keywords required
split_chunks (bool): If True, breaks multi-word noun chunks into individual words

Returns:
str: Keywords/phrases joined by OR
Expand All @@ -219,57 +146,58 @@ def text_to_tsvector_query(
sent_doc = sent.as_doc()

# Extract keywords
keywords = extract_keywords(sent_doc, top_n)
keywords = extract_keywords(sent_doc, top_n, split_chunks=split_chunks)
if len(keywords) < min_keywords:
continue

# Find keyword positions
keyword_positions = keyword_matcher.find_matches(sent_doc, keywords)
if not keyword_positions:
continue

# Group related keywords by proximity
groups = find_proximity_groups(keywords, keyword_positions, proximity_n)

# Add each group as a single term to our set
for group in groups:
if len(group) > 1:
# Sort by length descending to prioritize longer phrases
sorted_group = sorted(group, key=len, reverse=True)
# For truly proximate multi-word groups, group words
queries.add(" OR ".join(sorted_group))
else:
# For non-proximate words or single words, add them separately
queries.update(group)
queries.update(keywords)

# Join all terms with " OR "
return " OR ".join(queries) if queries else ""


def batch_text_to_tsvector_queries(
paragraphs: list[str],
top_n: int = 10,
proximity_n: int = 10,
min_keywords: int = 1,
n_process: int = 1,
) -> list[str]:
"""
Processes multiple paragraphs using nlp.pipe for better performance.

Args:
paragraphs (list[str]): List of paragraphs to process
top_n (int): Number of top keywords to include per paragraph

Returns:
list[str]: List of tsquery strings
"""
# Use a set to avoid duplicates
results = set()

for doc in nlp.pipe(paragraphs, disable=["lemmatizer", "textcat"], n_process=n_process):
# Generate tsquery string for each paragraph
queries = text_to_tsvector_query(doc, top_n, proximity_n, min_keywords)
# Add to results set
results.add(queries)

return list(results)
# def batch_text_to_tsvector_queries(
# paragraphs: list[str],
# top_n: int = 10,
# proximity_n: int = 10,
# min_keywords: int = 1,
# n_process: int = 1,
# ) -> list[str]:
# """
# Processes multiple paragraphs using nlp.pipe for better performance.

# Args:
# paragraphs (list[str]): List of paragraphs to process
# top_n (int): Number of top keywords to include per paragraph

# Returns:
# list[str]: List of tsquery strings
# """
# results = []

# for doc in nlp.pipe(paragraphs, disable=["lemmatizer", "textcat"], n_process=n_process):
# queries = set() # Use set to avoid duplicates
# for sent in doc.sents:
# sent_doc = sent.as_doc()
# keywords = extract_keywords(sent_doc, top_n)
# if len(keywords) < min_keywords:
# continue
# keyword_positions = keyword_matcher.find_matches(sent_doc, keywords)
# if not keyword_positions:
# continue
# groups = find_proximity_groups(keywords, keyword_positions, proximity_n)
# # Add each group as a single term to our set
# for group in groups:
# if len(group) > 1:
# # Sort by length descending to prioritize longer phrases
# sorted_group = sorted(group, key=len, reverse=True)
# # For truly proximate multi-word groups, group words
# queries.add(" OR ".join(sorted_group))
# else:
# # For non-proximate words or single words, add them separately
# queries.update(group)

# # Join all terms with " OR "
# results.append(" OR ".join(queries) if queries else "")

# return results
3 changes: 2 additions & 1 deletion agents-api/agents_api/queries/docs/search_docs_by_text.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
from fastapi import HTTPException

from ...autogen.openapi_model import DocReference
from ...common.nlp import text_to_tsvector_query
from ...common.utils.db_exceptions import common_db_exceptions
from ..utils import pg_query, rewrap_exceptions, wrap_in_class
from .utils import transform_to_doc_reference
Expand Down Expand Up @@ -61,7 +62,7 @@ async def search_docs_by_text(
owner_types: list[str] = [owner[0] for owner in owners]
owner_ids: list[str] = [str(owner[1]) for owner in owners]
# Pre-process rawtext query
# query = text_to_tsvector_query(query)
query = text_to_tsvector_query(query)

return (
search_docs_text_query,
Expand Down
4 changes: 4 additions & 0 deletions agents-api/agents_api/queries/docs/search_docs_hybrid.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
from fastapi import HTTPException

from ...autogen.openapi_model import DocReference
from ...common.nlp import text_to_tsvector_query
from ...common.utils.db_exceptions import common_db_exceptions
from ..utils import (
pg_query,
Expand Down Expand Up @@ -81,6 +82,9 @@ async def search_docs_hybrid(
owner_types: list[str] = [owner[0] for owner in owners]
owner_ids: list[str] = [str(owner[1]) for owner in owners]

# Pre-process rawtext query
text_query = text_to_tsvector_query(text_query)

return (
search_docs_hybrid_query,
[
Expand Down
Loading
Loading