Skip to content

Commit

Permalink
[Internal] Query: Removes hack for _FullTextWordCount and adds unit t…
Browse files Browse the repository at this point in the history
…ests for Hybrid Search (#4836)

## Description

We remove the hack for `_FullTextWordCount` in the
`HybridSearchCrossPartitionQueryPipelineStage` class, since the backend
has now been updated to accept the correct system function name.

This change also adds unit tests for the
`HybridSearchCrossPartitionQueryPipelineStage` class

## Type of change

- [x] New feature (non-breaking change which adds functionality)
  • Loading branch information
neildsh authored Oct 22, 2024
1 parent b6c4507 commit 408ee12
Show file tree
Hide file tree
Showing 3 changed files with 460 additions and 58 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -118,9 +118,6 @@ TryCatch<IQueryPipelineStage> ComponentPipelineFactory(QueryInfo rewrittenQueryI
{
QueryExecutionOptions queryExecutionOptions = new QueryExecutionOptions(pageSizeHint: maxItemCount);

// TODO: Remove this once the FullTextWordCount is fixed in the backend
queryInfo.GlobalStatisticsQuery = queryInfo.GlobalStatisticsQuery.Replace("_FullTextWordCount", "_FullText_WordCount");

SqlQuerySpec globalStatisticsQuerySpec = new SqlQuerySpec(
queryInfo.GlobalStatisticsQuery,
sqlQuerySpec.Parameters);
Expand Down Expand Up @@ -293,6 +290,7 @@ private async ValueTask<bool> MoveNextAsync_DrainSingletonComponentAsync(ITrace
foreach (CosmosElement cosmosElement in page.Documents)
{
HybridSearchQueryResult hybridSearchQueryResult = HybridSearchQueryResult.Create(cosmosElement);
HybridSearchDebugTraceHelpers.TraceQueryResult(hybridSearchQueryResult);
documents.Add(hybridSearchQueryResult.Payload);
}

Expand Down Expand Up @@ -877,7 +875,7 @@ private static class Placeholders

private static class HybridSearchDebugTraceHelpers
{
private const bool Enabled = true;
private const bool Enabled = false;
#pragma warning disable CS0162 // Unreachable code detected

[Conditional("DEBUG")]
Expand Down Expand Up @@ -926,6 +924,18 @@ public static void TraceQueryResultsWithRanks(IReadOnlyList<HybridSearchQueryRes
}
}

[Conditional("DEBUG")]
public static void TraceQueryResult(HybridSearchQueryResult queryResult)
{
if (Enabled)
{
StringBuilder builder = new StringBuilder();
AppendQueryResult(queryResult, builder);
string row = builder.ToString();
System.Diagnostics.Trace.WriteLine(row);
}
}

private static StringBuilder AppendQueryResult(HybridSearchQueryResult queryResult, StringBuilder builder)
{
builder.Append(queryResult.Rid.Value.ToString());
Expand Down
Original file line number Diff line number Diff line change
@@ -1,23 +1,23 @@
namespace Microsoft.Azure.Cosmos.EmulatorTests.Query
{
using System;
using System.Collections.Generic;
namespace Microsoft.Azure.Cosmos.EmulatorTests.Query
{
using System;
using System.Collections.Generic;
using System.Collections.ObjectModel;
using System.Diagnostics;
using System.IO;
using System.Linq;
using System.Text;
using System.Threading.Tasks;
using Microsoft.Azure.Cosmos.CosmosElements;
using Microsoft.VisualStudio.TestTools.UnitTesting;

[TestClass]
using System.Threading.Tasks;
using Microsoft.Azure.Cosmos.CosmosElements;
using Microsoft.VisualStudio.TestTools.UnitTesting;

[TestClass]
[TestCategory("Query")]
public sealed class HybridSearchQueryTests : QueryTestsBase
{
private const string CollectionDataPath = "Documents\\text-3properties-1536dimensions-100documents.json";

private static readonly IndexingPolicy CompositeIndexPolicy = CreateIndexingPolicy();
private static readonly IndexingPolicy CompositeIndexPolicy = CreateIndexingPolicy();

[Ignore("This test can only be enabled after Direct package and emulator upgrade")]
[TestMethod]
Expand All @@ -26,11 +26,11 @@ public async Task SanityTests()
CosmosArray documentsArray = await LoadDocuments();
IEnumerable<string> documents = documentsArray.Select(document => document.ToString());

await this.CreateIngestQueryDeleteAsync(
connectionModes: ConnectionModes.Direct, // | ConnectionModes.Gateway,
collectionTypes: CollectionTypes.MultiPartition, // | CollectionTypes.SinglePartition,
documents: documents,
query: RunSanityTests,
await this.CreateIngestQueryDeleteAsync(
connectionModes: ConnectionModes.Direct, // | ConnectionModes.Gateway,
collectionTypes: CollectionTypes.MultiPartition, // | CollectionTypes.SinglePartition,
documents: documents,
query: RunSanityTests,
indexingPolicy: CompositeIndexPolicy);
}

Expand All @@ -43,50 +43,56 @@ private static async Task RunSanityTests(Container container, IReadOnlyList<Cosm
FROM c
WHERE FullTextContains(c.title, 'John') OR FullTextContains(c.text, 'John')
ORDER BY RANK FullTextScore(c.title, ['John'])",
new List<int>{ 2, 57, 85 }),
new List<List<int>>{ new List<int>{ 2, 57, 85 }, new List<int>{ 2, 85, 57 } }),
MakeSanityTest(@"
SELECT TOP 10 c.index AS Index, c.title AS Title, c.text AS Text
FROM c
WHERE FullTextContains(c.title, 'John') OR FullTextContains(c.text, 'John')
ORDER BY RANK FullTextScore(c.title, ['John'])",
new List<int>{ 2, 57, 85 }),
new List<List<int>>{ new List<int>{ 2, 57, 85 }, new List<int>{ 2, 85, 57 } }),
MakeSanityTest(@"
SELECT c.index AS Index, c.title AS Title, c.text AS Text
FROM c
WHERE FullTextContains(c.title, 'John') OR FullTextContains(c.text, 'John')
ORDER BY RANK FullTextScore(c.title, ['John'])
OFFSET 1 LIMIT 5",
new List<int>{ 57, 85 }),
new List<List<int>>{ new List<int>{ 57, 85 }, new List<int>{ 85, 57 } }),
MakeSanityTest(@"
SELECT c.index AS Index, c.title AS Title, c.text AS Text
FROM c
WHERE FullTextContains(c.title, 'John') OR FullTextContains(c.text, 'John') OR FullTextContains(c.text, 'United States')
ORDER BY RANK RRF(FullTextScore(c.title, ['John']), FullTextScore(c.text, ['United States']))",
new List<int>{ 61, 51, 49, 54, 75, 24, 77, 76, 80, 25, 22, 2, 66, 57, 85 }),
new List<List<int>>{
new List<int>{ 61, 51, 49, 54, 75, 24, 77, 76, 80, 25, 22, 2, 66, 57, 85 },
new List<int>{ 61, 51, 49, 54, 75, 24, 77, 76, 80, 25, 22, 2, 66, 85, 57 },
}),
MakeSanityTest(@"
SELECT TOP 10 c.index AS Index, c.title AS Title, c.text AS Text
FROM c
WHERE FullTextContains(c.title, 'John') OR FullTextContains(c.text, 'John') OR FullTextContains(c.text, 'United States')
ORDER BY RANK RRF(FullTextScore(c.title, ['John']), FullTextScore(c.text, ['United States']))",
new List<int>{ 61, 51, 49, 54, 75, 24, 77, 76, 80, 25 }),
new List<List<int>>{ new List<int>{ 61, 51, 49, 54, 75, 24, 77, 76, 80, 25 } }),
MakeSanityTest(@"
SELECT c.index AS Index, c.title AS Title, c.text AS Text
FROM c
WHERE FullTextContains(c.title, 'John') OR FullTextContains(c.text, 'John') OR FullTextContains(c.text, 'United States')
ORDER BY RANK RRF(FullTextScore(c.title, ['John']), FullTextScore(c.text, ['United States']))
OFFSET 5 LIMIT 10",
new List<int>{ 24, 77, 76, 80, 25, 22, 2, 66, 57, 85 }),
new List<List<int>>{
new List<int>{ 24, 77, 76, 80, 25, 22, 2, 66, 57, 85 },
new List<int>{ 24, 77, 76, 80, 25, 22, 2, 66, 85, 57 },
}),
MakeSanityTest(@"
SELECT TOP 10 c.index AS Index, c.title AS Title, c.text AS Text
FROM c
ORDER BY RANK RRF(FullTextScore(c.title, ['John']), FullTextScore(c.text, ['United States']))",
new List<int>{ 61, 51, 49, 54, 75, 24, 77, 76, 80, 25 }),
new List<List<int>>{new List<int>{ 61, 51, 49, 54, 75, 24, 77, 76, 80, 25 } }),
MakeSanityTest(@"
SELECT c.index AS Index, c.title AS Title, c.text AS Text
FROM c
ORDER BY RANK RRF(FullTextScore(c.title, ['John']), FullTextScore(c.text, ['United States']))
OFFSET 0 LIMIT 13",
new List<int>{ 61, 51, 49, 54, 75, 24, 77, 76, 80, 25, 22, 2, 66 }),
new List<List<int>>{ new List<int>{ 61, 51, 49, 54, 75, 24, 77, 76, 80, 25, 22, 2, 66 } }),
};

foreach (SanityTestCase testCase in testCases)
Expand All @@ -98,12 +104,27 @@ ORDER BY RANK RRF(FullTextScore(c.title, ['John']), FullTextScore(c.text, ['Unit
queryDrainingMode: QueryDrainingMode.HoldState);

IEnumerable<int> actual = result.Select(document => document.Index);
if (!testCase.ExpectedIndices.SequenceEqual(actual))

bool match = false;
foreach (IReadOnlyList<int> expectedIndices in testCase.ExpectedIndices)
{
if (expectedIndices.SequenceEqual(actual))
{
match = true;
break;
}
}

if (!match)
{
Trace.WriteLine($"Query: {testCase.Query}");
Trace.WriteLine($"Expected: {string.Join(", ", testCase.ExpectedIndices)}");
Trace.WriteLine($"Actual: {string.Join(", ", actual)}");
Assert.Fail("The query results did not match the expected results.");

string errorMessage = @"The query results did not match any of the expected results." +
"Please set HybridSearchCrossPartitionQueryPipelineStage.HybridSearchDebugTraceHelpers.Enabled = true to debug." +
"Usually, the failure may be due to some swaps in the results that have equal scores. You can see this in the debug output." +
"The solution is to add another expected result that matches the actual results (provided the scores are in decresing order).";
Assert.Fail(errorMessage);
}
}
}
Expand All @@ -119,21 +140,21 @@ private static async Task<CosmosArray> LoadDocuments()
return items;
}

private static IndexingPolicy CreateIndexingPolicy()
{
IndexingPolicy policy = new IndexingPolicy();

policy.IncludedPaths.Add(new IncludedPath { Path = IndexingPolicy.DefaultPath });
policy.CompositeIndexes.Add(new Collection<CompositePath>
{
new CompositePath { Path = $"/index" },
new CompositePath { Path = $"/mixedTypefield" },
});

return policy;
private static IndexingPolicy CreateIndexingPolicy()
{
IndexingPolicy policy = new IndexingPolicy();

policy.IncludedPaths.Add(new IncludedPath { Path = IndexingPolicy.DefaultPath });
policy.CompositeIndexes.Add(new Collection<CompositePath>
{
new CompositePath { Path = $"/index" },
new CompositePath { Path = $"/mixedTypefield" },
});

return policy;
}

private static SanityTestCase MakeSanityTest(string query, IReadOnlyList<int> expectedIndices)
private static SanityTestCase MakeSanityTest(string query, IReadOnlyList<IReadOnlyList<int>> expectedIndices)
{
return new SanityTestCase
{
Expand All @@ -146,7 +167,7 @@ private sealed class SanityTestCase
{
public string Query { get; init; }

public IReadOnlyList<int> ExpectedIndices { get; init; }
public IReadOnlyList<IReadOnlyList<int>> ExpectedIndices { get; init; }
}

private sealed class TextDocument
Expand All @@ -166,5 +187,5 @@ private static class FieldNames
public const string Text = "text";
public const string Rid = "_rid";
}
}
}
}
Loading

0 comments on commit 408ee12

Please sign in to comment.