diff --git a/ml-algorithms/src/main/java/org/opensearch/ml/engine/algorithms/regression/LinearRegression.java b/ml-algorithms/src/main/java/org/opensearch/ml/engine/algorithms/regression/LinearRegression.java index 1bd78aa478..895261201e 100644 --- a/ml-algorithms/src/main/java/org/opensearch/ml/engine/algorithms/regression/LinearRegression.java +++ b/ml-algorithms/src/main/java/org/opensearch/ml/engine/algorithms/regression/LinearRegression.java @@ -52,7 +52,7 @@ public class LinearRegression implements Trainable, Predictable { public static final String VERSION = "1.0.0"; private static final LinearRegressionParams.ObjectiveType DEFAULT_OBJECTIVE_TYPE = LinearRegressionParams.ObjectiveType.SQUARED_LOSS; - private static final LinearRegressionParams.OptimizerType DEFAULT_OPTIMIZER_TYPE = LinearRegressionParams.OptimizerType.SIMPLE_SGD; + private static final LinearRegressionParams.OptimizerType DEFAULT_OPTIMIZER_TYPE = LinearRegressionParams.OptimizerType.ADA_GRAD; private static final double DEFAULT_LEARNING_RATE = 0.01; // Momentum private static final double DEFAULT_MOMENTUM_FACTOR = 0; @@ -134,15 +134,15 @@ private void createOptimiser() { break; } switch (optimizerType) { + case SIMPLE_SGD: + optimiser = SGD.getSimpleSGD(learningRate, momentumFactor, momentum); + break; case LINEAR_DECAY_SGD: optimiser = SGD.getLinearDecaySGD(learningRate, momentumFactor, momentum); break; case SQRT_DECAY_SGD: optimiser = SGD.getSqrtDecaySGD(learningRate, momentumFactor, momentum); break; - case ADA_GRAD: - optimiser = new AdaGrad(learningRate, epsilon); - break; case ADA_DELTA: optimiser = new AdaDelta(momentumFactor, epsilon); break; @@ -153,8 +153,9 @@ private void createOptimiser() { optimiser = new RMSProp(learningRate, momentumFactor, epsilon, decayRate); break; default: - // Use default SGD with a constant learning rate. - optimiser = SGD.getSimpleSGD(learningRate, momentumFactor, momentum); + // Use AdaGrad by default, reference issue: + // https://github.com/opensearch-project/ml-commons/issues/3210#issuecomment-2556119802 + optimiser = new AdaGrad(learningRate, epsilon); break; } }