Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Optimize ML inference connection retry logic #1054

Merged
merged 1 commit into from
Jan 4, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@ The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.0.0/),
- Support new knn query parameter expand_nested ([#1013](https://github.com/opensearch-project/neural-search/pull/1013))
- Implement pruning for neural sparse ingestion pipeline and two phase search processor ([#988](https://github.com/opensearch-project/neural-search/pull/988))
- Support empty string for fields in text embedding processor ([#1041](https://github.com/opensearch-project/neural-search/pull/1041))
- Optimize ML inference connection retry logic ([#1054](https://github.com/opensearch-project/neural-search/pull/1054))
### Bug Fixes
- Address inconsistent scoring in hybrid query results ([#998](https://github.com/opensearch-project/neural-search/pull/998))
- Fix bug where ingested document has list of nested objects ([#1040](https://github.com/opensearch-project/neural-search/pull/1040))
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -162,14 +162,14 @@ private void retryableInferenceSentencesWithMapResult(
mlClient.predict(modelId, mlInput, ActionListener.wrap(mlOutput -> {
final List<Map<String, ?>> result = buildMapResultFromResponse(mlOutput);
listener.onResponse(result);
}, e -> {
if (RetryUtil.shouldRetry(e, retryTime)) {
final int retryTimeAdd = retryTime + 1;
retryableInferenceSentencesWithMapResult(modelId, inputText, retryTimeAdd, listener);
} else {
listener.onFailure(e);
}
}));
},
e -> RetryUtil.handleRetryOrFailure(
e,
retryTime,
() -> retryableInferenceSentencesWithMapResult(modelId, inputText, retryTime + 1, listener),
listener
)
));
}

private void retryableInferenceSentencesWithVectorResult(
Expand All @@ -183,14 +183,14 @@ private void retryableInferenceSentencesWithVectorResult(
mlClient.predict(modelId, mlInput, ActionListener.wrap(mlOutput -> {
final List<List<Float>> vector = buildVectorFromResponse(mlOutput);
listener.onResponse(vector);
}, e -> {
if (RetryUtil.shouldRetry(e, retryTime)) {
final int retryTimeAdd = retryTime + 1;
retryableInferenceSentencesWithVectorResult(targetResponseFilters, modelId, inputText, retryTimeAdd, listener);
} else {
listener.onFailure(e);
}
}));
},
e -> RetryUtil.handleRetryOrFailure(
e,
retryTime,
() -> retryableInferenceSentencesWithVectorResult(targetResponseFilters, modelId, inputText, retryTime + 1, listener),
listener
)
));
}

private void retryableInferenceSimilarityWithVectorResult(
Expand All @@ -204,13 +204,14 @@ private void retryableInferenceSimilarityWithVectorResult(
mlClient.predict(modelId, mlInput, ActionListener.wrap(mlOutput -> {
final List<Float> scores = buildVectorFromResponse(mlOutput).stream().map(v -> v.get(0)).collect(Collectors.toList());
listener.onResponse(scores);
}, e -> {
if (RetryUtil.shouldRetry(e, retryTime)) {
retryableInferenceSimilarityWithVectorResult(modelId, queryText, inputText, retryTime + 1, listener);
} else {
listener.onFailure(e);
}
}));
},
e -> RetryUtil.handleRetryOrFailure(
e,
retryTime,
() -> retryableInferenceSimilarityWithVectorResult(modelId, queryText, inputText, retryTime + 1, listener),
listener
)
));
}

private MLInput createMLTextInput(final List<String> targetResponseFilters, List<String> inputText) {
Expand Down Expand Up @@ -272,14 +273,20 @@ private void retryableInferenceSentencesWithSingleVectorResult(
final List<Float> vector = buildSingleVectorFromResponse(mlOutput);
log.debug("Inference Response for input sentence is : {} ", vector);
listener.onResponse(vector);
}, e -> {
if (RetryUtil.shouldRetry(e, retryTime)) {
final int retryTimeAdd = retryTime + 1;
retryableInferenceSentencesWithSingleVectorResult(targetResponseFilters, modelId, inputObjects, retryTimeAdd, listener);
} else {
listener.onFailure(e);
}
}));
},
e -> RetryUtil.handleRetryOrFailure(
e,
retryTime,
() -> retryableInferenceSentencesWithSingleVectorResult(
targetResponseFilters,
modelId,
inputObjects,
retryTime + 1,
listener
),
listener
)
));
}

private MLInput createMLMultimodalInput(final List<String> targetResponseFilters, final Map<String, String> input) {
Expand Down
48 changes: 39 additions & 9 deletions src/main/java/org/opensearch/neuralsearch/util/RetryUtil.java
Original file line number Diff line number Diff line change
Expand Up @@ -6,30 +6,60 @@

import java.util.List;

import lombok.extern.log4j.Log4j2;
import org.apache.commons.lang3.exception.ExceptionUtils;
import org.opensearch.core.action.ActionListener;
import org.opensearch.transport.NodeDisconnectedException;
import org.opensearch.transport.NodeNotConnectedException;

import com.google.common.collect.ImmutableList;
import org.opensearch.common.Randomness;

@Log4j2
public class RetryUtil {

private static final int MAX_RETRY = 3;

private static final int DEFAULT_MAX_RETRY = 3;
private static final long DEFAULT_BASE_DELAY_MS = 500;
private static final List<Class<? extends Throwable>> RETRYABLE_EXCEPTIONS = ImmutableList.of(
NodeNotConnectedException.class,
NodeDisconnectedException.class
);

/**
*
* @param e {@link Exception} which is the exception received to check if retryable.
* @param retryTime {@link int} which is the current retried times.
* @return {@link boolean} which is the result of if current exception needs retry or not.
* Handle retry or failure based on the exception and retry time
* @param e Exception
* @param retryTime Retry time
* @param retryAction Action to retry
* @param listener Listener to handle success or failure
*/
public static boolean shouldRetry(final Exception e, int retryTime) {
boolean hasRetryException = RETRYABLE_EXCEPTIONS.stream().anyMatch(x -> ExceptionUtils.indexOfThrowable(e, x) != -1);
return hasRetryException && retryTime < MAX_RETRY;
public static void handleRetryOrFailure(Exception e, int retryTime, Runnable retryAction, ActionListener<?> listener) {
if (shouldRetry(e, retryTime)) {
long backoffTime = calculateBackoffTime(retryTime);
log.warn("Retrying connection for ML inference due to [{}] after [{}ms]", e.getMessage(), backoffTime, e);
try {
Thread.sleep(backoffTime);
} catch (InterruptedException interruptedException) {
Thread.currentThread().interrupt();
listener.onFailure(interruptedException);
return;
}
retryAction.run();
} else {
listener.onFailure(e);
}
}

private static boolean shouldRetry(final Exception e, int retryTime) {
return isRetryableException(e) && retryTime < DEFAULT_MAX_RETRY;
}

private static boolean isRetryableException(final Exception e) {
return RETRYABLE_EXCEPTIONS.stream().anyMatch(x -> ExceptionUtils.indexOfThrowable(e, x) != -1);
}

private static long calculateBackoffTime(int retryTime) {
long backoffTime = DEFAULT_BASE_DELAY_MS * (1L << retryTime); // Exponential backoff
long jitter = Randomness.get().nextLong(10, 50); // Add jitter between 10ms and 50ms
return backoffTime + jitter;
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -120,7 +120,37 @@ public void testInferenceSentences_whenExceptionFromMLClient_thenFailure() {
Mockito.verifyNoMoreInteractions(resultListener);
}

public void testInferenceSentences_whenNodeNotConnectedException_thenRetry_3Times() {
public void testInferenceSimilarity_whenNodeNotConnectedException_ThenRetry() {
final NodeNotConnectedException nodeNodeConnectedException = new NodeNotConnectedException(
mock(DiscoveryNode.class),
"Node not connected"
);

Mockito.doAnswer(invocation -> {
final ActionListener<MLOutput> actionListener = invocation.getArgument(2);
actionListener.onFailure(nodeNodeConnectedException);
return null;
}).when(client).predict(Mockito.eq(TestCommonConstants.MODEL_ID), Mockito.isA(MLInput.class), Mockito.isA(ActionListener.class));

accessor.inferenceSimilarity(
TestCommonConstants.MODEL_ID,
"is it sunny",
List.of("it is sunny today", "roses are red"),
singleSentenceResultListener
);

// Verify client.predict is called 4 times (1 initial + 3 retries)
Mockito.verify(client, times(4))
.predict(Mockito.eq(TestCommonConstants.MODEL_ID), Mockito.isA(MLInput.class), Mockito.isA(ActionListener.class));

// Verify failure is propagated to the listener after all retries
Mockito.verify(singleSentenceResultListener).onFailure(nodeNodeConnectedException);

// Ensure no additional interactions with the listener
Mockito.verifyNoMoreInteractions(singleSentenceResultListener);
}

public void testInferenceSentences_whenExceptionFromMLClient_thenRetry_thenFailure() {
final NodeNotConnectedException nodeNodeConnectedException = new NodeNotConnectedException(
mock(DiscoveryNode.class),
"Node not connected"
Expand Down Expand Up @@ -293,18 +323,28 @@ public void testInferenceMultimodal_whenValidInput_thenSuccess() {
Mockito.verifyNoMoreInteractions(singleSentenceResultListener);
}

public void testInferenceMultimodal_whenExceptionFromMLClient_thenFailure() {
final RuntimeException exception = new RuntimeException();
public void testInferenceMultimodal_whenExceptionFromMLClient_thenRetry_thenFailure() {
final NodeNotConnectedException nodeNodeConnectedException = new NodeNotConnectedException(
mock(DiscoveryNode.class),
"Node not connected"
);

Mockito.doAnswer(invocation -> {
final ActionListener<MLOutput> actionListener = invocation.getArgument(2);
actionListener.onFailure(exception);
actionListener.onFailure(nodeNodeConnectedException);
return null;
}).when(client).predict(Mockito.eq(TestCommonConstants.MODEL_ID), Mockito.isA(MLInput.class), Mockito.isA(ActionListener.class));

accessor.inferenceSentences(TestCommonConstants.MODEL_ID, TestCommonConstants.SENTENCES_MAP, singleSentenceResultListener);

Mockito.verify(client)
// Verify client.predict is called 4 times (1 initial + 3 retries)
Mockito.verify(client, times(4))
.predict(Mockito.eq(TestCommonConstants.MODEL_ID), Mockito.isA(MLInput.class), Mockito.isA(ActionListener.class));
Mockito.verify(singleSentenceResultListener).onFailure(exception);

// Verify failure is propagated to the listener after retries
Mockito.verify(singleSentenceResultListener).onFailure(nodeNodeConnectedException);

// Verify no further interactions with the listener
Mockito.verifyNoMoreInteractions(singleSentenceResultListener);
}

Expand Down Expand Up @@ -367,29 +407,6 @@ public void testInferencesSimilarity_whenExceptionFromMLClient_ThenFail() {
Mockito.verifyNoMoreInteractions(singleSentenceResultListener);
}

public void testInferenceSimilarity_whenNodeNotConnectedException_ThenTryThreeTimes() {
final NodeNotConnectedException nodeNodeConnectedException = new NodeNotConnectedException(
mock(DiscoveryNode.class),
"Node not connected"
);
Mockito.doAnswer(invocation -> {
final ActionListener<MLOutput> actionListener = invocation.getArgument(2);
actionListener.onFailure(nodeNodeConnectedException);
return null;
}).when(client).predict(Mockito.eq(TestCommonConstants.MODEL_ID), Mockito.isA(MLInput.class), Mockito.isA(ActionListener.class));

accessor.inferenceSimilarity(
TestCommonConstants.MODEL_ID,
"is it sunny",
List.of("it is sunny today", "roses are red"),
singleSentenceResultListener
);

Mockito.verify(client, times(4))
.predict(Mockito.eq(TestCommonConstants.MODEL_ID), Mockito.isA(MLInput.class), Mockito.isA(ActionListener.class));
Mockito.verify(singleSentenceResultListener).onFailure(nodeNodeConnectedException);
}

private ModelTensorOutput createModelTensorOutput(final Float[] output) {
final List<ModelTensors> tensorsList = new ArrayList<>();
final List<ModelTensor> mlModelTensorList = new ArrayList<>();
Expand Down
Loading