Skip to content

Commit

Permalink
fix(agents-api): Configure spacy for postgresql
Browse files Browse the repository at this point in the history
  • Loading branch information
Ahmad-mtos committed Jan 14, 2025
1 parent cb86135 commit 61d32bd
Show file tree
Hide file tree
Showing 2 changed files with 86 additions and 189 deletions.
255 changes: 79 additions & 176 deletions agents-api/agents_api/common/nlp.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,67 +29,33 @@
},
)


# Singleton PhraseMatcher for better performance
class KeywordMatcher:
_instance = None

def __new__(cls):
if cls._instance is None:
cls._instance = super().__new__(cls)
cls._instance.matcher = PhraseMatcher(nlp.vocab, attr="LOWER")
cls._instance.batch_size = 1000 # Adjust based on memory constraints
cls._instance.patterns_cache = {}
return cls._instance

@lru_cache(maxsize=10000)
def _create_pattern(self, text: str) -> Doc:
return nlp.make_doc(text)

def find_matches(self, doc: Doc, keywords: list[str]) -> dict[str, list[int]]:
"""Batch process keywords for better performance."""
keyword_positions = defaultdict(list)

# Process keywords in batches to avoid memory issues
for i in range(0, len(keywords), self.batch_size):
batch = keywords[i : i + self.batch_size]
patterns = [self._create_pattern(kw) for kw in batch]

# Clear previous patterns and add new batch
if "KEYWORDS" in self.matcher:
self.matcher.remove("KEYWORDS")
self.matcher.add("KEYWORDS", patterns)

# Find matches for this batch
matches = self.matcher(doc)
for match_id, start, end in matches:
span_text = doc[start:end].text
normalized = WHITESPACE_RE.sub(" ", span_text).lower().strip()
keyword_positions[normalized].append(start)

return keyword_positions


# Initialize global matcher
keyword_matcher = KeywordMatcher()


@lru_cache(maxsize=10000)
def clean_keyword(kw: str) -> str:
"""Cache cleaned keywords for reuse."""
return NON_ALPHANUM_RE.sub("", kw).strip()


def extract_keywords(doc: Doc, top_n: int = 10, clean: bool = True) -> list[str]:
def extract_keywords(doc: Doc, top_n: int = 25, clean: bool = True) -> list[str]:
"""Optimized keyword extraction with minimal behavior change."""
excluded_labels = {
"DATE",
"TIME",
"PERCENT",
"MONEY",
"QUANTITY",
"ORDINAL",
"CARDINAL",
"DATE", # Absolute or relative dates or periods.
"TIME", # Times smaller than a day.
"PERCENT", # Percentage, including ”%“.
"MONEY", # Monetary values, including unit.
"QUANTITY", # Measurements, as of weight or distance.
"ORDINAL", # “first”, “second”, etc.
"CARDINAL", # Numerals that do not fall under another type.
# "PERSON", # People, including fictional.
# "NORP", # Nationalities or religious or political groups.
# "FAC", # Buildings, airports, highways, bridges, etc.
# "ORG", # Companies, agencies, institutions, etc.
# "GPE", # Countries, cities, states.
# "LOC", # Non-GPE locations, mountain ranges, bodies of water.
# "PRODUCT", # Objects, vehicles, foods, etc. (Not services.)
# "EVENT", # Named hurricanes, battles, wars, sports events, etc.
# "WORK_OF_ART", # Titles of books, songs, etc.
# "LAW", # Named documents made into laws.
# "LANGUAGE", # Any named language.
}

# Extract and filter spans in a single pass
Expand All @@ -104,8 +70,12 @@ def extract_keywords(doc: Doc, top_n: int = 10, clean: bool = True) -> list[str]

# Process spans efficiently and filter out spans that are entirely stopwords
keywords = []
ent_keywords = []
seen_texts = set()

# Convert ent_spans to set for faster lookup
ent_spans_set = set(ent_spans)

for span in all_spans:
# Skip if all tokens in span are stopwords
if all(token.is_stop for token in span):
Expand All @@ -119,79 +89,30 @@ def extract_keywords(doc: Doc, top_n: int = 10, clean: bool = True) -> list[str]
continue

seen_texts.add(lower_text)
keywords.append(text)
ent_keywords.append(text) if span in ent_spans_set else keywords.append(text)


# Normalize keywords by replacing multiple spaces with single space and stripping
normalized_ent_keywords = [WHITESPACE_RE.sub(" ", kw).strip() for kw in ent_keywords]
normalized_keywords = [WHITESPACE_RE.sub(" ", kw).strip() for kw in keywords]

# Count frequencies efficiently
ent_freq = Counter(normalized_ent_keywords)
freq = Counter(normalized_keywords)
top_keywords = [kw for kw, _ in freq.most_common(top_n)]


top_keywords = [kw for kw, _ in ent_freq.most_common(top_n)]
remaining_slots = max(0, top_n - len(top_keywords))
top_keywords += [kw for kw, _ in freq.most_common(remaining_slots)]

if clean:
return [clean_keyword(kw) for kw in top_keywords]
return top_keywords


def find_proximity_groups(
keywords: list[str], keyword_positions: dict[str, list[int]], n: int = 10
) -> list[set[str]]:
"""Optimized proximity grouping using sorted positions."""
# Early return for single or no keywords
if len(keywords) <= 1:
return [{kw} for kw in keywords]

# Create flat list of positions for efficient processing
positions: list[tuple[int, str]] = [
(pos, kw) for kw in keywords for pos in keyword_positions[kw]
]

# Sort positions once
positions.sort()

# Initialize Union-Find with path compression and union by rank
parent = {kw: kw for kw in keywords}
rank = dict.fromkeys(keywords, 0)

def find(u: str) -> str:
if parent[u] != u:
parent[u] = find(parent[u])
return parent[u]

def union(u: str, v: str) -> None:
u_root, v_root = find(u), find(v)
if u_root != v_root:
if rank[u_root] < rank[v_root]:
u_root, v_root = v_root, u_root
parent[v_root] = u_root
if rank[u_root] == rank[v_root]:
rank[u_root] += 1

# Use sliding window for proximity checking
window = []
for pos, kw in positions:
# Remove positions outside window
while window and pos - window[0][0] > n:
window.pop(0)

# Union with all keywords in window
for _, w_kw in window:
union(kw, w_kw)

window.append((pos, kw))

# Group keywords efficiently
groups = defaultdict(set)
for kw in keywords:
root = find(kw)
groups[root].add(kw)

return list(groups.values())


@lru_cache(maxsize=1000)
def text_to_tsvector_query(
paragraph: str, top_n: int = 10, proximity_n: int = 10, min_keywords: int = 1
paragraph: str, top_n: int = 25, min_keywords: int = 1
) -> str:
"""
Extracts meaningful keywords/phrases from text and joins them with OR.
Expand All @@ -203,7 +124,6 @@ def text_to_tsvector_query(
Args:
paragraph (str): The input text to process
top_n (int): Number of top keywords to extract per sentence
proximity_n (int): The proximity window for grouping related keywords
min_keywords (int): Minimum number of keywords required
Returns:
Expand All @@ -223,71 +143,54 @@ def text_to_tsvector_query(
if len(keywords) < min_keywords:
continue

# Find keyword positions
keyword_positions = keyword_matcher.find_matches(sent_doc, keywords)
if not keyword_positions:
continue

# Group related keywords by proximity
groups = find_proximity_groups(keywords, keyword_positions, proximity_n)

# Add each group as a single term to our set
for group in groups:
if len(group) > 1:
# Sort by length descending to prioritize longer phrases
sorted_group = sorted(group, key=len, reverse=True)
# For truly proximate multi-word groups, group words
queries.add(" OR ".join(sorted_group))
else:
# For non-proximate words or single words, add them separately
queries.update(group)
queries.add(" OR ".join(keywords))

# Join all terms with " OR "
return " OR ".join(queries) if queries else ""


def batch_text_to_tsvector_queries(
paragraphs: list[str],
top_n: int = 10,
proximity_n: int = 10,
min_keywords: int = 1,
n_process: int = 1,
) -> list[str]:
"""
Processes multiple paragraphs using nlp.pipe for better performance.
Args:
paragraphs (list[str]): List of paragraphs to process
top_n (int): Number of top keywords to include per paragraph
Returns:
list[str]: List of tsquery strings
"""
results = []

for doc in nlp.pipe(paragraphs, disable=["lemmatizer", "textcat"], n_process=n_process):
queries = set() # Use set to avoid duplicates
for sent in doc.sents:
sent_doc = sent.as_doc()
keywords = extract_keywords(sent_doc, top_n)
if len(keywords) < min_keywords:
continue
keyword_positions = keyword_matcher.find_matches(sent_doc, keywords)
if not keyword_positions:
continue
groups = find_proximity_groups(keywords, keyword_positions, proximity_n)
# Add each group as a single term to our set
for group in groups:
if len(group) > 1:
# Sort by length descending to prioritize longer phrases
sorted_group = sorted(group, key=len, reverse=True)
# For truly proximate multi-word groups, group words
queries.add(" OR ".join(sorted_group))
else:
# For non-proximate words or single words, add them separately
queries.update(group)

# Join all terms with " OR "
results.append(" OR ".join(queries) if queries else "")

return results
# def batch_text_to_tsvector_queries(
# paragraphs: list[str],
# top_n: int = 10,
# proximity_n: int = 10,
# min_keywords: int = 1,
# n_process: int = 1,
# ) -> list[str]:
# """
# Processes multiple paragraphs using nlp.pipe for better performance.

# Args:
# paragraphs (list[str]): List of paragraphs to process
# top_n (int): Number of top keywords to include per paragraph

# Returns:
# list[str]: List of tsquery strings
# """
# results = []

# for doc in nlp.pipe(paragraphs, disable=["lemmatizer", "textcat"], n_process=n_process):
# queries = set() # Use set to avoid duplicates
# for sent in doc.sents:
# sent_doc = sent.as_doc()
# keywords = extract_keywords(sent_doc, top_n)
# if len(keywords) < min_keywords:
# continue
# keyword_positions = keyword_matcher.find_matches(sent_doc, keywords)
# if not keyword_positions:
# continue
# groups = find_proximity_groups(keywords, keyword_positions, proximity_n)
# # Add each group as a single term to our set
# for group in groups:
# if len(group) > 1:
# # Sort by length descending to prioritize longer phrases
# sorted_group = sorted(group, key=len, reverse=True)
# # For truly proximate multi-word groups, group words
# queries.add(" OR ".join(sorted_group))
# else:
# # For non-proximate words or single words, add them separately
# queries.update(group)

# # Join all terms with " OR "
# results.append(" OR ".join(queries) if queries else "")

# return results
20 changes: 7 additions & 13 deletions agents-api/tests/test_docs_queries.py
Original file line number Diff line number Diff line change
Expand Up @@ -387,11 +387,7 @@ async def _():
# Multiple sentences
(
"I love basketball especially Michael Jordan. LeBron James is also great.",
[
"basketball OR lebron james OR michael jordan",
"LeBron James OR Michael Jordan OR basketball",
"Michael Jordan OR basketball OR LeBron James",
],
"basketball OR lebron james OR michael jordan",
),
# Quoted phrases
(
Expand Down Expand Up @@ -422,14 +418,12 @@ async def _():
result = text_to_tsvector_query(input_text)
print(f"Generated query: '{result}'")
print(f"Expected: '{expected_output}'\n")
if isinstance(expected_output, list):
assert any(
result.lower() == expected_output.lower() for expected_output in expected_output
), f"Expected '{expected_output}' but got '{result}' for input '{input_text}'"
else:
assert result.lower() == expected_output.lower(), (
f"Expected '{expected_output}' but got '{result}' for input '{input_text}'"
)

result_terms = set(term.lower() for term in result.split(" OR ") if term)
expected_terms = set(term.lower() for term in expected_output.split(" OR ") if term)
assert result_terms == expected_terms, (
f"Expected terms {expected_terms} but got {result_terms} for input '{input_text}'"
)


# @test("query: search docs by embedding with different confidence levels")
Expand Down

0 comments on commit 61d32bd

Please sign in to comment.