From 61d32bd1b68add043640e77db54b497f920014f4 Mon Sep 17 00:00:00 2001 From: Ahmad Haidar Date: Tue, 14 Jan 2025 15:27:13 +0300 Subject: [PATCH] fix(agents-api): Configure spacy for postgresql --- agents-api/agents_api/common/nlp.py | 255 ++++++++------------------ agents-api/tests/test_docs_queries.py | 20 +- 2 files changed, 86 insertions(+), 189 deletions(-) diff --git a/agents-api/agents_api/common/nlp.py b/agents-api/agents_api/common/nlp.py index be86d8936..62895f7f9 100644 --- a/agents-api/agents_api/common/nlp.py +++ b/agents-api/agents_api/common/nlp.py @@ -29,67 +29,33 @@ }, ) - -# Singleton PhraseMatcher for better performance -class KeywordMatcher: - _instance = None - - def __new__(cls): - if cls._instance is None: - cls._instance = super().__new__(cls) - cls._instance.matcher = PhraseMatcher(nlp.vocab, attr="LOWER") - cls._instance.batch_size = 1000 # Adjust based on memory constraints - cls._instance.patterns_cache = {} - return cls._instance - - @lru_cache(maxsize=10000) - def _create_pattern(self, text: str) -> Doc: - return nlp.make_doc(text) - - def find_matches(self, doc: Doc, keywords: list[str]) -> dict[str, list[int]]: - """Batch process keywords for better performance.""" - keyword_positions = defaultdict(list) - - # Process keywords in batches to avoid memory issues - for i in range(0, len(keywords), self.batch_size): - batch = keywords[i : i + self.batch_size] - patterns = [self._create_pattern(kw) for kw in batch] - - # Clear previous patterns and add new batch - if "KEYWORDS" in self.matcher: - self.matcher.remove("KEYWORDS") - self.matcher.add("KEYWORDS", patterns) - - # Find matches for this batch - matches = self.matcher(doc) - for match_id, start, end in matches: - span_text = doc[start:end].text - normalized = WHITESPACE_RE.sub(" ", span_text).lower().strip() - keyword_positions[normalized].append(start) - - return keyword_positions - - -# Initialize global matcher -keyword_matcher = KeywordMatcher() - - @lru_cache(maxsize=10000) def clean_keyword(kw: str) -> str: """Cache cleaned keywords for reuse.""" return NON_ALPHANUM_RE.sub("", kw).strip() -def extract_keywords(doc: Doc, top_n: int = 10, clean: bool = True) -> list[str]: +def extract_keywords(doc: Doc, top_n: int = 25, clean: bool = True) -> list[str]: """Optimized keyword extraction with minimal behavior change.""" excluded_labels = { - "DATE", - "TIME", - "PERCENT", - "MONEY", - "QUANTITY", - "ORDINAL", - "CARDINAL", + "DATE", # Absolute or relative dates or periods. + "TIME", # Times smaller than a day. + "PERCENT", # Percentage, including ”%“. + "MONEY", # Monetary values, including unit. + "QUANTITY", # Measurements, as of weight or distance. + "ORDINAL", # “first”, “second”, etc. + "CARDINAL", # Numerals that do not fall under another type. + # "PERSON", # People, including fictional. + # "NORP", # Nationalities or religious or political groups. + # "FAC", # Buildings, airports, highways, bridges, etc. + # "ORG", # Companies, agencies, institutions, etc. + # "GPE", # Countries, cities, states. + # "LOC", # Non-GPE locations, mountain ranges, bodies of water. + # "PRODUCT", # Objects, vehicles, foods, etc. (Not services.) + # "EVENT", # Named hurricanes, battles, wars, sports events, etc. + # "WORK_OF_ART", # Titles of books, songs, etc. + # "LAW", # Named documents made into laws. + # "LANGUAGE", # Any named language. } # Extract and filter spans in a single pass @@ -104,8 +70,12 @@ def extract_keywords(doc: Doc, top_n: int = 10, clean: bool = True) -> list[str] # Process spans efficiently and filter out spans that are entirely stopwords keywords = [] + ent_keywords = [] seen_texts = set() + # Convert ent_spans to set for faster lookup + ent_spans_set = set(ent_spans) + for span in all_spans: # Skip if all tokens in span are stopwords if all(token.is_stop for token in span): @@ -119,79 +89,30 @@ def extract_keywords(doc: Doc, top_n: int = 10, clean: bool = True) -> list[str] continue seen_texts.add(lower_text) - keywords.append(text) + ent_keywords.append(text) if span in ent_spans_set else keywords.append(text) + # Normalize keywords by replacing multiple spaces with single space and stripping + normalized_ent_keywords = [WHITESPACE_RE.sub(" ", kw).strip() for kw in ent_keywords] normalized_keywords = [WHITESPACE_RE.sub(" ", kw).strip() for kw in keywords] # Count frequencies efficiently + ent_freq = Counter(normalized_ent_keywords) freq = Counter(normalized_keywords) - top_keywords = [kw for kw, _ in freq.most_common(top_n)] + + + top_keywords = [kw for kw, _ in ent_freq.most_common(top_n)] + remaining_slots = max(0, top_n - len(top_keywords)) + top_keywords += [kw for kw, _ in freq.most_common(remaining_slots)] if clean: return [clean_keyword(kw) for kw in top_keywords] return top_keywords -def find_proximity_groups( - keywords: list[str], keyword_positions: dict[str, list[int]], n: int = 10 -) -> list[set[str]]: - """Optimized proximity grouping using sorted positions.""" - # Early return for single or no keywords - if len(keywords) <= 1: - return [{kw} for kw in keywords] - - # Create flat list of positions for efficient processing - positions: list[tuple[int, str]] = [ - (pos, kw) for kw in keywords for pos in keyword_positions[kw] - ] - - # Sort positions once - positions.sort() - - # Initialize Union-Find with path compression and union by rank - parent = {kw: kw for kw in keywords} - rank = dict.fromkeys(keywords, 0) - - def find(u: str) -> str: - if parent[u] != u: - parent[u] = find(parent[u]) - return parent[u] - - def union(u: str, v: str) -> None: - u_root, v_root = find(u), find(v) - if u_root != v_root: - if rank[u_root] < rank[v_root]: - u_root, v_root = v_root, u_root - parent[v_root] = u_root - if rank[u_root] == rank[v_root]: - rank[u_root] += 1 - - # Use sliding window for proximity checking - window = [] - for pos, kw in positions: - # Remove positions outside window - while window and pos - window[0][0] > n: - window.pop(0) - - # Union with all keywords in window - for _, w_kw in window: - union(kw, w_kw) - - window.append((pos, kw)) - - # Group keywords efficiently - groups = defaultdict(set) - for kw in keywords: - root = find(kw) - groups[root].add(kw) - - return list(groups.values()) - - @lru_cache(maxsize=1000) def text_to_tsvector_query( - paragraph: str, top_n: int = 10, proximity_n: int = 10, min_keywords: int = 1 + paragraph: str, top_n: int = 25, min_keywords: int = 1 ) -> str: """ Extracts meaningful keywords/phrases from text and joins them with OR. @@ -203,7 +124,6 @@ def text_to_tsvector_query( Args: paragraph (str): The input text to process top_n (int): Number of top keywords to extract per sentence - proximity_n (int): The proximity window for grouping related keywords min_keywords (int): Minimum number of keywords required Returns: @@ -223,71 +143,54 @@ def text_to_tsvector_query( if len(keywords) < min_keywords: continue - # Find keyword positions - keyword_positions = keyword_matcher.find_matches(sent_doc, keywords) - if not keyword_positions: - continue - - # Group related keywords by proximity - groups = find_proximity_groups(keywords, keyword_positions, proximity_n) - - # Add each group as a single term to our set - for group in groups: - if len(group) > 1: - # Sort by length descending to prioritize longer phrases - sorted_group = sorted(group, key=len, reverse=True) - # For truly proximate multi-word groups, group words - queries.add(" OR ".join(sorted_group)) - else: - # For non-proximate words or single words, add them separately - queries.update(group) + queries.add(" OR ".join(keywords)) # Join all terms with " OR " return " OR ".join(queries) if queries else "" -def batch_text_to_tsvector_queries( - paragraphs: list[str], - top_n: int = 10, - proximity_n: int = 10, - min_keywords: int = 1, - n_process: int = 1, -) -> list[str]: - """ - Processes multiple paragraphs using nlp.pipe for better performance. - - Args: - paragraphs (list[str]): List of paragraphs to process - top_n (int): Number of top keywords to include per paragraph - - Returns: - list[str]: List of tsquery strings - """ - results = [] - - for doc in nlp.pipe(paragraphs, disable=["lemmatizer", "textcat"], n_process=n_process): - queries = set() # Use set to avoid duplicates - for sent in doc.sents: - sent_doc = sent.as_doc() - keywords = extract_keywords(sent_doc, top_n) - if len(keywords) < min_keywords: - continue - keyword_positions = keyword_matcher.find_matches(sent_doc, keywords) - if not keyword_positions: - continue - groups = find_proximity_groups(keywords, keyword_positions, proximity_n) - # Add each group as a single term to our set - for group in groups: - if len(group) > 1: - # Sort by length descending to prioritize longer phrases - sorted_group = sorted(group, key=len, reverse=True) - # For truly proximate multi-word groups, group words - queries.add(" OR ".join(sorted_group)) - else: - # For non-proximate words or single words, add them separately - queries.update(group) - - # Join all terms with " OR " - results.append(" OR ".join(queries) if queries else "") - - return results +# def batch_text_to_tsvector_queries( +# paragraphs: list[str], +# top_n: int = 10, +# proximity_n: int = 10, +# min_keywords: int = 1, +# n_process: int = 1, +# ) -> list[str]: +# """ +# Processes multiple paragraphs using nlp.pipe for better performance. + +# Args: +# paragraphs (list[str]): List of paragraphs to process +# top_n (int): Number of top keywords to include per paragraph + +# Returns: +# list[str]: List of tsquery strings +# """ +# results = [] + +# for doc in nlp.pipe(paragraphs, disable=["lemmatizer", "textcat"], n_process=n_process): +# queries = set() # Use set to avoid duplicates +# for sent in doc.sents: +# sent_doc = sent.as_doc() +# keywords = extract_keywords(sent_doc, top_n) +# if len(keywords) < min_keywords: +# continue +# keyword_positions = keyword_matcher.find_matches(sent_doc, keywords) +# if not keyword_positions: +# continue +# groups = find_proximity_groups(keywords, keyword_positions, proximity_n) +# # Add each group as a single term to our set +# for group in groups: +# if len(group) > 1: +# # Sort by length descending to prioritize longer phrases +# sorted_group = sorted(group, key=len, reverse=True) +# # For truly proximate multi-word groups, group words +# queries.add(" OR ".join(sorted_group)) +# else: +# # For non-proximate words or single words, add them separately +# queries.update(group) + +# # Join all terms with " OR " +# results.append(" OR ".join(queries) if queries else "") + +# return results diff --git a/agents-api/tests/test_docs_queries.py b/agents-api/tests/test_docs_queries.py index feec3a6c2..fea7f4fbf 100644 --- a/agents-api/tests/test_docs_queries.py +++ b/agents-api/tests/test_docs_queries.py @@ -387,11 +387,7 @@ async def _(): # Multiple sentences ( "I love basketball especially Michael Jordan. LeBron James is also great.", - [ - "basketball OR lebron james OR michael jordan", - "LeBron James OR Michael Jordan OR basketball", - "Michael Jordan OR basketball OR LeBron James", - ], + "basketball OR lebron james OR michael jordan", ), # Quoted phrases ( @@ -422,14 +418,12 @@ async def _(): result = text_to_tsvector_query(input_text) print(f"Generated query: '{result}'") print(f"Expected: '{expected_output}'\n") - if isinstance(expected_output, list): - assert any( - result.lower() == expected_output.lower() for expected_output in expected_output - ), f"Expected '{expected_output}' but got '{result}' for input '{input_text}'" - else: - assert result.lower() == expected_output.lower(), ( - f"Expected '{expected_output}' but got '{result}' for input '{input_text}'" - ) + + result_terms = set(term.lower() for term in result.split(" OR ") if term) + expected_terms = set(term.lower() for term in expected_output.split(" OR ") if term) + assert result_terms == expected_terms, ( + f"Expected terms {expected_terms} but got {result_terms} for input '{input_text}'" + ) # @test("query: search docs by embedding with different confidence levels")