diff --git a/build.gradle b/build.gradle index 6e67c5792..320aa7841 100644 --- a/build.gradle +++ b/build.gradle @@ -699,9 +699,6 @@ List jacocoExclusions = [ // TODO: add test coverage (kaituo) 'org.opensearch.forecast.*', - 'org.opensearch.timeseries.transport.ResultBulkTransportAction', - 'org.opensearch.timeseries.transport.handler.IndexMemoryPressureAwareResultHandler', - 'org.opensearch.timeseries.transport.handler.ResultIndexingHandler', 'org.opensearch.timeseries.ml.Sample', 'org.opensearch.timeseries.ratelimit.FeatureRequest', 'org.opensearch.ad.transport.ADHCImputeNodeRequest', diff --git a/release-notes/opensearch-anomaly-detection.release-notes-2.18.0.0.md b/release-notes/opensearch-anomaly-detection.release-notes-2.18.0.0.md index fee194023..b62a3ab10 100644 --- a/release-notes/opensearch-anomaly-detection.release-notes-2.18.0.0.md +++ b/release-notes/opensearch-anomaly-detection.release-notes-2.18.0.0.md @@ -7,6 +7,7 @@ Compatible with OpenSearch 2.18.0 ### Bug Fixes * Bump RCF Version and Fix Default Rules Bug in AnomalyDetector ([#1334](https://github.com/opensearch-project/anomaly-detection/pull/1334)) +* Fix race condition in PageListener ([#1351](https://github.com/opensearch-project/anomaly-detection/pull/1351)) ### Infrastructure * forward port flaky test fix and add forecasting security tests ([#1329](https://github.com/opensearch-project/anomaly-detection/pull/1329)) diff --git a/src/main/java/org/opensearch/forecast/transport/ForecastResultBulkTransportAction.java b/src/main/java/org/opensearch/forecast/transport/ForecastResultBulkTransportAction.java index dcb792fba..dcdd0680a 100644 --- a/src/main/java/org/opensearch/forecast/transport/ForecastResultBulkTransportAction.java +++ b/src/main/java/org/opensearch/forecast/transport/ForecastResultBulkTransportAction.java @@ -58,7 +58,7 @@ public ForecastResultBulkTransportAction( } @Override - protected BulkRequest prepareBulkRequest(float indexingPressurePercent, ForecastResultBulkRequest request) { + public BulkRequest prepareBulkRequest(float indexingPressurePercent, ForecastResultBulkRequest request) { BulkRequest bulkRequest = new BulkRequest(); List results = request.getResults(); diff --git a/src/main/java/org/opensearch/timeseries/rest/handler/IntervalCalculation.java b/src/main/java/org/opensearch/timeseries/rest/handler/IntervalCalculation.java index ab17f91cf..53c503626 100644 --- a/src/main/java/org/opensearch/timeseries/rest/handler/IntervalCalculation.java +++ b/src/main/java/org/opensearch/timeseries/rest/handler/IntervalCalculation.java @@ -252,8 +252,9 @@ private void findMinimumInterval(LongBounds timeStampBounds, ActionListener searchResponseListener = ActionListener.wrap(response -> { List timestamps = aggregationPrep.getTimestamps(response); - if (timestamps.isEmpty()) { - logger.warn("empty data, return one minute by default"); + if (timestamps.size() < 2) { + // to calculate the difference we need at least 2 timestamps + logger.warn("not enough data, return one minute by default"); listener.onResponse(new IntervalTimeConfiguration(1, ChronoUnit.MINUTES)); return; } diff --git a/src/main/java/org/opensearch/timeseries/transport/handler/ResultIndexingHandler.java b/src/main/java/org/opensearch/timeseries/transport/handler/ResultIndexingHandler.java index d7beb64a8..1647fa01a 100644 --- a/src/main/java/org/opensearch/timeseries/transport/handler/ResultIndexingHandler.java +++ b/src/main/java/org/opensearch/timeseries/transport/handler/ResultIndexingHandler.java @@ -34,7 +34,6 @@ import org.opensearch.core.concurrency.OpenSearchRejectedExecutionException; import org.opensearch.core.xcontent.XContentBuilder; import org.opensearch.threadpool.ThreadPool; -import org.opensearch.timeseries.common.exception.EndRunException; import org.opensearch.timeseries.common.exception.TimeSeriesException; import org.opensearch.timeseries.indices.IndexManagement; import org.opensearch.timeseries.indices.TimeSeriesIndex; @@ -109,22 +108,28 @@ public void setFixedDoc(boolean fixedDoc) { } // TODO: check if user has permission to index. - public void index(ResultType toSave, String detectorId, String indexOrAliasName) { - try { - if (indexOrAliasName != null) { - if (indexUtils.checkIndicesBlocked(clusterService.state(), ClusterBlockLevel.WRITE, indexOrAliasName)) { - LOG.warn(String.format(Locale.ROOT, CANNOT_SAVE_ERR_MSG, detectorId)); - return; - } - // We create custom result index when creating a detector. Custom result index can be rolled over and thus we may need to - // create a new one. - if (!timeSeriesIndices.doesIndexExist(indexOrAliasName) && !timeSeriesIndices.doesAliasExist(indexOrAliasName)) { - timeSeriesIndices.initCustomResultIndexDirectly(indexOrAliasName, ActionListener.wrap(response -> { - if (response.isAcknowledged()) { - save(toSave, detectorId, indexOrAliasName); - } else { - throw new TimeSeriesException( - detectorId, + /** + * Run async index operation. Cannot guarantee index is done after finishing executing the function as several calls + * in the method are asynchronous. + * @param toSave Result to save + * @param configId config id + * @param indexOrAliasName custom index or alias name + */ + public void index(ResultType toSave, String configId, String indexOrAliasName) { + if (indexOrAliasName != null) { + if (indexUtils.checkIndicesBlocked(clusterService.state(), ClusterBlockLevel.WRITE, indexOrAliasName)) { + LOG.warn(String.format(Locale.ROOT, CANNOT_SAVE_ERR_MSG, configId)); + return; + } + // We create custom result index when creating a detector. Custom result index can be rolled over and thus we may need to + // create a new one. + if (!timeSeriesIndices.doesIndexExist(indexOrAliasName) && !timeSeriesIndices.doesAliasExist(indexOrAliasName)) { + timeSeriesIndices.initCustomResultIndexDirectly(indexOrAliasName, ActionListener.wrap(response -> { + if (response.isAcknowledged()) { + save(toSave, configId, indexOrAliasName); + } else { + LOG + .error( String .format( Locale.ROOT, @@ -132,65 +137,49 @@ public void index(ResultType toSave, String detectorId, String indexOrAliasName) indexOrAliasName ) ); - } - }, exception -> { - if (ExceptionsHelper.unwrapCause(exception) instanceof ResourceAlreadyExistsException) { - // It is possible the index has been created while we sending the create request - save(toSave, detectorId, indexOrAliasName); - } else { - throw new TimeSeriesException( - detectorId, - String.format(Locale.ROOT, "cannot create result index %s", indexOrAliasName), - exception - ); - } - })); - } else { - timeSeriesIndices.validateResultIndexMapping(indexOrAliasName, ActionListener.wrap(valid -> { - if (!valid) { - throw new EndRunException(detectorId, "wrong index mapping of custom AD result index", true); - } else { - save(toSave, detectorId, indexOrAliasName); - } - }, exception -> { - throw new TimeSeriesException( - detectorId, - String.format(Locale.ROOT, "cannot validate result index %s", indexOrAliasName), - exception - ); - })); - } + } + }, exception -> { + if (ExceptionsHelper.unwrapCause(exception) instanceof ResourceAlreadyExistsException) { + // It is possible the index has been created while we sending the create request + save(toSave, configId, indexOrAliasName); + } else { + LOG.error(String.format(Locale.ROOT, "cannot create result index %s", indexOrAliasName), exception); + } + })); } else { - if (indexUtils.checkIndicesBlocked(clusterService.state(), ClusterBlockLevel.WRITE, this.defaultResultIndexName)) { - LOG.warn(String.format(Locale.ROOT, CANNOT_SAVE_ERR_MSG, detectorId)); - return; - } - if (!timeSeriesIndices.doesDefaultResultIndexExist()) { - timeSeriesIndices - .initDefaultResultIndexDirectly( - ActionListener.wrap(initResponse -> onCreateIndexResponse(initResponse, toSave, detectorId), exception -> { - if (ExceptionsHelper.unwrapCause(exception) instanceof ResourceAlreadyExistsException) { - // It is possible the index has been created while we sending the create request - save(toSave, detectorId); - } else { - throw new TimeSeriesException( - detectorId, + timeSeriesIndices.validateResultIndexMapping(indexOrAliasName, ActionListener.wrap(valid -> { + if (!valid) { + LOG.error("wrong index mapping of custom result index"); + } else { + save(toSave, configId, indexOrAliasName); + } + }, exception -> { LOG.error(String.format(Locale.ROOT, "cannot validate result index %s", indexOrAliasName), exception); }) + ); + } + } else { + if (indexUtils.checkIndicesBlocked(clusterService.state(), ClusterBlockLevel.WRITE, this.defaultResultIndexName)) { + LOG.warn(String.format(Locale.ROOT, CANNOT_SAVE_ERR_MSG, configId)); + return; + } + if (!timeSeriesIndices.doesDefaultResultIndexExist()) { + timeSeriesIndices + .initDefaultResultIndexDirectly( + ActionListener.wrap(initResponse -> onCreateIndexResponse(initResponse, toSave, configId), exception -> { + if (ExceptionsHelper.unwrapCause(exception) instanceof ResourceAlreadyExistsException) { + // It is possible the index has been created while we sending the create request + save(toSave, configId); + } else { + LOG + .error( String.format(Locale.ROOT, "Unexpected error creating index %s", defaultResultIndexName), exception ); - } - }) - ); - } else { - save(toSave, detectorId); - } + } + }) + ); + } else { + save(toSave, configId); } - } catch (Exception e) { - throw new TimeSeriesException( - detectorId, - String.format(Locale.ROOT, "Error in saving %s for detector %s", defaultResultIndexName, detectorId), - e - ); } } diff --git a/src/test/java/org/opensearch/ad/transport/handler/AnomalyResultHandlerTests.java b/src/test/java/org/opensearch/ad/transport/handler/AnomalyResultHandlerTests.java index 73efdf4a1..8db056c0c 100644 --- a/src/test/java/org/opensearch/ad/transport/handler/AnomalyResultHandlerTests.java +++ b/src/test/java/org/opensearch/ad/transport/handler/AnomalyResultHandlerTests.java @@ -12,11 +12,14 @@ package org.opensearch.ad.transport.handler; import static org.mockito.ArgumentMatchers.any; +import static org.mockito.ArgumentMatchers.anyString; +import static org.mockito.ArgumentMatchers.eq; import static org.mockito.Mockito.doAnswer; import static org.mockito.Mockito.mock; import static org.mockito.Mockito.never; import static org.mockito.Mockito.times; import static org.mockito.Mockito.verify; +import static org.mockito.Mockito.when; import java.io.IOException; import java.time.Clock; @@ -31,6 +34,8 @@ import org.junit.rules.ExpectedException; import org.mockito.ArgumentMatchers; import org.mockito.Mock; +import org.opensearch.ResourceAlreadyExistsException; +import org.opensearch.action.admin.indices.create.CreateIndexResponse; import org.opensearch.action.index.IndexRequest; import org.opensearch.action.index.IndexResponse; import org.opensearch.ad.constant.ADCommonName; @@ -44,7 +49,6 @@ import org.opensearch.core.concurrency.OpenSearchRejectedExecutionException; import org.opensearch.timeseries.NodeStateManager; import org.opensearch.timeseries.TestHelpers; -import org.opensearch.timeseries.common.exception.TimeSeriesException; import org.opensearch.timeseries.transport.handler.ResultIndexingHandler; public class AnomalyResultHandlerTests extends AbstractIndexHandlerTest { @@ -181,9 +185,6 @@ public void testAdResultIndexExist() throws IOException { @Test public void testAdResultIndexOtherException() throws IOException { - expectedEx.expect(TimeSeriesException.class); - expectedEx.expectMessage("Error in saving .opendistro-anomaly-results for detector " + detectorId); - setUpSavingAnomalyResultIndex(false, IndexCreation.RUNTIME_EXCEPTION); ResultIndexingHandler handler = new ResultIndexingHandler<>( client, @@ -199,6 +200,7 @@ public void testAdResultIndexOtherException() throws IOException { ); handler.index(TestHelpers.randomAnomalyDetectResult(), detectorId, null); verify(client, never()).index(any(), any()); + assertTrue(testAppender.containsMessage(String.format(Locale.ROOT, "Unexpected error creating index .opendistro-anomaly-results"))); } /** @@ -212,7 +214,6 @@ public void testAdResultIndexOtherException() throws IOException { * @throws InterruptedException if thread execution is interrupted * @throws IOException if IO failures */ - @SuppressWarnings("unchecked") private void savingFailureTemplate(boolean throwOpenSearchRejectedExecutionException, int latchCount, boolean adResultIndexExists) throws InterruptedException, IOException { @@ -262,4 +263,218 @@ private void savingFailureTemplate(boolean throwOpenSearchRejectedExecutionExcep backoffLatch.await(1, TimeUnit.MINUTES); } + + @Test + public void testCustomIndexCreate() { + String testIndex = "test"; + setWriteBlockAdResultIndex(false); + when(anomalyDetectionIndices.doesIndexExist(anyString())).thenReturn(false); + when(anomalyDetectionIndices.doesAliasExist(anyString())).thenReturn(false); + doAnswer(invocation -> { + ActionListener listener = invocation.getArgument(1); + listener.onResponse(new CreateIndexResponse(true, true, testIndex)); + return null; + }).when(anomalyDetectionIndices).initCustomResultIndexDirectly(eq(testIndex), any()); + + ResultIndexingHandler handler = new ResultIndexingHandler<>( + client, + settings, + threadPool, + ADCommonName.ANOMALY_RESULT_INDEX_ALIAS, + anomalyDetectionIndices, + clientUtil, + indexUtil, + clusterService, + AnomalyDetectorSettings.AD_BACKOFF_INITIAL_DELAY, + AnomalyDetectorSettings.AD_MAX_RETRY_FOR_BACKOFF + ); + handler.index(TestHelpers.randomAnomalyDetectResult(), detectorId, testIndex); + verify(client, times(1)).index(any(), any()); + } + + @Test + public void testCustomIndexCreateNotAcked() { + String testIndex = "test"; + setWriteBlockAdResultIndex(false); + when(anomalyDetectionIndices.doesIndexExist(anyString())).thenReturn(false); + when(anomalyDetectionIndices.doesAliasExist(anyString())).thenReturn(false); + doAnswer(invocation -> { + ActionListener listener = invocation.getArgument(1); + listener.onResponse(new CreateIndexResponse(false, false, testIndex)); + return null; + }).when(anomalyDetectionIndices).initCustomResultIndexDirectly(eq(testIndex), any()); + + ResultIndexingHandler handler = new ResultIndexingHandler<>( + client, + settings, + threadPool, + ADCommonName.ANOMALY_RESULT_INDEX_ALIAS, + anomalyDetectionIndices, + clientUtil, + indexUtil, + clusterService, + AnomalyDetectorSettings.AD_BACKOFF_INITIAL_DELAY, + AnomalyDetectorSettings.AD_MAX_RETRY_FOR_BACKOFF + ); + handler.index(TestHelpers.randomAnomalyDetectResult(), detectorId, testIndex); + + assertTrue( + testAppender + .containsMessage( + String.format(Locale.ROOT, "Creating custom result index %s with mappings call not acknowledged", testIndex) + ) + ); + } + + @Test + public void testCustomIndexCreateExists() { + String testIndex = "test"; + setWriteBlockAdResultIndex(false); + when(anomalyDetectionIndices.doesIndexExist(anyString())).thenReturn(false); + when(anomalyDetectionIndices.doesAliasExist(anyString())).thenReturn(false); + doAnswer(invocation -> { + ActionListener listener = invocation.getArgument(1); + listener.onFailure(new ResourceAlreadyExistsException("index already exists")); + return null; + }).when(anomalyDetectionIndices).initCustomResultIndexDirectly(eq(testIndex), any()); + + ResultIndexingHandler handler = new ResultIndexingHandler<>( + client, + settings, + threadPool, + ADCommonName.ANOMALY_RESULT_INDEX_ALIAS, + anomalyDetectionIndices, + clientUtil, + indexUtil, + clusterService, + AnomalyDetectorSettings.AD_BACKOFF_INITIAL_DELAY, + AnomalyDetectorSettings.AD_MAX_RETRY_FOR_BACKOFF + ); + handler.index(TestHelpers.randomAnomalyDetectResult(), detectorId, testIndex); + verify(client, times(1)).index(any(), any()); + } + + @Test + public void testCustomIndexOtherException() { + String testIndex = "test"; + setWriteBlockAdResultIndex(false); + when(anomalyDetectionIndices.doesIndexExist(anyString())).thenReturn(false); + when(anomalyDetectionIndices.doesAliasExist(anyString())).thenReturn(false); + + Exception testException = new OpenSearchRejectedExecutionException("Test exception"); + + doAnswer(invocation -> { + ActionListener listener = invocation.getArgument(1); + listener.onFailure(testException); + return null; + }).when(anomalyDetectionIndices).initCustomResultIndexDirectly(eq(testIndex), any()); + + ResultIndexingHandler handler = new ResultIndexingHandler<>( + client, + settings, + threadPool, + ADCommonName.ANOMALY_RESULT_INDEX_ALIAS, + anomalyDetectionIndices, + clientUtil, + indexUtil, + clusterService, + AnomalyDetectorSettings.AD_BACKOFF_INITIAL_DELAY, + AnomalyDetectorSettings.AD_MAX_RETRY_FOR_BACKOFF + ); + handler.index(TestHelpers.randomAnomalyDetectResult(), detectorId, testIndex); + + assertTrue(testAppender.containsMessage(String.format(Locale.ROOT, "cannot create result index %s", testIndex))); + } + + @Test + public void testInvalid() { + String testIndex = "test"; + setWriteBlockAdResultIndex(false); + when(anomalyDetectionIndices.doesIndexExist(anyString())).thenReturn(false); + when(anomalyDetectionIndices.doesAliasExist(anyString())).thenReturn(true); + + doAnswer(invocation -> { + ActionListener listener = invocation.getArgument(1); + listener.onResponse(false); + return null; + }).when(anomalyDetectionIndices).validateResultIndexMapping(eq(testIndex), any()); + + ResultIndexingHandler handler = new ResultIndexingHandler<>( + client, + settings, + threadPool, + ADCommonName.ANOMALY_RESULT_INDEX_ALIAS, + anomalyDetectionIndices, + clientUtil, + indexUtil, + clusterService, + AnomalyDetectorSettings.AD_BACKOFF_INITIAL_DELAY, + AnomalyDetectorSettings.AD_MAX_RETRY_FOR_BACKOFF + ); + handler.index(TestHelpers.randomAnomalyDetectResult(), detectorId, testIndex); + + assertTrue(testAppender.containsMessage("wrong index mapping of custom result index", false)); + } + + @Test + public void testValid() { + String testIndex = "test"; + setWriteBlockAdResultIndex(false); + when(anomalyDetectionIndices.doesIndexExist(anyString())).thenReturn(false); + when(anomalyDetectionIndices.doesAliasExist(anyString())).thenReturn(true); + + doAnswer(invocation -> { + ActionListener listener = invocation.getArgument(1); + listener.onResponse(true); + return null; + }).when(anomalyDetectionIndices).validateResultIndexMapping(eq(testIndex), any()); + + ResultIndexingHandler handler = new ResultIndexingHandler<>( + client, + settings, + threadPool, + ADCommonName.ANOMALY_RESULT_INDEX_ALIAS, + anomalyDetectionIndices, + clientUtil, + indexUtil, + clusterService, + AnomalyDetectorSettings.AD_BACKOFF_INITIAL_DELAY, + AnomalyDetectorSettings.AD_MAX_RETRY_FOR_BACKOFF + ); + + handler.index(TestHelpers.randomAnomalyDetectResult(), detectorId, testIndex); + verify(client, times(1)).index(any(), any()); + } + + @Test + public void testValidationException() { + String testIndex = "test"; + setWriteBlockAdResultIndex(false); + when(anomalyDetectionIndices.doesIndexExist(anyString())).thenReturn(false); + when(anomalyDetectionIndices.doesAliasExist(anyString())).thenReturn(true); + + Exception testException = new OpenSearchRejectedExecutionException("Test exception"); + + doAnswer(invocation -> { + ActionListener listener = invocation.getArgument(1); + listener.onFailure(testException); + return null; + }).when(anomalyDetectionIndices).validateResultIndexMapping(eq(testIndex), any()); + + ResultIndexingHandler handler = new ResultIndexingHandler<>( + client, + settings, + threadPool, + ADCommonName.ANOMALY_RESULT_INDEX_ALIAS, + anomalyDetectionIndices, + clientUtil, + indexUtil, + clusterService, + AnomalyDetectorSettings.AD_BACKOFF_INITIAL_DELAY, + AnomalyDetectorSettings.AD_MAX_RETRY_FOR_BACKOFF + ); + + handler.index(TestHelpers.randomAnomalyDetectResult(), detectorId, testIndex); + assertTrue(testAppender.containsMessage(String.format(Locale.ROOT, "cannot validate result index %s", testIndex), false)); + } } diff --git a/src/test/java/org/opensearch/forecast/AbstractForecastSyntheticDataTest.java b/src/test/java/org/opensearch/forecast/AbstractForecastSyntheticDataTest.java index 2449ccd33..d17783e1b 100644 --- a/src/test/java/org/opensearch/forecast/AbstractForecastSyntheticDataTest.java +++ b/src/test/java/org/opensearch/forecast/AbstractForecastSyntheticDataTest.java @@ -21,18 +21,22 @@ import java.io.IOException; import java.util.ArrayList; +import java.util.Arrays; import java.util.List; import java.util.Locale; import java.util.Set; import org.apache.hc.core5.http.ParseException; import org.apache.hc.core5.http.io.entity.EntityUtils; +import org.opensearch.action.search.SearchResponse; import org.opensearch.client.Response; import org.opensearch.client.RestClient; +import org.opensearch.common.xcontent.json.JsonXContent; import org.opensearch.core.xcontent.XContentParser; import org.opensearch.forecast.constant.ForecastCommonName; import org.opensearch.forecast.model.ForecastTaskProfile; import org.opensearch.forecast.model.Forecaster; +import org.opensearch.search.SearchHit; import org.opensearch.timeseries.AbstractSyntheticDataTest; import org.opensearch.timeseries.TestHelpers; import org.opensearch.timeseries.TimeSeriesAnalyticsPlugin; @@ -158,4 +162,14 @@ protected List waitUntilTaskReachState(String forecasterId, Set return results; } + protected List toHits(Response response) throws UnsupportedOperationException, IOException { + SearchResponse searchResponse = SearchResponse + .fromXContent(createParser(JsonXContent.jsonXContent, response.getEntity().getContent())); + long total = searchResponse.getHits().getTotalHits().value; + if (total == 0) { + return new ArrayList<>(); + } + return Arrays.asList(searchResponse.getHits().getHits()); + } + } diff --git a/src/test/java/org/opensearch/forecast/rest/ForecastRestApiIT.java b/src/test/java/org/opensearch/forecast/rest/ForecastRestApiIT.java index b45dfece1..3d35cac46 100644 --- a/src/test/java/org/opensearch/forecast/rest/ForecastRestApiIT.java +++ b/src/test/java/org/opensearch/forecast/rest/ForecastRestApiIT.java @@ -11,9 +11,12 @@ import java.time.Duration; import java.time.Instant; import java.time.temporal.ChronoUnit; +import java.util.Arrays; +import java.util.HashSet; import java.util.List; import java.util.Locale; import java.util.Map; +import java.util.Set; import java.util.regex.Matcher; import java.util.regex.Pattern; @@ -27,6 +30,7 @@ import org.opensearch.forecast.AbstractForecastSyntheticDataTest; import org.opensearch.forecast.model.ForecastTaskProfile; import org.opensearch.forecast.settings.ForecastEnabledSetting; +import org.opensearch.search.SearchHit; import org.opensearch.timeseries.TestHelpers; import org.opensearch.timeseries.constant.CommonMessages; import org.opensearch.timeseries.model.EntityTaskProfile; @@ -39,16 +43,22 @@ /** * Test the following Restful API: - * - Suggest - * - Validate - * - Create - * - run once + * - top forecast * - start * - stop + * - Create + * - run once + * - Validate + * - Suggest * - update */ public class ForecastRestApiIT extends AbstractForecastSyntheticDataTest { public static final int MAX_RETRY_TIMES = 200; + private static final String CITY_NAME = "cityName"; + private static final String CONFIDENCE_INTERVAL_WIDTH = "confidence_interval_width"; + private static final String FORECAST_VALUE = "forecast_value"; + private static final String MIN_CONFIDENCE_INTERVAL = "MIN_CONFIDENCE_INTERVAL_WIDTH"; + private static final String MAX_CONFIDENCE_INTERVAL = "MAX_CONFIDENCE_INTERVAL_WIDTH"; @Override @Before @@ -85,10 +95,10 @@ private static Instant loadSparseCategoryData(int trainTestSplit) throws Excepti JsonObject row = data.get(i); // Get the value of the "cityName" field - String cityName = row.get("cityName").getAsString(); + String cityName = row.get(CITY_NAME).getAsString(); // Replace the field based on the value of "cityName" - row.remove("cityName"); // Remove the original "cityName" field + row.remove(CITY_NAME); // Remove the original "cityName" field if ("Phoenix".equals(cityName)) { if (phonenixIndex % 2 == 0) { @@ -539,7 +549,7 @@ public void testSuggestSparseData() throws Exception { */ public void testFailToSuggest() throws Exception { int trainTestSplit = 100; - String categoricalField = "cityName"; + String categoricalField = CITY_NAME; GenData dataGenerated = genUniformSingleFeatureData( 70, trainTestSplit, @@ -1931,7 +1941,7 @@ public void testCreate() throws Exception { ); MatcherAssert.assertThat(ex.getMessage(), containsString("Can't create more than 1 feature(s)")); - // case 2: create forecaster with custom index + // Case 2: users cannot specify forecaster id when creating a forecaster forecasterDef = "{\n" + " \"name\": \"Second-Test-Forecaster-4\",\n" + " \"description\": \"ok rate\",\n" @@ -1946,28 +1956,11 @@ public void testCreate() throws Exception { + " \"feature_enabled\": true,\n" + " \"importance\": 1,\n" + " \"aggregation_query\": {\n" - + " \"filtered_max_1\": {\n" - + " \"filter\": {\n" - + " \"bool\": {\n" - + " \"must\": [\n" - + " {\n" - + " \"range\": {\n" - + " \"timestamp\": {\n" - + " \"lt\": %d\n" - + " }\n" - + " }\n" - + " }\n" - + " ]\n" - + " }\n" - + " },\n" - + " \"aggregations\": {\n" + " \"max1\": {\n" + " \"max\": {\n" + " \"field\": \"visitCount\"\n" + " }\n" + " }\n" - + " }\n" - + " }\n" + " }\n" + " }\n" + " ],\n" @@ -1989,26 +1982,25 @@ public void testCreate() throws Exception { + " \"interval\": 10,\n" + " \"unit\": \"MINUTES\"\n" + " }\n" - + " },\n" - + " \"result_index\": \"opensearch-forecast-result-b\"\n" + + " }\n" + "}"; // +1 to make sure it is big enough windowDelayMinutes = Duration.between(trainTime, Instant.now()).toMinutes() + 1; - // we have 100 timestamps (2 entities per timestamp). Timestamps are 10 minutes apart. If we subtract 70 * 10 = 700 minutes, we have - // sparse data. - String formattedForecaster2 = String.format(Locale.ROOT, forecasterDef, RULE_DATASET_NAME, filterTimestamp, windowDelayMinutes); + final String formattedForecasterId = String.format(Locale.ROOT, forecasterDef, RULE_DATASET_NAME, windowDelayMinutes); + String blahId = "__blah__"; Response response = TestHelpers .makeRequest( client(), "POST", String.format(Locale.ROOT, CREATE_FORECASTER), - ImmutableMap.of(), - TestHelpers.toHttpEntity(formattedForecaster2), + ImmutableMap.of(RestHandlerUtils.FORECASTER_ID, blahId), + TestHelpers.toHttpEntity(formattedForecasterId), null ); Map responseMap = entityAsMap(response); - assertEquals("opensearch-forecast-result-b", ((Map) responseMap.get("forecaster")).get("result_index")); + String forecasterId = (String) responseMap.get("_id"); + assertNotEquals("response is missing Id", blahId, forecasterId); } public void testRunOnce() throws Exception { @@ -2054,12 +2046,14 @@ public void testRunOnce() throws Exception { + " \"interval\": 10,\n" + " \"unit\": \"MINUTES\"\n" + " }\n" - + " }\n" + + " },\n" + + " \"result_index\": \"opensearch-forecast-result-b\",\n" + + " \"category_field\": [\"%s\"]\n" + "}"; // +1 to make sure it is big enough long windowDelayMinutes = Duration.between(trainTime, Instant.now()).toMinutes() + 1; - final String formattedForecaster = String.format(Locale.ROOT, forecasterDef, RULE_DATASET_NAME, windowDelayMinutes); + final String formattedForecaster = String.format(Locale.ROOT, forecasterDef, RULE_DATASET_NAME, windowDelayMinutes, CITY_NAME); Response response = TestHelpers .makeRequest( client(), @@ -2071,6 +2065,7 @@ public void testRunOnce() throws Exception { ); Map responseMap = entityAsMap(response); String forecasterId = (String) responseMap.get("_id"); + assertEquals("opensearch-forecast-result-b", ((Map) responseMap.get("forecaster")).get("result_index")); // run once response = TestHelpers @@ -2100,6 +2095,30 @@ public void testRunOnce() throws Exception { int total = (int) (((Map) ((Map) responseMap.get("hits")).get("total")).get("value")); assertTrue("actual: " + total, total > 40); + List hits = toHits(response); + long forecastFrom = -1; + for (SearchHit hit : hits) { + Map source = hit.getSourceAsMap(); + if (source.get("forecast_value") != null) { + forecastFrom = (long) (source.get("data_end_time")); + break; + } + } + assertTrue(forecastFrom != -1); + + // top forecast verification + minConfidenceIntervalVerification(forecasterId, forecastFrom); + maxConfidenceIntervalVerification(forecasterId, forecastFrom); + minForecastValueVerification(forecasterId, forecastFrom); + maxForecastValueVerification(forecasterId, forecastFrom); + distanceToThresholdGreaterThan(forecasterId, forecastFrom); + distanceToThresholdGreaterThanEqual(forecasterId, forecastFrom); + distanceToThresholdLessThan(forecasterId, forecastFrom); + distanceToThresholdLessThanEqual(forecasterId, forecastFrom); + customMaxForecastValue(forecasterId, forecastFrom); + customMinForecastValue(forecasterId, forecastFrom); + topForecastSizeVerification(forecasterId, forecastFrom); + // case 2: cannot run once while forecaster is started response = TestHelpers .makeRequest( @@ -2144,6 +2163,442 @@ public void testRunOnce() throws Exception { assertEquals(forecasterId, responseMap.get("_id")); } + private void maxForecastValueVerification(String forecasterId, long forecastFrom) throws IOException { + Response response; + Map responseMap; + String topForcastRequest; + List parsedBuckets; + double previousValue; + topForcastRequest = String + .format( + Locale.ROOT, + "{\n" + + " \"split_by\": \"%s\",\n" + + " \"filter_by\": \"BUILD_IN_QUERY\",\n" + + " \"build_in_query\": \"MAX_VALUE_WITHIN_THE_HORIZON\",\n" + + " \"forecast_from\": %d,\n" + + " \"run_once\": true\n" + + "}", + CITY_NAME, + forecastFrom + ); + + response = TestHelpers + .makeRequest( + client(), + "POST", + String.format(Locale.ROOT, TOP_FORECASTER, forecasterId), + ImmutableMap.of(), + TestHelpers.toHttpEntity(topForcastRequest), + null + ); + responseMap = entityAsMap(response); + parsedBuckets = (List) responseMap.get("buckets"); + assertTrue("actual content: " + parsedBuckets, parsedBuckets.size() == 2); + + previousValue = Double.MAX_VALUE; // Initialize to positive infinity + double largestValue = Double.MIN_VALUE; + + largestValue = isDesc(parsedBuckets, previousValue, largestValue, "MAX_VALUE_WITHIN_THE_HORIZON"); + + String maxValueRequest = String + .format( + Locale.ROOT, + "{\n" + + " \"size\": 1,\n" + + " \"query\": {\n" + + " \"bool\": {\n" + + " \"filter\": [\n" + + " {\n" + + " \"range\": {\n" + + " \"data_end_time\": {\n" + + " \"from\": %d,\n" + + " \"to\": null,\n" + + " \"include_lower\": true\n" + + " }\n" + + " }\n" + + " }\n" + + " ]\n" + + " }\n" + + " },\n" + + " \"sort\": [\n" + + " {\n" + + " \"%s\": {\n" + + " \"order\": \"desc\"\n" + + " }\n" + + " }\n" + + " ]\n" + + "}", + forecastFrom, + FORECAST_VALUE + ); + + Response maxValueResponse = TestHelpers + .makeRequest(client(), "GET", SEARCH_RESULTS, ImmutableMap.of(), TestHelpers.toHttpEntity(maxValueRequest), null); + List maxValueHits = toHits(maxValueResponse); + assertEquals("actual: " + maxValueHits, 1, maxValueHits.size()); + double maxValue = (double) (maxValueHits.get(0).getSourceAsMap().get(FORECAST_VALUE)); + assertEquals(String.format(Locale.ROOT, "actual: %f, expect: %f", maxValue, largestValue), maxValue, largestValue, 0.001); + } + + private void minForecastValueVerification(String forecasterId, long forecastFrom) throws IOException { + Response response; + Map responseMap; + String topForcastRequest; + List parsedBuckets; + Set cities; + topForcastRequest = String + .format( + Locale.ROOT, + "{\n" + + " \"split_by\": \"%s\",\n" + + " \"filter_by\": \"BUILD_IN_QUERY\",\n" + + " \"build_in_query\": \"MIN_VALUE_WITHIN_THE_HORIZON\",\n" + + " \"forecast_from\": %d,\n" + + " \"run_once\": true\n" + + "}", + CITY_NAME, + forecastFrom + ); + + response = TestHelpers + .makeRequest( + client(), + "POST", + String.format(Locale.ROOT, TOP_FORECASTER, forecasterId), + ImmutableMap.of(), + TestHelpers.toHttpEntity(topForcastRequest), + null + ); + responseMap = entityAsMap(response); + parsedBuckets = (List) responseMap.get("buckets"); + assertTrue("actual content: " + parsedBuckets, parsedBuckets.size() == 2); + + double previousValue = -Double.MAX_VALUE; // Initialize to negative infinity + double smallestValue = Double.MAX_VALUE; + cities = new HashSet<>(); + + smallestValue = isAsc(parsedBuckets, cities, previousValue, smallestValue, "MIN_VALUE_WITHIN_THE_HORIZON"); + + String minValueRequest = String + .format( + Locale.ROOT, + "{\n" + + " \"size\": 1,\n" + + " \"query\": {\n" + + " \"bool\": {\n" + + " \"filter\": [\n" + + " {\n" + + " \"range\": {\n" + + " \"data_end_time\": {\n" + + " \"from\": %d,\n" + + " \"to\": null,\n" + + " \"include_lower\": true\n" + + " }\n" + + " }\n" + + " }\n" + + " ]\n" + + " }\n" + + " },\n" + + " \"sort\": [\n" + + " {\n" + + " \"%s\": {\n" + + " \"order\": \"asc\"\n" + + " }\n" + + " }\n" + + " ]\n" + + "}", + forecastFrom, + FORECAST_VALUE + ); + + Response minValueResponse = TestHelpers + .makeRequest(client(), "GET", SEARCH_RESULTS, ImmutableMap.of(), TestHelpers.toHttpEntity(minValueRequest), null); + List minValueHits = toHits(minValueResponse); + assertEquals("actual: " + minValueHits, 1, minValueHits.size()); + double minValue = (double) (minValueHits.get(0).getSourceAsMap().get("forecast_value")); + assertEquals(String.format(Locale.ROOT, "actual: %f, expect: %f", minValue, smallestValue), minValue, smallestValue, 0.001); + } + + private double isAsc(List parsedBuckets, Set cities, double previousValue, double smallestValue, String valueKey) { + for (Object obj : parsedBuckets) { + assertTrue("Each element in the list must be a Map.", obj instanceof Map); + + @SuppressWarnings("unchecked") + Map bucket = (Map) obj; + + // Extract value using keys + Object valueObj = bucket.get(valueKey); + assertTrue("actual: " + valueObj, valueObj instanceof Number); + + double value = ((Number) valueObj).doubleValue(); + if (smallestValue > value) { + smallestValue = value; + } + + // Check ascending order + assertTrue(String.format(Locale.ROOT, "value %f previousValue %f", value, previousValue), value >= previousValue); + + previousValue = value; + + // Extract the key + Object keyObj = bucket.get("key"); + assertTrue("actual: " + keyObj, keyObj instanceof Map); + + @SuppressWarnings("unchecked") + Map keyMap = (Map) keyObj; + String cityName = (String) keyMap.get(CITY_NAME); + + assertTrue("cityName is null", cityName != null); + + // Check that service is either "Phoenix" or "Scottsdale" + assertTrue("cityName is " + cityName, cityName.equals("Phoenix") || cityName.equals("Scottsdale")); + + // Check for unique services + assertTrue("Duplicate city found: " + cityName, cities.add(cityName)); + } + return smallestValue; + } + + private void maxConfidenceIntervalVerification(String forecasterId, long forecastFrom) throws IOException { + Response response; + Map responseMap; + String topForcastRequest; + List parsedBuckets; + double previousWidth; + topForcastRequest = String + .format( + Locale.ROOT, + "{\n" + + " \"split_by\": \"%s\",\n" + + " \"filter_by\": \"BUILD_IN_QUERY\",\n" + + " \"build_in_query\": \"MAX_CONFIDENCE_INTERVAL_WIDTH\",\n" + + " \"forecast_from\": %d,\n" + + " \"run_once\": true\n" + + "}", + CITY_NAME, + forecastFrom + ); + + response = TestHelpers + .makeRequest( + client(), + "POST", + String.format(Locale.ROOT, TOP_FORECASTER, forecasterId), + ImmutableMap.of(), + TestHelpers.toHttpEntity(topForcastRequest), + null + ); + responseMap = entityAsMap(response); + parsedBuckets = (List) responseMap.get("buckets"); + assertTrue("actual content: " + parsedBuckets, parsedBuckets.size() == 2); + + previousWidth = Double.MAX_VALUE; // Initialize to positive infinity + double largestWidth = Double.MIN_VALUE; + + largestWidth = isDesc(parsedBuckets, previousWidth, largestWidth, MAX_CONFIDENCE_INTERVAL); + + String maxConfidenceIntervalRequest = String + .format( + Locale.ROOT, + "{\n" + + " \"size\": 1,\n" + + " \"query\": {\n" + + " \"bool\": {\n" + + " \"filter\": [\n" + + " {\n" + + " \"term\": {\n" + + " \"horizon_index\": 24\n" + + " }\n" + + " },\n" + + " {\n" + + " \"range\": {\n" + + " \"data_end_time\": {\n" + + " \"from\": %d,\n" + + " \"to\": null,\n" + + " \"include_lower\": true\n" + + " }\n" + + " }\n" + + " }\n" + + " ]\n" + + " }\n" + + " },\n" + + " \"sort\": [\n" + + " {\n" + + " \"%s\": {\n" + + " \"order\": \"desc\"\n" + + " }\n" + + " }\n" + + " ]\n" + + "}", + forecastFrom, + CONFIDENCE_INTERVAL_WIDTH + ); + + Response maxConfidenceIntervalResponse = TestHelpers + .makeRequest(client(), "GET", SEARCH_RESULTS, ImmutableMap.of(), TestHelpers.toHttpEntity(maxConfidenceIntervalRequest), null); + List maxConfidenceIntervalHits = toHits(maxConfidenceIntervalResponse); + assertEquals("actual: " + maxConfidenceIntervalHits, 1, maxConfidenceIntervalHits.size()); + double maxWidth = (double) (maxConfidenceIntervalHits.get(0).getSourceAsMap().get(CONFIDENCE_INTERVAL_WIDTH)); + assertEquals(String.format(Locale.ROOT, "actual: %f, expect: %f", maxWidth, largestWidth), maxWidth, largestWidth, 0.001); + } + + private void validateKeyValue( + Map keyMap, + String keyName, + String valueDescription, + Set expectedValues, + Set uniqueValuesSet + ) { + // Extract the value from the keyMap using the keyName + String value = (String) keyMap.get(keyName); + + // Ensure the value is not null + assertTrue(valueDescription + " is null", value != null); + + // Check that the value is one of the expected values + assertTrue(valueDescription + " is " + value, expectedValues.contains(value)); + + // Check for uniqueness in the provided set + assertTrue("Duplicate " + valueDescription + " found: " + value, uniqueValuesSet.add(value)); + } + + private double isDesc( + List parsedBuckets, + double previousWidth, + Set uniqueValuesSet, + double largestWidth, + String valueKey, + String keyName, + String valueDescription, + Set expectedValues + ) { + for (Object obj : parsedBuckets) { + assertTrue("Each element in the list must be a Map.", obj instanceof Map); + + @SuppressWarnings("unchecked") + Map bucket = (Map) obj; + + // Extract valueKey + Object widthObj = bucket.get(valueKey); + assertTrue("actual: " + widthObj, widthObj instanceof Number); + + double width = ((Number) widthObj).doubleValue(); + if (largestWidth < width) { + largestWidth = width; + } + + // Check descending order + assertTrue(String.format(Locale.ROOT, "width %f previousWidth %f", width, previousWidth), width <= previousWidth); + + previousWidth = width; + + // Extract the key + Object keyObj = bucket.get("key"); + assertTrue("actual: " + keyObj, keyObj instanceof Map); + + @SuppressWarnings("unchecked") + Map keyMap = (Map) keyObj; + + // Use the helper method for validation + validateKeyValue(keyMap, keyName, valueDescription, expectedValues, uniqueValuesSet); + } + return largestWidth; + } + + private double isDesc(List parsedBuckets, double previousWidth, double largestWidth, String valueKey) { + Set cities = new HashSet<>(); + Set expectedCities = new HashSet<>(Arrays.asList("Phoenix", "Scottsdale")); + return isDesc(parsedBuckets, previousWidth, cities, largestWidth, valueKey, CITY_NAME, "cityName", expectedCities); + } + + private double isDescTwoCategorical(List parsedBuckets, double previousWidth, double largestWidth, String valueKey) { + Set regions = new HashSet<>(); + Set expectedRegions = new HashSet<>(Arrays.asList("pdx", "iad")); + return isDesc(parsedBuckets, previousWidth, regions, largestWidth, valueKey, "region", "regionName", expectedRegions); + } + + private void minConfidenceIntervalVerification(String forecasterId, long forecastFrom) throws IOException { + Response response; + Map responseMap; + String topForcastRequest = String + .format( + Locale.ROOT, + "{\n" + + " \"split_by\": \"%s\",\n" + + " \"filter_by\": \"BUILD_IN_QUERY\",\n" + + " \"build_in_query\": \"MIN_CONFIDENCE_INTERVAL_WIDTH\",\n" + + " \"forecast_from\": %d,\n" + + " \"run_once\": true\n" + + "}", + CITY_NAME, + forecastFrom + ); + + response = TestHelpers + .makeRequest( + client(), + "POST", + String.format(Locale.ROOT, TOP_FORECASTER, forecasterId), + ImmutableMap.of(), + TestHelpers.toHttpEntity(topForcastRequest), + null + ); + responseMap = entityAsMap(response); + List parsedBuckets = (List) responseMap.get("buckets"); + assertTrue("actual content: " + parsedBuckets, parsedBuckets.size() == 2); + + double previousWidth = -Double.MAX_VALUE; // Initialize to negative infinity + double smallestWidth = Double.MAX_VALUE; + Set cities = new HashSet<>(); + + smallestWidth = isAsc(parsedBuckets, cities, previousWidth, smallestWidth, MIN_CONFIDENCE_INTERVAL); + + String minConfidenceIntervalRequest = String + .format( + Locale.ROOT, + "{\n" + + " \"size\": 1,\n" + + " \"query\": {\n" + + " \"bool\": {\n" + + " \"filter\": [\n" + + " {\n" + + " \"term\": {\n" + + " \"horizon_index\": 24\n" + + " }\n" + + " },\n" + + " {\n" + + " \"range\": {\n" + + " \"data_end_time\": {\n" + + " \"from\": %d,\n" + + " \"to\": null,\n" + + " \"include_lower\": true\n" + + " }\n" + + " }\n" + + " }\n" + + " ]\n" + + " }\n" + + " },\n" + + " \"sort\": [\n" + + " {\n" + + " \"%s\": {\n" + + " \"order\": \"asc\"\n" + + " }\n" + + " }\n" + + " ]\n" + + "}", + forecastFrom, + CONFIDENCE_INTERVAL_WIDTH + ); + + Response minConfidenceIntervalResponse = TestHelpers + .makeRequest(client(), "GET", SEARCH_RESULTS, ImmutableMap.of(), TestHelpers.toHttpEntity(minConfidenceIntervalRequest), null); + List minConfidenceIntervalHits = toHits(minConfidenceIntervalResponse); + assertEquals("actual: " + minConfidenceIntervalHits, 1, minConfidenceIntervalHits.size()); + double minWidth = (double) (minConfidenceIntervalHits.get(0).getSourceAsMap().get(CONFIDENCE_INTERVAL_WIDTH)); + assertEquals(String.format(Locale.ROOT, "actual: %f, expect: %f", minWidth, smallestWidth), minWidth, smallestWidth, 0.001); + } + public Response searchTaskResult(String taskId) throws IOException { Response response = TestHelpers .makeRequest( @@ -2153,7 +2608,9 @@ public Response searchTaskResult(String taskId) throws IOException { ImmutableMap.of(), TestHelpers .toHttpEntity( - "{\"query\":{\"bool\":{\"filter\":[{\"term\":{\"task_id\":\"" + taskId + "\"}}]}},\"track_total_hits\":true}" + "{\"query\":{\"bool\":{\"filter\":[{\"term\":{\"task_id\":\"" + + taskId + + "\"}}]}},\"track_total_hits\":true,\"size\":10000}" ), null ); @@ -2454,4 +2911,624 @@ public void testUpdateDetector() throws Exception { responseMap = entityAsMap(response); assertEquals(responseMap.get("last_update_time"), responseMap.get("last_ui_breaking_change_time")); } + + private void distanceToThresholdGreaterThan(String forecasterId, long forecastFrom) throws IOException { + distanceToThresholdGreaterTemplate(forecasterId, forecastFrom, false); + } + + private void distanceToThresholdGreaterTemplate(String forecasterId, long forecastFrom, boolean equal) throws IOException { + Response response; + Map responseMap; + String topForcastRequest; + List parsedBuckets; + double previousWidth; + int threshold = 4587; + topForcastRequest = String + .format( + Locale.ROOT, + "{\n" + + " \"split_by\": \"%s\",\n" + + " \"filter_by\": \"BUILD_IN_QUERY\",\n" + + " \"build_in_query\": \"DISTANCE_TO_THRESHOLD_VALUE\",\n" + + " \"forecast_from\": %d,\n" + + " \"run_once\": true,\n" + + " \"threshold\": %d,\n" + + " \"relation_to_threshold\": \"%s\"" + + "}", + CITY_NAME, + forecastFrom, + threshold, + equal ? "GREATER_THAN_OR_EQUAL_TO" : "GREATER_THAN" + ); + + response = TestHelpers + .makeRequest( + client(), + "POST", + String.format(Locale.ROOT, TOP_FORECASTER, forecasterId), + ImmutableMap.of(), + TestHelpers.toHttpEntity(topForcastRequest), + null + ); + responseMap = entityAsMap(response); + parsedBuckets = (List) responseMap.get("buckets"); + assertTrue("actual content: " + parsedBuckets, parsedBuckets.size() == 2); + + previousWidth = Double.MAX_VALUE; // Initialize to positive infinity + double largestValue = Double.MIN_VALUE; + + largestValue = isDesc(parsedBuckets, previousWidth, largestValue, "DISTANCE_TO_THRESHOLD_VALUE"); + + String maxDistanceToThresholdRequest = String + .format( + Locale.ROOT, + "{\n" + + " \"size\": 1,\n" + + " \"query\": {\n" + + " \"bool\": {\n" + + " \"filter\": [\n" + + " {\n" + + " \"range\": {\n" + + " \"forecast_value\": {\n" + + " \"%s\": " + + threshold + + "\n" + + " }\n" + + " }\n" + + " },\n" + + " {\n" + + " \"range\": {\n" + + " \"data_end_time\": {\n" + + " \"from\": %d,\n" + + " \"to\": null,\n" + + " \"include_lower\": true\n" + + " }\n" + + " }\n" + + " }\n" + + " ]\n" + + " }\n" + + " },\n" + + " \"sort\": [\n" + + " {\n" + + " \"forecast_value\": {\n" + + " \"order\": \"desc\"\n" + + " }\n" + + " }\n" + + " ]\n" + + "}", + equal ? "gte" : "gt", + forecastFrom + ); + + Response maxDistanceResponse = TestHelpers + .makeRequest(client(), "GET", SEARCH_RESULTS, ImmutableMap.of(), TestHelpers.toHttpEntity(maxDistanceToThresholdRequest), null); + List maxDistanceHits = toHits(maxDistanceResponse); + assertEquals("actual: " + maxDistanceHits, 1, maxDistanceHits.size()); + double maxValue = (double) (maxDistanceHits.get(0).getSourceAsMap().get("forecast_value")); + assertEquals(String.format(Locale.ROOT, "actual: %f, expect: %f", maxValue, largestValue), maxValue, largestValue, 0.001); + } + + private void distanceToThresholdGreaterThanEqual(String forecasterId, long forecastFrom) throws IOException { + distanceToThresholdGreaterTemplate(forecasterId, forecastFrom, true); + } + + private void distanceToThresholdLessThan(String forecasterId, long forecastFrom) throws IOException { + distanceToThresholdLessTemplate(forecasterId, forecastFrom, false); + } + + private void distanceToThresholdLessTemplate(String forecasterId, long forecastFrom, boolean equal) throws IOException { + Response response; + Map responseMap; + String topForcastRequest; + List parsedBuckets; + double previousWidth; + Set cities; + int threshold = 7000; + topForcastRequest = String + .format( + Locale.ROOT, + "{\n" + + " \"split_by\": \"%s\",\n" + + " \"filter_by\": \"BUILD_IN_QUERY\",\n" + + " \"build_in_query\": \"DISTANCE_TO_THRESHOLD_VALUE\",\n" + + " \"forecast_from\": %d,\n" + + " \"run_once\": true,\n" + + " \"threshold\": %d,\n" + + " \"relation_to_threshold\": \"%s\"" + + "}", + CITY_NAME, + forecastFrom, + threshold, + equal ? "LESS_THAN_OR_EQUAL_TO" : "LESS_THAN" + ); + + response = TestHelpers + .makeRequest( + client(), + "POST", + String.format(Locale.ROOT, TOP_FORECASTER, forecasterId), + ImmutableMap.of(), + TestHelpers.toHttpEntity(topForcastRequest), + null + ); + responseMap = entityAsMap(response); + parsedBuckets = (List) responseMap.get("buckets"); + assertTrue("actual content: " + parsedBuckets, parsedBuckets.size() == 2); + + previousWidth = Double.MIN_VALUE; // Initialize to negative infinity + double smallestValue = Double.MAX_VALUE; + cities = new HashSet<>(); + + smallestValue = isAsc(parsedBuckets, cities, previousWidth, smallestValue, "DISTANCE_TO_THRESHOLD_VALUE"); + + String maxDistanceToThresholdRequest = String + .format( + Locale.ROOT, + "{\n" + + " \"size\": 1,\n" + + " \"query\": {\n" + + " \"bool\": {\n" + + " \"filter\": [\n" + + " {\n" + + " \"range\": {\n" + + " \"forecast_value\": {\n" + + " \"%s\": " + + threshold + + "\n" + + " }\n" + + " }\n" + + " },\n" + + " {\n" + + " \"range\": {\n" + + " \"data_end_time\": {\n" + + " \"from\": %d,\n" + + " \"to\": null,\n" + + " \"include_lower\": true\n" + + " }\n" + + " }\n" + + " }\n" + + " ]\n" + + " }\n" + + " },\n" + + " \"sort\": [\n" + + " {\n" + + " \"forecast_value\": {\n" + + " \"order\": \"asc\"\n" + + " }\n" + + " }\n" + + " ]\n" + + "}", + equal ? "lte" : "lt", + forecastFrom + ); + + Response maxDistanceResponse = TestHelpers + .makeRequest(client(), "GET", SEARCH_RESULTS, ImmutableMap.of(), TestHelpers.toHttpEntity(maxDistanceToThresholdRequest), null); + List maxDistanceHits = toHits(maxDistanceResponse); + assertEquals("actual: " + maxDistanceHits, 1, maxDistanceHits.size()); + double maxValue = (double) (maxDistanceHits.get(0).getSourceAsMap().get("forecast_value")); + assertEquals(String.format(Locale.ROOT, "actual: %f, expect: %f", maxValue, smallestValue), maxValue, smallestValue, 0.001); + } + + private void distanceToThresholdLessThanEqual(String forecasterId, long forecastFrom) throws IOException { + distanceToThresholdLessTemplate(forecasterId, forecastFrom, true); + } + + private void customMaxForecastValue(String forecasterId, long forecastFrom) throws IOException { + customForecastValueTemplate(forecasterId, forecastFrom, true); + } + + private void customMinForecastValue(String forecasterId, long forecastFrom) throws IOException { + customForecastValueTemplate(forecasterId, forecastFrom, false); + } + + private void customForecastValueTemplate(String forecasterId, long forecastFrom, boolean max) throws IOException { + Response response; + Map responseMap; + String topForcastRequest; + List parsedBuckets; + double previousValue; + topForcastRequest = String + .format( + Locale.ROOT, + "{\n" + + " \"forecast_from\": %d,\n" + + " \"filter_by\": \"CUSTOM_QUERY\",\n" + + " \"filter_query\": {\n" + + " \"nested\": {\n" + + " \"path\": \"entity\",\n" + + " \"query\": {\n" + + " \"bool\": {\n" + + " \"must\": [\n" + + " {\n" + + " \"term\": {\n" + + " \"entity.name\": \"%s\"\n" + + " }\n" + + " },\n" + + " {\n" + + " \"wildcard\": {\n" + + " \"entity.value\": \"S*\"\n" + + " }\n" + + " }\n" + + " ]\n" + + " }\n" + + " }\n" + + " }\n" + + " },\n" + + " \"subaggregations\": [\n" + + " {\n" + + " \"aggregation_query\": {\n" + + " \"forecast_value_max\": {\n" + + " \"%s\": {\n" + + " \"field\": \"forecast_value\"\n" + + " }\n" + + " }\n" + + " },\n" + + " \"order\": \"DESC\"\n" + + " }\n" + + " ],\n" + + " \"run_once\": true\n" + + "}", + forecastFrom, + CITY_NAME, + max ? "max" : "min" + ); + + response = TestHelpers + .makeRequest( + client(), + "POST", + String.format(Locale.ROOT, TOP_FORECASTER, forecasterId), + ImmutableMap.of(), + TestHelpers.toHttpEntity(topForcastRequest), + null + ); + responseMap = entityAsMap(response); + parsedBuckets = (List) responseMap.get("buckets"); + assertTrue("actual content: " + parsedBuckets, parsedBuckets.size() == 1); + + previousValue = Double.MAX_VALUE; // Initialize to positive infinity + double largestValue = Double.MIN_VALUE; + + largestValue = isDesc(parsedBuckets, previousValue, largestValue, "forecast_value_max"); + + String maxValueRequest = String + .format( + Locale.ROOT, + "{\n" + + " \"size\": 1,\n" + + " \"sort\": [\n" + + " {\n" + + " \"%s\": {\n" + + " \"order\": \"%s\"\n" + + " }\n" + + " }\n" + + " ],\n" + + " \"query\": {\n" + + " \"bool\": {\n" + + " \"must\": [\n" + + " {\n" + + " \"nested\": {\n" + + " \"path\": \"entity\",\n" + + " \"query\": {\n" + + " \"bool\": {\n" + + " \"must\": [\n" + + " {\n" + + " \"term\": {\n" + + " \"entity.name\": \"%s\"\n" + + " }\n" + + " },\n" + + " {\n" + + " \"wildcard\": {\n" + + " \"entity.value\": \"S*\"\n" + + " }\n" + + " }\n" + + " ]\n" + + " }\n" + + " }\n" + + " }\n" + + " },\n" + + " {\n" + + " \"range\": {\n" + + " \"data_end_time\": {\n" + + " \"from\": %d,\n" + + " \"to\": null,\n" + + " \"include_lower\": true\n" + + " }\n" + + " }\n" + + " }\n" + + " ]\n" + + " }\n" + + " }\n" + + "}", + FORECAST_VALUE, // First %s + max ? "desc" : "asc", // Second %s + CITY_NAME, // Third %s + forecastFrom // %d + ); + + Response maxValueResponse = TestHelpers + .makeRequest(client(), "GET", SEARCH_RESULTS, ImmutableMap.of(), TestHelpers.toHttpEntity(maxValueRequest), null); + List maxValueHits = toHits(maxValueResponse); + assertEquals("actual: " + maxValueHits, 1, maxValueHits.size()); + double maxValue = (double) (maxValueHits.get(0).getSourceAsMap().get(FORECAST_VALUE)); + assertEquals(String.format(Locale.ROOT, "actual: %f, expect: %f", maxValue, largestValue), maxValue, largestValue, 0.001); + } + + public void testTopForecast() throws Exception { + Instant trainTime = loadTwoCategoricalFieldData(200); + // case 1: happy case + String forecasterDef = "{\n" + + " \"name\": \"Second-Test-Forecaster-4\",\n" + + " \"description\": \"ok rate\",\n" + + " \"time_field\": \"timestamp\",\n" + + " \"indices\": [\n" + + " \"%s\"\n" + + " ],\n" + + " \"feature_attributes\": [\n" + + " {\n" + + " \"feature_id\": \"max1\",\n" + + " \"feature_name\": \"max1\",\n" + + " \"feature_enabled\": true,\n" + + " \"importance\": 1,\n" + + " \"aggregation_query\": {\n" + + " \"max1\": {\n" + + " \"max\": {\n" + + " \"field\": \"visitCount\"\n" + + " }\n" + + " }\n" + + " }\n" + + " }\n" + + " ],\n" + + " \"window_delay\": {\n" + + " \"period\": {\n" + + " \"interval\": %d,\n" + + " \"unit\": \"MINUTES\"\n" + + " }\n" + + " },\n" + + " \"ui_metadata\": {\n" + + " \"aabb\": {\n" + + " \"ab\": \"bb\"\n" + + " }\n" + + " },\n" + + " \"schema_version\": 2,\n" + + " \"horizon\": 24,\n" + + " \"forecast_interval\": {\n" + + " \"period\": {\n" + + " \"interval\": 10,\n" + + " \"unit\": \"MINUTES\"\n" + + " }\n" + + " },\n" + + " \"result_index\": \"opensearch-forecast-result-b\",\n" + + " \"category_field\": [%s]\n" + + "}"; + + // +1 to make sure it is big enough + long windowDelayMinutes = Duration.between(trainTime, Instant.now()).toMinutes() + 1; + final String formattedForecaster = String + .format(Locale.ROOT, forecasterDef, RULE_DATASET_NAME, windowDelayMinutes, "\"account\",\"region\""); + Response response = TestHelpers + .makeRequest( + client(), + "POST", + String.format(Locale.ROOT, CREATE_FORECASTER), + ImmutableMap.of(), + TestHelpers.toHttpEntity(formattedForecaster), + null + ); + Map responseMap = entityAsMap(response); + String forecasterId = (String) responseMap.get("_id"); + assertEquals("opensearch-forecast-result-b", ((Map) responseMap.get("forecaster")).get("result_index")); + + // run once + response = TestHelpers + .makeRequest( + client(), + "POST", + String.format(Locale.ROOT, RUN_ONCE_FORECASTER, forecasterId), + ImmutableMap.of(), + (HttpEntity) null, + null + ); + + ForecastTaskProfile forecastTaskProfile = (ForecastTaskProfile) waitUntilTaskReachState( + forecasterId, + ImmutableSet.of(TaskState.TEST_COMPLETE.name()), + client() + ).get(0); + assertTrue(forecastTaskProfile != null); + assertTrue(forecastTaskProfile.getTask().isLatest()); + + responseMap = entityAsMap(response); + String taskId = (String) responseMap.get(EntityTaskProfile.TASK_ID_FIELD); + assertEquals(taskId, forecastTaskProfile.getTaskId()); + + response = searchTaskResult(taskId); + responseMap = entityAsMap(response); + int total = (int) (((Map) ((Map) responseMap.get("hits")).get("total")).get("value")); + assertTrue("actual: " + total, total > 40); + + List hits = toHits(response); + long forecastFrom = -1; + for (SearchHit hit : hits) { + Map source = hit.getSourceAsMap(); + if (source.get("forecast_value") != null) { + forecastFrom = (long) (source.get("data_end_time")); + break; + } + } + assertTrue(forecastFrom != -1); + + // top forecast verification + customForecastValueDoubleCategories(forecasterId, forecastFrom, true, taskId); + customForecastValueDoubleCategories(forecasterId, forecastFrom, false, taskId); + } + + private void topForecastSizeVerification(String forecasterId, long forecastFrom) throws IOException { + Response response; + Map responseMap; + String topForcastRequest = String + .format( + Locale.ROOT, + "{\n" + + " \"size\": 1,\n" + + " \"split_by\": \"%s\",\n" + + " \"filter_by\": \"BUILD_IN_QUERY\",\n" + + " \"build_in_query\": \"MIN_CONFIDENCE_INTERVAL_WIDTH\",\n" + + " \"forecast_from\": %d,\n" + + " \"run_once\": true\n" + + "}", + CITY_NAME, + forecastFrom + ); + + response = TestHelpers + .makeRequest( + client(), + "POST", + String.format(Locale.ROOT, TOP_FORECASTER, forecasterId), + ImmutableMap.of(), + TestHelpers.toHttpEntity(topForcastRequest), + null + ); + responseMap = entityAsMap(response); + List parsedBuckets = (List) responseMap.get("buckets"); + assertTrue("actual content: " + parsedBuckets, parsedBuckets.size() == 1); + } + + private void customForecastValueDoubleCategories(String forecasterId, long forecastFrom, boolean max, String taskId) + throws IOException { + Response response; + Map responseMap; + String topForcastRequest; + List parsedBuckets; + double previousValue; + topForcastRequest = String + .format( + Locale.ROOT, + "{\n" + + " \"forecast_from\": %d,\n" + + " \"filter_by\": \"CUSTOM_QUERY\",\n" + + " \"filter_query\": {\n" + + " \"nested\": {\n" + + " \"path\": \"entity\",\n" + + " \"query\": {\n" + + " \"bool\": {\n" + + " \"must\": [\n" + + " {\n" + + " \"term\": {\n" + + " \"entity.name\": \"%s\"\n" + + " }\n" + + " },\n" + + " {\n" + + " \"wildcard\": {\n" + + " \"entity.value\": \"i*\"\n" + + " }\n" + + " }\n" + + " ]\n" + + " }\n" + + " }\n" + + " }\n" + + " },\n" + + " \"subaggregations\": [\n" + + " {\n" + + " \"aggregation_query\": {\n" + + " \"forecast_value_max\": {\n" + + " \"%s\": {\n" + + " \"field\": \"forecast_value\"\n" + + " }\n" + + " }\n" + + " },\n" + + " \"order\": \"DESC\"\n" + + " }\n" + + " ],\n" + + " \"run_once\": true,\n" + + " \"task_id\": \"%s\"\n" + + "}", + forecastFrom, + "region", + max ? "max" : "min", + taskId + ); + + response = TestHelpers + .makeRequest( + client(), + "POST", + String.format(Locale.ROOT, TOP_FORECASTER, forecasterId), + ImmutableMap.of(), + TestHelpers.toHttpEntity(topForcastRequest), + null + ); + responseMap = entityAsMap(response); + parsedBuckets = (List) responseMap.get("buckets"); + assertTrue("actual content: " + parsedBuckets, parsedBuckets.size() == 1); + + previousValue = Double.MAX_VALUE; // Initialize to positive infinity + double largestValue = Double.MIN_VALUE; + + largestValue = isDescTwoCategorical(parsedBuckets, previousValue, largestValue, "forecast_value_max"); + + String maxValueRequest = String + .format( + Locale.ROOT, + "{\n" + + " \"size\": 1,\n" + + " \"sort\": [\n" + + " {\n" + + " \"%s\": {\n" + + " \"order\": \"%s\"\n" + + " }\n" + + " }\n" + + " ],\n" + + " \"query\": {\n" + + " \"bool\": {\n" + + " \"must\": [\n" + + " {\n" + + " \"nested\": {\n" + + " \"path\": \"entity\",\n" + + " \"query\": {\n" + + " \"bool\": {\n" + + " \"must\": [\n" + + " {\n" + + " \"term\": {\n" + + " \"entity.name\": \"%s\"\n" + + " }\n" + + " },\n" + + " {\n" + + " \"wildcard\": {\n" + + " \"entity.value\": \"i*\"\n" + + " }\n" + + " }\n" + + " ]\n" + + " }\n" + + " }\n" + + " }\n" + + " },\n" + + " {\n" + + " \"range\": {\n" + + " \"data_end_time\": {\n" + + " \"from\": %d,\n" + + " \"to\": null,\n" + + " \"include_lower\": true\n" + + " }\n" + + " }\n" + + " }\n" + + " ]\n" + + " }\n" + + " }\n" + + "}", + FORECAST_VALUE, // First %s + max ? "desc" : "asc", // Second %s + "region", // Third %s + forecastFrom // %d + ); + + Response maxValueResponse = TestHelpers + .makeRequest(client(), "GET", SEARCH_RESULTS, ImmutableMap.of(), TestHelpers.toHttpEntity(maxValueRequest), null); + List maxValueHits = toHits(maxValueResponse); + assertEquals("actual: " + maxValueHits, 1, maxValueHits.size()); + double maxValue = (double) (maxValueHits.get(0).getSourceAsMap().get(FORECAST_VALUE)); + assertEquals(String.format(Locale.ROOT, "actual: %f, expect: %f", maxValue, largestValue), maxValue, largestValue, 0.001); + } } diff --git a/src/test/java/org/opensearch/forecast/rest/SecureForecastRestIT.java b/src/test/java/org/opensearch/forecast/rest/SecureForecastRestIT.java index acdcf68f8..c293af9d4 100644 --- a/src/test/java/org/opensearch/forecast/rest/SecureForecastRestIT.java +++ b/src/test/java/org/opensearch/forecast/rest/SecureForecastRestIT.java @@ -588,16 +588,6 @@ protected List waitUntilResultAvailable(RestClient client) throws Int return hits; } - private List toHits(Response response) throws UnsupportedOperationException, IOException { - SearchResponse searchResponse = SearchResponse - .fromXContent(createParser(JsonXContent.jsonXContent, response.getEntity().getContent())); - long total = searchResponse.getHits().getTotalHits().value; - if (total == 0) { - return new ArrayList<>(); - } - return Arrays.asList(searchResponse.getHits().getHits()); - } - private Response enableFilterBy() throws IOException { return TestHelpers .makeRequest( diff --git a/src/test/java/org/opensearch/timeseries/AbstractSyntheticDataTest.java b/src/test/java/org/opensearch/timeseries/AbstractSyntheticDataTest.java index b83e89f9f..26b7fe843 100644 --- a/src/test/java/org/opensearch/timeseries/AbstractSyntheticDataTest.java +++ b/src/test/java/org/opensearch/timeseries/AbstractSyntheticDataTest.java @@ -279,6 +279,43 @@ protected static Instant loadRuleData(int trainTestSplit) throws Exception { return loadData(RULE_DATASET_NAME, trainTestSplit, RULE_DATA_MAPPING); } + // convert 1 categorical field (cityName) rule data with two categorical field (account and region) rule data + protected static Instant loadTwoCategoricalFieldData(int trainTestSplit) throws Exception { + RestClient client = client(); + + String dataFileName = String.format(Locale.ROOT, "org/opensearch/ad/e2e/data/%s.data", RULE_DATASET_NAME); + + List data = readJsonArrayWithLimit(dataFileName, trainTestSplit); + + for (int i = 0; i < trainTestSplit && i < data.size(); i++) { + JsonObject jsonObject = data.get(i); + String city = jsonObject.get("cityName").getAsString(); + if (city.equals("Phoenix")) { + jsonObject.addProperty("account", "1234"); + jsonObject.addProperty("region", "iad"); + } else if (city.equals("Scottsdale")) { + jsonObject.addProperty("account", "5678"); + jsonObject.addProperty("region", "pdx"); + } + } + + String mapping = "{ \"mappings\": { \"properties\": { " + + "\"timestamp\": { \"type\": \"date\" }, " + + "\"visitCount\": { \"type\": \"integer\" }, " + + "\"cityName\": { \"type\": \"keyword\" }, " + + "\"account\": { \"type\": \"keyword\" }, " + + "\"region\": { \"type\": \"keyword\" } " + + "} } }"; + + bulkIndexTrainData(RULE_DATASET_NAME, data, trainTestSplit, client, mapping); + String trainTimeStr = data.get(trainTestSplit - 1).get("timestamp").getAsString(); + if (canBeParsedAsLong(trainTimeStr)) { + return Instant.ofEpochMilli(Long.parseLong(trainTimeStr)); + } else { + return Instant.parse(trainTimeStr); + } + } + public static boolean canBeParsedAsLong(String str) { if (str == null || str.isEmpty()) { return false; // Handle null or empty strings as not parsable diff --git a/src/test/java/org/opensearch/timeseries/TestHelpers.java b/src/test/java/org/opensearch/timeseries/TestHelpers.java index c36b4b686..0c8f45a63 100644 --- a/src/test/java/org/opensearch/timeseries/TestHelpers.java +++ b/src/test/java/org/opensearch/timeseries/TestHelpers.java @@ -118,6 +118,7 @@ import org.opensearch.forecast.model.ForecastResult; import org.opensearch.forecast.model.ForecastTask; import org.opensearch.forecast.model.Forecaster; +import org.opensearch.forecast.ratelimit.ForecastResultWriteRequest; import org.opensearch.index.get.GetResult; import org.opensearch.index.query.BoolQueryBuilder; import org.opensearch.index.query.MatchAllQueryBuilder; @@ -1585,6 +1586,16 @@ public static Entity randomEntity(Config config) { return entity; } + private static Entity randomEntity() { + String name = randomAlphaOfLength(10); + List values = new ArrayList<>(); + int size = random.nextInt(3) + 1; // At least one value + for (int i = 0; i < size; i++) { + values.add(randomAlphaOfLength(10)); + } + return Entity.createEntityByReordering(ImmutableMap.of(name, values)); + } + public static HttpEntity toHttpEntity(ToXContentObject object) throws IOException { return new StringEntity(toJsonString(object), APPLICATION_JSON); } @@ -2224,4 +2235,57 @@ public Job build() { } } + public static ForecastResultWriteRequest randomForecastResultWriteRequest() { + // Generate random values for required fields + long expirationEpochMs = Instant.now().plusSeconds(random.nextInt(3600)).toEpochMilli(); // Expire within the next hour + String forecasterId = randomAlphaOfLength(10); + RequestPriority priority = RequestPriority.MEDIUM; // Use NORMAL priority for testing + ForecastResult result = randomForecastResult(forecasterId); + String resultIndex = random.nextBoolean() ? randomAlphaOfLength(10) : null; // Randomly decide to set resultIndex or not + + return new ForecastResultWriteRequest(expirationEpochMs, forecasterId, priority, result, resultIndex); + } + + public static ForecastResult randomForecastResult(String forecasterId) { + String taskId = randomAlphaOfLength(10); + Double dataQuality = random.nextDouble(); + List featureData = ImmutableList.of(randomFeatureData()); + Instant dataStartTime = Instant.now().minusSeconds(random.nextInt(3600)); + Instant dataEndTime = Instant.now(); + Instant executionStartTime = Instant.now().minusSeconds(random.nextInt(3600)); + Instant executionEndTime = Instant.now(); + String error = random.nextBoolean() ? randomAlphaOfLength(20) : null; + Optional entity = random.nextBoolean() ? Optional.of(randomEntity()) : Optional.empty(); + User user = random.nextBoolean() ? randomUser() : null; + Integer schemaVersion = random.nextInt(10); + String featureId = randomAlphaOfLength(10); + Float forecastValue = random.nextFloat(); + Float lowerBound = forecastValue - random.nextFloat(); + Float upperBound = forecastValue + random.nextFloat(); + Instant forecastDataStartTime = dataEndTime.plusSeconds(random.nextInt(3600)); + Instant forecastDataEndTime = forecastDataStartTime.plusSeconds(random.nextInt(3600)); + Integer horizonIndex = random.nextInt(100); + + return new ForecastResult( + forecasterId, + taskId, + dataQuality, + featureData, + dataStartTime, + dataEndTime, + executionStartTime, + executionEndTime, + error, + entity, + user, + schemaVersion, + featureId, + forecastValue, + lowerBound, + upperBound, + forecastDataStartTime, + forecastDataEndTime, + horizonIndex + ); + } } diff --git a/src/test/java/org/opensearch/timeseries/transport/ForecastResultBulkTransportActionTests.java b/src/test/java/org/opensearch/timeseries/transport/ForecastResultBulkTransportActionTests.java new file mode 100644 index 000000000..f5cead05e --- /dev/null +++ b/src/test/java/org/opensearch/timeseries/transport/ForecastResultBulkTransportActionTests.java @@ -0,0 +1,150 @@ +/* + * SPDX-License-Identifier: Apache-2.0 + * + * The OpenSearch Contributors require contributions made to + * this file be licensed under the Apache-2.0 license or a + * compatible open source license. + */ + +package org.opensearch.timeseries.transport; + +import static org.mockito.ArgumentMatchers.any; +import static org.mockito.Mockito.doAnswer; +import static org.mockito.Mockito.doThrow; +import static org.mockito.Mockito.mock; +import static org.mockito.Mockito.when; + +import java.io.IOException; + +import org.junit.After; +import org.junit.AfterClass; +import org.junit.Before; +import org.junit.BeforeClass; +import org.opensearch.action.bulk.BulkRequest; +import org.opensearch.action.bulk.BulkResponse; +import org.opensearch.action.support.ActionFilters; +import org.opensearch.action.support.PlainActionFuture; +import org.opensearch.client.Client; +import org.opensearch.cluster.service.ClusterService; +import org.opensearch.common.settings.Settings; +import org.opensearch.core.action.ActionListener; +import org.opensearch.core.xcontent.XContentBuilder; +import org.opensearch.forecast.model.ForecastResult; +import org.opensearch.forecast.ratelimit.ForecastResultWriteRequest; +import org.opensearch.forecast.settings.ForecastSettings; +import org.opensearch.forecast.transport.ForecastResultBulkRequest; +import org.opensearch.forecast.transport.ForecastResultBulkTransportAction; +import org.opensearch.index.IndexingPressure; +import org.opensearch.timeseries.AbstractTimeSeriesTest; +import org.opensearch.timeseries.TestHelpers; +import org.opensearch.transport.TransportService; + +public class ForecastResultBulkTransportActionTests extends AbstractTimeSeriesTest { + + private ForecastResultBulkTransportAction resultBulk; + private TransportService transportService; + private ClusterService clusterService; + private IndexingPressure indexingPressure; + private Client client; + + @BeforeClass + public static void setUpBeforeClass() { + setUpThreadPool(ForecastResultBulkTransportActionTests.class.getSimpleName()); + } + + @AfterClass + public static void tearDownAfterClass() { + tearDownThreadPool(); + } + + @Override + @Before + public void setUp() throws Exception { + super.setUp(); + Settings settings = Settings + .builder() + .put(IndexingPressure.MAX_INDEXING_BYTES.getKey(), "1KB") + .put("forecast.index_pressure.soft_limit", 0.8) + .build(); + + // Setup test nodes and services + setupTestNodes(ForecastSettings.FORECAST_INDEX_PRESSURE_SOFT_LIMIT, ForecastSettings.FORECAST_INDEX_PRESSURE_HARD_LIMIT); + transportService = testNodes[0].transportService; + clusterService = testNodes[0].clusterService; + + ActionFilters actionFilters = mock(ActionFilters.class); + indexingPressure = mock(IndexingPressure.class); + + client = mock(Client.class); + + resultBulk = new ForecastResultBulkTransportAction( + transportService, + actionFilters, + indexingPressure, + settings, + clusterService, + client + ); + } + + @Override + @After + public final void tearDown() throws Exception { + tearDownTestNodes(); + super.tearDown(); + } + + @SuppressWarnings("unchecked") + public void testBulkIndexingFailure() throws IOException { + // Set indexing pressure below soft limit to ensure requests are processed + when(indexingPressure.getCurrentCombinedCoordinatingAndPrimaryBytes()).thenReturn(0L); + when(indexingPressure.getCurrentReplicaBytes()).thenReturn(0L); + + // Create a ForecastResultBulkRequest with some results + ForecastResultBulkRequest originalRequest = new ForecastResultBulkRequest(); + originalRequest.add(TestHelpers.randomForecastResultWriteRequest()); + originalRequest.add(TestHelpers.randomForecastResultWriteRequest()); + + // Mock client.execute to throw an exception + doAnswer(invocation -> { + Object[] args = invocation.getArguments(); + ActionListener listener = (ActionListener) args[2]; + listener.onFailure(new RuntimeException("Simulated bulk indexing failure")); + return null; + }).when(client).execute(any(), any(), any()); + + // Execute the action + PlainActionFuture future = PlainActionFuture.newFuture(); + resultBulk.doExecute(null, originalRequest, future); + + // Verify that the exception is propagated to the listener + Exception exception = expectThrows(Exception.class, () -> future.actionGet()); + assertTrue(exception.getMessage().contains("Simulated bulk indexing failure")); + } + + public void testPrepareBulkRequestFailure() throws IOException { + // Set indexing pressure below soft limit to ensure requests are processed + when(indexingPressure.getCurrentCombinedCoordinatingAndPrimaryBytes()).thenReturn(0L); + when(indexingPressure.getCurrentReplicaBytes()).thenReturn(0L); + + // Create a ForecastResultWriteRequest with a result that throws IOException when toXContent is called + ForecastResultWriteRequest faultyWriteRequest = mock(ForecastResultWriteRequest.class); + ForecastResult faultyResult = mock(ForecastResult.class); + + when(faultyWriteRequest.getResult()).thenReturn(faultyResult); + when(faultyWriteRequest.getResultIndex()).thenReturn(null); + + // Mock the toXContent method to throw IOException + doThrow(new IOException("Simulated IOException in toXContent")).when(faultyResult).toXContent(any(XContentBuilder.class), any()); + + // Create a ForecastResultBulkRequest with the faulty write request + ForecastResultBulkRequest originalRequest = new ForecastResultBulkRequest(); + originalRequest.add(faultyWriteRequest); + + // Execute the prepareBulkRequest method directly + BulkRequest bulkRequest = resultBulk.prepareBulkRequest(0.5f, originalRequest); + + // Since the exception is caught inside addResult, bulkRequest should have zero actions + assertEquals(0, bulkRequest.numberOfActions()); + } +}