diff --git a/agents-api/agents_api/common/nlp.py b/agents-api/agents_api/common/nlp.py index 00ba3d881..c89339ae2 100644 --- a/agents-api/agents_api/common/nlp.py +++ b/agents-api/agents_api/common/nlp.py @@ -1,15 +1,15 @@ import re -from collections import Counter, defaultdict +from collections import Counter from functools import lru_cache import spacy -from spacy.matcher import PhraseMatcher from spacy.tokens import Doc from spacy.util import filter_spans # Precompile regex patterns WHITESPACE_RE = re.compile(r"\s+") NON_ALPHANUM_RE = re.compile(r"[^\w\s\-_]+") +LONE_HYPHEN_RE = re.compile(r"\s*-\s*(?!\w)|(? Doc: - return nlp.make_doc(text) - - def find_matches(self, doc: Doc, keywords: list[str]) -> dict[str, list[int]]: - """Batch process keywords for better performance.""" - keyword_positions = defaultdict(list) - - # Process keywords in batches to avoid memory issues - for i in range(0, len(keywords), self.batch_size): - batch = keywords[i : i + self.batch_size] - patterns = [self._create_pattern(kw) for kw in batch] - - # Clear previous patterns and add new batch - if "KEYWORDS" in self.matcher: - self.matcher.remove("KEYWORDS") - self.matcher.add("KEYWORDS", patterns) - - # Find matches for this batch - matches = self.matcher(doc) - for match_id, start, end in matches: - span_text = doc[start:end].text - normalized = WHITESPACE_RE.sub(" ", span_text).lower().strip() - keyword_positions[normalized].append(start) - - return keyword_positions - - -# Initialize global matcher -keyword_matcher = KeywordMatcher() - - @lru_cache(maxsize=10000) def clean_keyword(kw: str) -> str: """Cache cleaned keywords for reuse.""" - return NON_ALPHANUM_RE.sub("", kw).strip() + # First remove non-alphanumeric chars (except whitespace, hyphens, underscores) + cleaned = NON_ALPHANUM_RE.sub("", kw).strip() + # Replace lone hyphens with spaces + cleaned = LONE_HYPHEN_RE.sub(" ", cleaned) + # Clean up any resulting multiple spaces + return WHITESPACE_RE.sub(" ", cleaned).strip() -def extract_keywords(doc: Doc, top_n: int = 10, clean: bool = True) -> list[str]: +def extract_keywords(doc: Doc, top_n: int = 25, split_chunks: bool = True) -> list[str]: """Optimized keyword extraction with minimal behavior change.""" excluded_labels = { - "DATE", - "TIME", - "PERCENT", - "MONEY", - "QUANTITY", - "ORDINAL", - "CARDINAL", + "DATE", # Absolute or relative dates or periods. + "TIME", # Times smaller than a day. + "PERCENT", # Percentage, including ”%“. + "MONEY", # Monetary values, including unit. + "QUANTITY", # Measurements, as of weight or distance. + "ORDINAL", # “first”, “second”, etc. + "CARDINAL", # Numerals that do not fall under another type. + # "PERSON", # People, including fictional. + # "NORP", # Nationalities or religious or political groups. + # "FAC", # Buildings, airports, highways, bridges, etc. + # "ORG", # Companies, agencies, institutions, etc. + # "GPE", # Countries, cities, states. + # "LOC", # Non-GPE locations, mountain ranges, bodies of water. + # "PRODUCT", # Objects, vehicles, foods, etc. (Not services.) + # "EVENT", # Named hurricanes, battles, wars, sports events, etc. + # "WORK_OF_ART", # Titles of books, songs, etc. + # "LAW", # Named documents made into laws. + # "LANGUAGE", # Any named language. } # Extract and filter spans in a single pass ent_spans = [ent for ent in doc.ents if ent.label_ not in excluded_labels] - chunk_spans = [chunk for chunk in doc.noun_chunks if not chunk.root.is_stop] + # Add more comprehensive stopword filtering for noun chunks + chunk_spans = [ + chunk + for chunk in doc.noun_chunks + if not chunk.root.is_stop and not all(token.is_stop for token in chunk) + ] all_spans = filter_spans(ent_spans + chunk_spans) - # Process spans efficiently + # Process spans efficiently and filter out spans that are entirely stopwords keywords = [] + ent_keywords = [] seen_texts = set() + # Convert ent_spans to set for faster lookup + ent_spans_set = set(ent_spans) + for span in all_spans: + # Skip if all tokens in span are stopwords + if all(token.is_stop for token in span): + continue + text = span.text.strip() lower_text = text.lower() @@ -110,187 +95,108 @@ def extract_keywords(doc: Doc, top_n: int = 10, clean: bool = True) -> list[str] continue seen_texts.add(lower_text) - keywords.append(text) + ent_keywords.append(text) if span in ent_spans_set else keywords.append(text) # Normalize keywords by replacing multiple spaces with single space and stripping + normalized_ent_keywords = [WHITESPACE_RE.sub(" ", kw).strip() for kw in ent_keywords] normalized_keywords = [WHITESPACE_RE.sub(" ", kw).strip() for kw in keywords] + if split_chunks: + normalized_keywords = [word for kw in normalized_keywords for word in kw.split()] + # Count frequencies efficiently + ent_freq = Counter(normalized_ent_keywords) freq = Counter(normalized_keywords) - top_keywords = [kw for kw, _ in freq.most_common(top_n)] - if clean: - return [clean_keyword(kw) for kw in top_keywords] - return top_keywords + top_keywords = [kw for kw, _ in ent_freq.most_common(top_n)] + remaining_slots = max(0, top_n - len(top_keywords)) + top_keywords += [kw for kw, _ in freq.most_common(remaining_slots)] + return [clean_keyword(kw) for kw in top_keywords] -def find_proximity_groups( - keywords: list[str], keyword_positions: dict[str, list[int]], n: int = 10 -) -> list[set[str]]: - """Optimized proximity grouping using sorted positions.""" - # Early return for single or no keywords - if len(keywords) <= 1: - return [{kw} for kw in keywords] - # Create flat list of positions for efficient processing - positions: list[tuple[int, str]] = [ - (pos, kw) for kw in keywords for pos in keyword_positions[kw] - ] - - # Sort positions once - positions.sort() - - # Initialize Union-Find with path compression and union by rank - parent = {kw: kw for kw in keywords} - rank = dict.fromkeys(keywords, 0) - - def find(u: str) -> str: - if parent[u] != u: - parent[u] = find(parent[u]) - return parent[u] - - def union(u: str, v: str) -> None: - u_root, v_root = find(u), find(v) - if u_root != v_root: - if rank[u_root] < rank[v_root]: - u_root, v_root = v_root, u_root - parent[v_root] = u_root - if rank[u_root] == rank[v_root]: - rank[u_root] += 1 - - # Use sliding window for proximity checking - window = [] - for pos, kw in positions: - # Remove positions outside window - while window and pos - window[0][0] > n: - window.pop(0) - - # Union with all keywords in window - for _, w_kw in window: - union(kw, w_kw) - - window.append((pos, kw)) - - # Group keywords efficiently - groups = defaultdict(set) - for kw in keywords: - root = find(kw) - groups[root].add(kw) - - return list(groups.values()) - - -def build_query_pattern(group_size: int, n: int) -> str: - """Cache query patterns for common group sizes.""" - if group_size == 1: - return '"{}"' - return f"NEAR/{n}(" + " ".join('"{}"' for _ in range(group_size)) + ")" - - -def build_query(groups: list[set[str]], n: int = 10) -> str: - """Build query with cached patterns.""" - clauses = [] - - for group in groups: - if len(group) == 1: - clauses.append(f'"{next(iter(group))}"') - else: - # Sort by length descending to prioritize longer phrases - sorted_group = sorted(group, key=len, reverse=True) - # Get cached pattern and format with keywords - pattern = build_query_pattern(len(group), n) - clause = pattern.format(*sorted_group) - clauses.append(clause) - - return " OR ".join(clauses) - - -@lru_cache(maxsize=100) -def paragraph_to_custom_queries( - paragraph: str, top_n: int = 10, proximity_n: int = 10, min_keywords: int = 1 -) -> list[str]: +@lru_cache(maxsize=1000) +def text_to_tsvector_query( + paragraph: str, top_n: int = 25, min_keywords: int = 1, split_chunks: bool = True +) -> str: """ - Optimized paragraph processing with minimal behavior changes. - Added min_keywords parameter to filter out low-value queries. + Extracts meaningful keywords/phrases from text and joins them with OR. + + Example: + Input: "I like basketball especially Michael Jordan" + Output: "basketball OR Michael Jordan" Args: - paragraph (str): The input paragraph to convert. - top_n (int): Number of top keywords to extract per sentence. - proximity_n (int): The proximity window for NEAR/n. - min_keywords (int): Minimum number of keywords required to form a query. + paragraph (str): The input text to process + top_n (int): Number of top keywords to extract per sentence + min_keywords (int): Minimum number of keywords required + split_chunks (bool): If True, breaks multi-word noun chunks into individual words Returns: - list[str]: The list of custom query strings. + str: Keywords/phrases joined by OR """ if not paragraph or not paragraph.strip(): - return [] + return "" - # Process entire paragraph once doc = nlp(paragraph) - queries = [] + queries = set() # Use set to avoid duplicates - # Process sentences for sent in doc.sents: - # Convert to doc for consistent API sent_doc = sent.as_doc() - # Extract and clean keywords - keywords = extract_keywords(sent_doc, top_n) + # Extract keywords + keywords = extract_keywords(sent_doc, top_n, split_chunks=split_chunks) if len(keywords) < min_keywords: continue - # Find keyword positions using matcher - keyword_positions = keyword_matcher.find_matches(sent_doc, keywords) - - # Skip if no keywords found in positions - if not keyword_positions: - continue - - # Find proximity groups and build query - groups = find_proximity_groups(keywords, keyword_positions, proximity_n) - query = build_query(groups, proximity_n) - - if query: - queries.append(query) - - return queries - - -def batch_paragraphs_to_custom_queries( - paragraphs: list[str], - top_n: int = 10, - proximity_n: int = 10, - min_keywords: int = 1, - n_process: int = 1, -) -> list[list[str]]: - """ - Processes multiple paragraphs using nlp.pipe for better performance. - - Args: - paragraphs (list[str]): list of paragraphs to process. - top_n (int): Number of top keywords to extract per sentence. - proximity_n (int): The proximity window for NEAR/n. - min_keywords (int): Minimum number of keywords required to form a query. - n_process (int): Number of processes to use for multiprocessing. - - Returns: - list[list[str]]: A list where each element is a list of queries for a paragraph. - """ - results = [] - for doc in nlp.pipe(paragraphs, disable=["lemmatizer", "textcat"], n_process=n_process): - queries = [] - for sent in doc.sents: - sent_doc = sent.as_doc() - keywords = extract_keywords(sent_doc, top_n) - if len(keywords) < min_keywords: - continue - keyword_positions = keyword_matcher.find_matches(sent_doc, keywords) - if not keyword_positions: - continue - groups = find_proximity_groups(keywords, keyword_positions, proximity_n) - query = build_query(groups, proximity_n) - if query: - queries.append(query) - results.append(queries) - - return results + queries.update(keywords) + + # Join all terms with " OR " + return " OR ".join(queries) if queries else "" + + +# def batch_text_to_tsvector_queries( +# paragraphs: list[str], +# top_n: int = 10, +# proximity_n: int = 10, +# min_keywords: int = 1, +# n_process: int = 1, +# ) -> list[str]: +# """ +# Processes multiple paragraphs using nlp.pipe for better performance. + +# Args: +# paragraphs (list[str]): List of paragraphs to process +# top_n (int): Number of top keywords to include per paragraph + +# Returns: +# list[str]: List of tsquery strings +# """ +# results = [] + +# for doc in nlp.pipe(paragraphs, disable=["lemmatizer", "textcat"], n_process=n_process): +# queries = set() # Use set to avoid duplicates +# for sent in doc.sents: +# sent_doc = sent.as_doc() +# keywords = extract_keywords(sent_doc, top_n) +# if len(keywords) < min_keywords: +# continue +# keyword_positions = keyword_matcher.find_matches(sent_doc, keywords) +# if not keyword_positions: +# continue +# groups = find_proximity_groups(keywords, keyword_positions, proximity_n) +# # Add each group as a single term to our set +# for group in groups: +# if len(group) > 1: +# # Sort by length descending to prioritize longer phrases +# sorted_group = sorted(group, key=len, reverse=True) +# # For truly proximate multi-word groups, group words +# queries.add(" OR ".join(sorted_group)) +# else: +# # For non-proximate words or single words, add them separately +# queries.update(group) + +# # Join all terms with " OR " +# results.append(" OR ".join(queries) if queries else "") + +# return results diff --git a/agents-api/agents_api/queries/docs/search_docs_by_text.py b/agents-api/agents_api/queries/docs/search_docs_by_text.py index 77fb3a0e6..6632d3162 100644 --- a/agents-api/agents_api/queries/docs/search_docs_by_text.py +++ b/agents-api/agents_api/queries/docs/search_docs_by_text.py @@ -5,6 +5,7 @@ from fastapi import HTTPException from ...autogen.openapi_model import DocReference +from ...common.nlp import text_to_tsvector_query from ...common.utils.db_exceptions import common_db_exceptions from ..utils import pg_query, rewrap_exceptions, wrap_in_class from .utils import transform_to_doc_reference @@ -60,6 +61,8 @@ async def search_docs_by_text( # Extract owner types and IDs owner_types: list[str] = [owner[0] for owner in owners] owner_ids: list[str] = [str(owner[1]) for owner in owners] + # Pre-process rawtext query + query = text_to_tsvector_query(query, split_chunks=True) return ( search_docs_text_query, diff --git a/agents-api/agents_api/queries/docs/search_docs_hybrid.py b/agents-api/agents_api/queries/docs/search_docs_hybrid.py index fe68bc075..6047069f8 100644 --- a/agents-api/agents_api/queries/docs/search_docs_hybrid.py +++ b/agents-api/agents_api/queries/docs/search_docs_hybrid.py @@ -5,6 +5,7 @@ from fastapi import HTTPException from ...autogen.openapi_model import DocReference +from ...common.nlp import text_to_tsvector_query from ...common.utils.db_exceptions import common_db_exceptions from ..utils import ( pg_query, @@ -81,6 +82,9 @@ async def search_docs_hybrid( owner_types: list[str] = [owner[0] for owner in owners] owner_ids: list[str] = [str(owner[1]) for owner in owners] + # Pre-process rawtext query + text_query = text_to_tsvector_query(text_query, split_chunks=True) + return ( search_docs_hybrid_query, [ diff --git a/agents-api/tests/fixtures.py b/agents-api/tests/fixtures.py index b14078d68..5b0ff68cc 100644 --- a/agents-api/tests/fixtures.py +++ b/agents-api/tests/fixtures.py @@ -42,6 +42,7 @@ from .utils import ( get_localstack, get_pg_dsn, + make_vector_with_similarity, ) from .utils import ( patch_embed_acompletion as patch_embed_acompletion_ctx, @@ -164,6 +165,10 @@ async def test_doc(dsn=pg_dsn, developer=test_developer, agent=test_agent): @fixture(scope="test") async def test_doc_with_embedding(dsn=pg_dsn, developer=test_developer, doc=test_doc): pool = await create_db_pool(dsn=dsn) + embedding_with_confidence_0 = make_vector_with_similarity(d=0.0) + embedding_with_confidence_05 = make_vector_with_similarity(d=0.5) + embedding_with_confidence_05_neg = make_vector_with_similarity(d=-0.5) + embedding_with_confidence_1_neg = make_vector_with_similarity(d=-1.0) await pool.execute( """ INSERT INTO docs_embeddings_store (developer_id, doc_id, index, chunk_seq, chunk, embedding) @@ -175,6 +180,54 @@ async def test_doc_with_embedding(dsn=pg_dsn, developer=test_developer, doc=test f"[{', '.join([str(x) for x in [1.0] * 1024])}]", ) + # Insert embedding with confidence 0 with respect to unit vector + await pool.execute( + """ + INSERT INTO docs_embeddings_store (developer_id, doc_id, index, chunk_seq, chunk, embedding) + VALUES ($1, $2, 0, 1, $3, $4) + """, + developer.id, + doc.id, + "Test content 1", + f"[{', '.join([str(x) for x in embedding_with_confidence_0])}]", + ) + + # Insert embedding with confidence 0.5 with respect to unit vector + await pool.execute( + """ + INSERT INTO docs_embeddings_store (developer_id, doc_id, index, chunk_seq, chunk, embedding) + VALUES ($1, $2, 0, 2, $3, $4) + """, + developer.id, + doc.id, + "Test content 2", + f"[{', '.join([str(x) for x in embedding_with_confidence_05])}]", + ) + + # Insert embedding with confidence -0.5 with respect to unit vector + await pool.execute( + """ + INSERT INTO docs_embeddings_store (developer_id, doc_id, index, chunk_seq, chunk, embedding) + VALUES ($1, $2, 0, 3, $3, $4) + """, + developer.id, + doc.id, + "Test content 3", + f"[{', '.join([str(x) for x in embedding_with_confidence_05_neg])}]", + ) + + # Insert embedding with confidence -1 with respect to unit vector + await pool.execute( + """ + INSERT INTO docs_embeddings_store (developer_id, doc_id, index, chunk_seq, chunk, embedding) + VALUES ($1, $2, 0, 4, $3, $4) + """, + developer.id, + doc.id, + "Test content 4", + f"[{', '.join([str(x) for x in embedding_with_confidence_1_neg])}]", + ) + yield await get_doc(developer_id=developer.id, doc_id=doc.id, connection_pool=pool) diff --git a/agents-api/tests/test_docs_queries.py b/agents-api/tests/test_docs_queries.py index 6690badfd..7782b3bf7 100644 --- a/agents-api/tests/test_docs_queries.py +++ b/agents-api/tests/test_docs_queries.py @@ -257,6 +257,68 @@ async def _(dsn=pg_dsn, agent=test_agent, developer=test_developer): assert result[0].metadata == {"test": "test"}, "Metadata should match" +@test("query: search docs by text with technical terms and phrases") +async def _(dsn=pg_dsn, developer=test_developer, agent=test_agent): + pool = await create_db_pool(dsn=dsn) + + # Create documents with technical content + doc1 = await create_doc( + developer_id=developer.id, + owner_type="agent", + owner_id=agent.id, + data=CreateDocRequest( + title="Technical Document", + content="API endpoints using REST architecture with JSON payloads", + metadata={"domain": "technical"}, + embed_instruction="Embed the document", + ), + connection_pool=pool, + ) + + doc2 = await create_doc( + developer_id=developer.id, + owner_type="agent", + owner_id=agent.id, + data=CreateDocRequest( + title="More Technical Terms", + content="Database optimization using indexing and query planning", + metadata={"domain": "technical"}, + embed_instruction="Embed the document", + ), + connection_pool=pool, + ) + + # Test with technical terms + technical_queries = [ + "API endpoints", + "REST architecture", + "database optimization", + "indexing", + ] + + for query in technical_queries: + results = await search_docs_by_text( + developer_id=developer.id, + owners=[("agent", agent.id)], + query=query, + k=3, + search_language="english", + connection_pool=pool, + ) + + print(f"\nSearch results for '{query}':", results) + + # Verify appropriate document is found based on query + if "API" in query or "REST" in query: + assert any(doc.id == doc1.id for doc in results), ( + f"Doc1 should be found with query '{query}'" + ) + if "database" in query.lower() or "indexing" in query: + assert any(doc.id == doc2.id for doc in results), ( + f"Doc2 should be found with query '{query}'" + ) + + @test("query: search docs by embedding") async def _( dsn=pg_dsn, agent=test_agent, developer=test_developer, doc=test_doc_with_embedding @@ -304,3 +366,48 @@ async def _( assert len(result) >= 1 assert result[0].metadata is not None + + +# @test("query: search docs by embedding with different confidence levels") +# async def _( +# dsn=pg_dsn, agent=test_agent, developer=test_developer, doc=test_doc_with_embedding +# ): +# pool = await create_db_pool(dsn=dsn) + +# # Get query embedding (using original doc's embedding) +# query_embedding = make_vector_with_similarity(EMBEDDING_SIZE, 0.7) + +# # Test with different confidence levels +# confidence_tests = [ +# (0.99, 0), # Very high similarity threshold - should find no results +# (0.7, 1), # High similarity - should find 1 result (the embedding with all 1.0s) +# (0.3, 2), # Medium similarity - should find 2 results (including 0.3-0.7 embedding) +# (-0.8, 3), # Low similarity - should find 3 results (including -0.8 to 0.8 embedding) +# (-1.0, 4) # Lowest similarity - should find all 4 results (including alternating -1/1) +# ] + +# for confidence, expected_min_results in confidence_tests: +# results = await search_docs_by_embedding( +# developer_id=developer.id, +# owners=[("agent", agent.id)], +# embedding=query_embedding, +# k=3, +# confidence=confidence, +# metadata_filter={"test": "test"}, +# connection_pool=pool, +# ) + +# print(f"\nSearch results with confidence {confidence}:") +# for r in results: +# print(f"- Doc ID: {r.id}, Distance: {r.distance}") + +# assert len(results) >= expected_min_results, ( +# f"Expected at least {expected_min_results} results with confidence {confidence}, got {len(results)}" +# ) + +# if results: +# # Verify that all returned results meet the confidence threshold +# for result in results: +# assert result.distance >= confidence, ( +# f"Result distance {result.distance} is below confidence threshold {confidence}" +# ) diff --git a/agents-api/tests/test_nlp_utilities.py b/agents-api/tests/test_nlp_utilities.py new file mode 100644 index 000000000..733f695d5 --- /dev/null +++ b/agents-api/tests/test_nlp_utilities.py @@ -0,0 +1,191 @@ +import spacy +from agents_api.common.nlp import clean_keyword, extract_keywords, text_to_tsvector_query +from ward import test + + +@test("utility: clean_keyword") +async def _(): + assert clean_keyword("Hello, World!") == "Hello World" + + # Basic cleaning + # assert clean_keyword("test@example.com") == "test example com" + assert clean_keyword("user-name_123") == "user-name_123" + assert clean_keyword(" spaces ") == "spaces" + + # Special characters + assert clean_keyword("$price: 100%") == "price 100" + assert clean_keyword("#hashtag!") == "hashtag" + + # Multiple spaces and punctuation + assert clean_keyword("multiple, spaces...") == "multiple spaces" + + # Empty and whitespace + assert clean_keyword("") == "" + assert clean_keyword(" ") == "" + + assert clean_keyword("- try") == "try" + + +@test("utility: extract_keywords - split_chunks=False") +async def _(): + nlp = spacy.load("en_core_web_sm", exclude=["lemmatizer", "textcat"]) + doc = nlp("John Doe is a software engineer at Google.") + assert set(extract_keywords(doc, split_chunks=False)) == { + "John Doe", + "a software engineer", + "Google", + } + + +@test("utility: extract_keywords - split_chunks=True") +async def _(): + nlp = spacy.load("en_core_web_sm", exclude=["lemmatizer", "textcat"]) + doc = nlp("John Doe is a software engineer at Google.") + assert set(extract_keywords(doc, split_chunks=True)) == { + "John Doe", + "a", + "software", + "engineer", + "Google", + } + + +@test("utility: text_to_tsvector_query - split_chunks=False") +async def _(): + test_cases = [ + # Single words + ("test", "test"), + # Multiple words in single sentence + ( + "quick brown fox", + "quick brown fox", # Now kept as a single phrase due to proximity + ), + # Technical terms and phrases + ( + "Machine Learning algorithm", + "machine learning algorithm", # Common technical phrase + ), + # Multiple sentences + ( + "I love basketball especially Michael Jordan. LeBron James is also great.", + "basketball OR lebron james OR michael jordan", + ), + # Quoted phrases + ( + '"quick brown fox"', + "quick brown fox", # Quotes removed, phrase kept together + ), + ('Find "machine learning" algorithms', "machine learning"), + # Multiple quoted phrases + ('"data science" and "machine learning"', "machine learning OR data science"), + # Edge cases + ("", ""), + ( + "the and or", + "", # All stop words should result in empty string + ), + ( + "a", + "", # Single stop word should result in empty string + ), + ("X", "X"), + # Empty quotes + ('""', ""), + ('test "" phrase', "phrase OR test"), + ( + "John Doe is a software engineer at Google.", + "google OR john doe OR a software engineer", + ), + ("- google", "google"), + # Test duplicate keyword handling + ( + "John Doe is great. John Doe is awesome.", + "john doe", # Should only include "John Doe" once + ), + ( + "Software Engineer at Google. Also, a Software Engineer.", + "Google OR Also a Software Engineer OR Software Engineer", # Should only include "Software Engineer" once + ), + ] + + for input_text, expected_output in test_cases: + print(f"Input: '{input_text}'") + result = text_to_tsvector_query(input_text, split_chunks=False) + print(f"Generated query: '{result}'") + print(f"Expected: '{expected_output}'\n") + + result_terms = {term.lower() for term in result.split(" OR ") if term} + expected_terms = {term.lower() for term in expected_output.split(" OR ") if term} + assert result_terms == expected_terms, ( + f"Expected terms {expected_terms} but got {result_terms} for input '{input_text}'" + ) + + +@test("utility: text_to_tsvector_query - split_chunks=True") +async def _(): + test_cases = [ + # Single words + ("test", "test"), + # Multiple words in single sentence + ( + "quick brown fox", + "quick OR brown OR fox", # Now kept as a single phrase due to proximity + ), + # Technical terms and phrases + ( + "Machine Learning algorithm", + "machine OR learning OR algorithm", # Common technical phrase + ), + # Multiple sentences + ( + "I love basketball especially Michael Jordan. LeBron James is also great.", + "basketball OR lebron james OR michael jordan", + ), + # Quoted phrases + ( + '"quick brown fox"', + "quick OR brown OR fox", # Quotes removed, phrase kept together + ), + ('Find "machine learning" algorithms', "machine OR learning"), + # Multiple quoted phrases + ('"data science" and "machine learning"', "machine OR learning OR data OR science"), + # Edge cases + ("", ""), + ( + "the and or", + "", # All stop words should result in empty string + ), + ( + "a", + "", # Single stop word should result in empty string + ), + ("X", "X"), + # Empty quotes + ('""', ""), + ('test "" phrase', "phrase OR test"), + ( + "John Doe is a software engineer at Google.", + "google OR john doe OR a OR software OR engineer", + ), + # Test duplicate keyword handling + ( + "John Doe is great. John Doe is awesome.", + "john doe", # Should only include "John Doe" once even with split_chunks=True + ), + ( + "Software Engineer at Google. Also, a Software Engineer.", + "Also OR a OR google OR software OR engineer", # When split, each word appears once + ), + ] + + for input_text, expected_output in test_cases: + print(f"Input: '{input_text}'") + result = text_to_tsvector_query(input_text, split_chunks=True) + print(f"Generated query: '{result}'") + print(f"Expected: '{expected_output}'\n") + + result_terms = {term.lower() for term in result.split(" OR ") if term} + expected_terms = {term.lower() for term in expected_output.split(" OR ") if term} + assert result_terms == expected_terms, ( + f"Expected terms {expected_terms} but got {result_terms} for input '{input_text}'" + ) diff --git a/agents-api/tests/utils.py b/agents-api/tests/utils.py index 05544e048..45489befd 100644 --- a/agents-api/tests/utils.py +++ b/agents-api/tests/utils.py @@ -1,5 +1,6 @@ import asyncio import logging +import math import os import subprocess from contextlib import asynccontextmanager, contextmanager @@ -19,6 +20,52 @@ EMBEDDING_SIZE: int = 1024 +def make_vector_with_similarity(n: int = EMBEDDING_SIZE, d: float = 0.5): + """ + Returns a list `v` of length `n` such that the cosine similarity + between `v` and the all-ones vector of length `n` is approximately d. + """ + if not -1.0 <= d <= 1.0: + msg = "d must lie in [-1, 1]." + raise ValueError(msg) + + # Handle special cases exactly: + if abs(d - 1.0) < 1e-12: # d ~ +1 + return [1.0] * n + if abs(d + 1.0) < 1e-12: # d ~ -1 + return [-1.0] * n + if abs(d) < 1e-12: # d ~ 0 + v = [0.0] * n + if n >= 2: + v[0] = 1.0 + v[1] = -1.0 + return v + + sign_d = 1.0 if d >= 0 else -1.0 + + # Base part: sign(d)*[1,1,...,1] + base = [sign_d] * n + + # Orthogonal unit vector u with sum(u)=0; for simplicity: + # u = [1/sqrt(2), -1/sqrt(2), 0, 0, ..., 0] + u = [0.0] * n + if n >= 2: + u[0] = 1.0 / math.sqrt(2) + u[1] = -1.0 / math.sqrt(2) + # (if n=1, there's no truly orthogonal vector to [1], so skip) + + # Solve for alpha: + # alpha^2 = n*(1 - d^2)/d^2 + alpha = math.sqrt(n * (1 - d * d)) / abs(d) + + # Construct v + v = [0.0] * n + for i in range(n): + v[i] = base[i] + alpha * u[i] + + return v + + @asynccontextmanager async def patch_testing_temporal(): # Set log level to ERROR to avoid spamming the console