diff --git a/gradle.properties b/gradle.properties index e5ba6c4a..22d74bf2 100644 --- a/gradle.properties +++ b/gradle.properties @@ -16,4 +16,4 @@ group=com.nike artifactId=cerberus-lifecycle-cli -version=1.5.0 +version=1.6.0 diff --git a/src/main/java/com/nike/cerberus/cli/CerberusRunner.java b/src/main/java/com/nike/cerberus/cli/CerberusRunner.java index 99268f6f..ddcad63c 100644 --- a/src/main/java/com/nike/cerberus/cli/CerberusRunner.java +++ b/src/main/java/com/nike/cerberus/cli/CerberusRunner.java @@ -30,6 +30,7 @@ import com.nike.cerberus.command.cms.CreateCmsClusterCommand; import com.nike.cerberus.command.cms.CreateCmsConfigCommand; import com.nike.cerberus.command.cms.UpdateCmsConfigCommand; +import com.nike.cerberus.command.core.RollingRebootWithHealthCheckCommand; import com.nike.cerberus.command.core.ViewConfigCommand; import com.nike.cerberus.command.consul.CreateConsulClusterCommand; import com.nike.cerberus.command.consul.CreateConsulConfigCommand; @@ -198,6 +199,7 @@ private void registerAllCommands() { registerCommand(new RestoreCompleteCerberusDataFromS3BackupCommand()); registerCommand(new ViewConfigCommand()); registerCommand(new UpdateCmsConfigCommand()); + registerCommand(new RollingRebootWithHealthCheckCommand()); } /** diff --git a/src/main/java/com/nike/cerberus/client/CerberusAdminClient.java b/src/main/java/com/nike/cerberus/client/CerberusAdminClient.java index fb2e5705..e3c266df 100644 --- a/src/main/java/com/nike/cerberus/client/CerberusAdminClient.java +++ b/src/main/java/com/nike/cerberus/client/CerberusAdminClient.java @@ -16,7 +16,6 @@ package com.nike.cerberus.client; -import com.fasterxml.jackson.databind.JsonNode; import com.nike.vault.client.UrlResolver; import com.nike.vault.client.VaultAdminClient; import com.nike.vault.client.VaultClientException; diff --git a/src/main/java/com/nike/cerberus/command/cms/UpdateCmsConfigCommand.java b/src/main/java/com/nike/cerberus/command/cms/UpdateCmsConfigCommand.java index b4379c34..95cc3a67 100644 --- a/src/main/java/com/nike/cerberus/command/cms/UpdateCmsConfigCommand.java +++ b/src/main/java/com/nike/cerberus/command/cms/UpdateCmsConfigCommand.java @@ -26,7 +26,7 @@ import java.util.HashMap; import java.util.Map; -import static com.nike.cerberus.command.cms.CreateCmsClusterCommand.COMMAND_NAME; +import static com.nike.cerberus.command.cms.UpdateCmsConfigCommand.COMMAND_NAME; /** * Command to create the CMS cluster. diff --git a/src/main/java/com/nike/cerberus/command/core/RollingRebootWithHealthCheckCommand.java b/src/main/java/com/nike/cerberus/command/core/RollingRebootWithHealthCheckCommand.java new file mode 100644 index 00000000..1db261a4 --- /dev/null +++ b/src/main/java/com/nike/cerberus/command/core/RollingRebootWithHealthCheckCommand.java @@ -0,0 +1,40 @@ +package com.nike.cerberus.command.core; + +import com.beust.jcommander.Parameter; +import com.beust.jcommander.Parameters; +import com.nike.cerberus.command.Command; +import com.nike.cerberus.domain.environment.StackName; +import com.nike.cerberus.operation.Operation; +import com.nike.cerberus.operation.core.RollingRebootWithHealthCheckOperation; + +import static com.nike.cerberus.command.core.RollingRebootWithHealthCheckCommand.COMMAND_NAME; + +/** + * Command to reboot the CMS cluster. + */ +@Parameters( + commandNames = COMMAND_NAME, + commandDescription = "Performs a safe rolling reboot on instances in the given cluster, checking that " + + "the previous instance is healthy before rebooting the next one." +) +public class RollingRebootWithHealthCheckCommand implements Command { + + public static final String COMMAND_NAME = "rolling-reboot-with-health-check"; + + @Parameter(names = {"--stack-name"}, required = true, description = "The stack name to reboot.") + private StackName stackName = StackName.CMS; + + public StackName getStackName() { + return stackName; + } + + @Override + public String getCommandName() { + return COMMAND_NAME; + } + + @Override + public Class> getOperationClass() { + return RollingRebootWithHealthCheckOperation.class; + } +} diff --git a/src/main/java/com/nike/cerberus/command/core/ViewConfigCommand.java b/src/main/java/com/nike/cerberus/command/core/ViewConfigCommand.java index a35e289f..671b3c20 100644 --- a/src/main/java/com/nike/cerberus/command/core/ViewConfigCommand.java +++ b/src/main/java/com/nike/cerberus/command/core/ViewConfigCommand.java @@ -20,14 +20,14 @@ import com.beust.jcommander.Parameters; import com.nike.cerberus.command.Command; import com.nike.cerberus.operation.Operation; -import com.nike.cerberus.operation.cms.ViewConfigOperation; +import com.nike.cerberus.operation.core.ViewConfigOperation; import static com.nike.cerberus.command.cms.CreateCmsClusterCommand.COMMAND_NAME; /** - * Command to create the CMS cluster. + * Command to view configuration files in S3. */ -@Parameters(commandNames = COMMAND_NAME, commandDescription = "Shows the CMS config.") +@Parameters(commandNames = COMMAND_NAME, commandDescription = "Shows configuration files from S3.") public class ViewConfigCommand implements Command { public static final String COMMAND_NAME = "view-config"; diff --git a/src/main/java/com/nike/cerberus/operation/core/RollingRebootWithHealthCheckOperation.java b/src/main/java/com/nike/cerberus/operation/core/RollingRebootWithHealthCheckOperation.java new file mode 100644 index 00000000..c879aadd --- /dev/null +++ b/src/main/java/com/nike/cerberus/operation/core/RollingRebootWithHealthCheckOperation.java @@ -0,0 +1,229 @@ +package com.nike.cerberus.operation.core; + +import com.amazonaws.services.ec2.model.Filter; +import com.amazonaws.services.ec2.model.Instance; +import com.github.tomaslanger.chalk.Chalk; +import com.google.common.collect.ImmutableMap; +import com.google.inject.Inject; +import com.nike.cerberus.command.core.RollingRebootWithHealthCheckCommand; +import com.nike.cerberus.domain.environment.StackName; +import com.nike.cerberus.operation.Operation; +import com.nike.cerberus.service.AutoScalingService; +import com.nike.cerberus.service.CloudFormationService; +import com.nike.cerberus.service.Ec2Service; +import com.nike.cerberus.store.ConfigStore; +import com.nike.vault.client.http.HttpStatus; +import okhttp3.Call; +import okhttp3.OkHttpClient; +import okhttp3.Request; +import okhttp3.Response; +import org.apache.http.conn.ssl.NoopHostnameVerifier; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +import java.io.IOException; +import java.net.Proxy; +import java.util.List; +import java.util.Map; +import java.util.concurrent.TimeUnit; + +import static com.nike.cerberus.service.CloudFormationService.MIN_INSTANCES_STACK_PARAMETER_KEY; +import static com.nike.cerberus.service.Ec2Service.EC2_ASG_GROUP_NAME_TAG_KEY; +import static com.nike.cerberus.service.Ec2Service.INSTANCE_STATE_FILTER_NAME; +import static com.nike.cerberus.service.Ec2Service.INSTANCE_STATE_RUNNING_FILTER_VALUE; + +/** + * Reboots all EC2 instances in the given cluster. + */ +public class RollingRebootWithHealthCheckOperation implements Operation { + + private final Logger logger = LoggerFactory.getLogger(getClass()); + + private final static ImmutableMap HEALTH_CHECK_MAP = ImmutableMap.of( + StackName.CMS.getName(), "http://%s:8080/healthcheck", + StackName.GATEWAY.getName(), "https://%s:443/sys/health" + // TODO: Test that command works with remaining stacks +// StackName.VAULT.getName(), "https://%s:8200/v1/sys/health?standbyok", +// StackName.CONSUL.getName(), "https://%s:8500/v1/sys/health" + ); + + private final static int DEFAULT_HTTP_TIMEOUT = 15; + + private final static TimeUnit DEFAULT_HTTP_TIMEOUT_UNIT = TimeUnit.SECONDS; + + private final static int NUM_SECS_BETWEEN_HEALTH_CHECKS = 5; + + private final static int EXPECTED_NUM_SUCCESSES_AFTER_REBOOT = 10; + + private final static int EXPECTED_NUM_FAILURES_AFTER_REBOOT = 3; + + private final static int HEALTH_CHECK_FAILED_CODE = -1; + + private final ConfigStore configStore; + + private final CloudFormationService cloudFormationService; + + private final Ec2Service ec2Service; + + private final AutoScalingService autoScalingService; + + private final Proxy proxy; + + @Inject + public RollingRebootWithHealthCheckOperation(final ConfigStore configStore, + final CloudFormationService cloudFormationService, + final Ec2Service ec2Service, + final AutoScalingService autoScalingService, + final Proxy proxy) { + this.configStore = configStore; + this.cloudFormationService = cloudFormationService; + this.ec2Service = ec2Service; + this.autoScalingService = autoScalingService; + this.proxy = proxy; + } + + @Override + public void run(final RollingRebootWithHealthCheckCommand command) { + + logger.warn(Chalk.on( + "If this command fails: the minimum instance size may need to be increased and an EC2 instance" + + " may need to be set to 'in-service' state on the auto scaling group").yellow().toString()); + + final StackName stackName = command.getStackName(); + final String stackId = configStore.getStackId(stackName); + final Map stackOutputs = cloudFormationService.getStackOutputs(stackId); + + final Map stackParameters = cloudFormationService.getStackParameters(stackId); + final int minInstances = Integer.parseInt(stackParameters.get(MIN_INSTANCES_STACK_PARAMETER_KEY)); + + final String autoScalingGroupId = stackOutputs.get(CloudFormationService.AUTO_SCALING_GROUP_LOGICAL_ID_OUTPUT_KEY); + logger.debug("Found auto scaling group id for stack: {}", stackId); + + final Filter isRunningFilter = new Filter(INSTANCE_STATE_FILTER_NAME).withValues(INSTANCE_STATE_RUNNING_FILTER_VALUE); + final List instances = ec2Service.getInstancesByTag(EC2_ASG_GROUP_NAME_TAG_KEY, autoScalingGroupId, isRunningFilter); + logger.debug("Found {} instances by tag: '{}:{}'", instances.size(), EC2_ASG_GROUP_NAME_TAG_KEY, autoScalingGroupId); + + logger.info("Temporarily decreasing min instances for ASG: {}", autoScalingGroupId); + autoScalingService.updateMinInstancesForAutoScalingGroup(autoScalingGroupId, minInstances - 1); + + instances.forEach(instance -> { + rebootInstance(stackName, autoScalingGroupId, instance); + }); + + logger.info("Increasing min instances for ASG: {}", autoScalingGroupId); + autoScalingService.updateMinInstancesForAutoScalingGroup(autoScalingGroupId, minInstances); + } + + /** + * Reboot an instance and make sure it comes back healthy + */ + private void rebootInstance(StackName stackName, String autoScalingGroupId, Instance instance) { + final String instanceId = instance.getInstanceId(); + logger.info("Setting instance state to standby: {}", instanceId); + autoScalingService.setInstanceStateToStandby(autoScalingGroupId, instanceId); + + logger.info("Rebooting instance: {}", instanceId); + ec2Service.rebootEc2Instance(instanceId); + + final String healthCheckUrlTmpl = HEALTH_CHECK_MAP.get(stackName.getName()); + final String healthCheckUrl = String.format(healthCheckUrlTmpl, instance.getPublicDnsName()); + + // wait for health check fail to confirm box reboot + logger.info("Waiting for health check failure to confirm reboot..."); + waitForHealthCheckStatusCode(healthCheckUrl, HEALTH_CHECK_FAILED_CODE, EXPECTED_NUM_FAILURES_AFTER_REBOOT); + + // wait for health check pass to confirm instance is healthy after reboot + logger.warn(Chalk.on( + "If a proxy is required to talk to the EC2 instance, then make sure it is set up." + + " Otherwise this command will never succeed.").yellow().toString()); + logger.info("Waiting for health check to pass again to confirm instance is healthy..."); + waitForHealthCheckStatusCode(healthCheckUrl, HttpStatus.OK, EXPECTED_NUM_SUCCESSES_AFTER_REBOOT); + + logger.info("Setting instance state to in-service: {}", instanceId); + autoScalingService.setInstanceStateToInService(autoScalingGroupId, instanceId); + } + + /** + * Poll the health check 'n' times, looking for the given response + * @param healthCheckUrl - The health check URL + * @param numConsecutiveResponsesExpected - The number of times to poll health check + */ + private void waitForHealthCheckStatusCode(final String healthCheckUrl, + final long expectedStatusCode, + final int numConsecutiveResponsesExpected) { + + int responseCode; + int consecutiveResponses = 0; + while (consecutiveResponses < numConsecutiveResponsesExpected) { + + responseCode = executeHealthCheck(healthCheckUrl); + + if (responseCode == expectedStatusCode) { + consecutiveResponses++; + } else if (consecutiveResponses > 0) { + final String message = Chalk.on("Instance health check did not repeat response code ({}), {} times").red().bold().toString(); + logger.debug(message, expectedStatusCode, numConsecutiveResponsesExpected); + consecutiveResponses = 0; + } + + try { + TimeUnit.SECONDS.sleep(NUM_SECS_BETWEEN_HEALTH_CHECKS); + } catch (InterruptedException ie) { + logger.error(Chalk.on("Timeout between health checks has been interrupted").red().bold().toString()); + return; + } + } + } + + /** + * Execute the given health check + * @param healthCheckUrl - Name of that EC2 instance belongs to + * @return - Response code of the health check + */ + private int executeHealthCheck(final String healthCheckUrl) { + + final OkHttpClient okHttpClient = new OkHttpClient.Builder() + .hostnameVerifier(new NoopHostnameVerifier()) + .proxy(proxy) + .connectTimeout(DEFAULT_HTTP_TIMEOUT, DEFAULT_HTTP_TIMEOUT_UNIT) + .writeTimeout(DEFAULT_HTTP_TIMEOUT, DEFAULT_HTTP_TIMEOUT_UNIT) + .readTimeout(DEFAULT_HTTP_TIMEOUT, DEFAULT_HTTP_TIMEOUT_UNIT) + .build(); + + final Request requestBuilder = new Request.Builder() + .url(healthCheckUrl) + .get() + .build(); + + final Call healthCheckCall = okHttpClient.newCall(requestBuilder); + + try(final Response response = healthCheckCall.execute()) { + logger.debug("Health check returned status: {}, URL: {}", response.code(), healthCheckUrl); + return response.code(); + } catch (IOException ioe) { + final String message = Chalk.on("Health check failed, Cause: {}, URL: {}").red().toString(); + logger.debug(message, ioe.getMessage(), healthCheckUrl); + } + + return HEALTH_CHECK_FAILED_CODE; + } + + @Override + public boolean isRunnable(final RollingRebootWithHealthCheckCommand command) { + + final StackName stackName = command.getStackName(); + final String stackNameStr = stackName.getName(); + final String stackId = configStore.getStackId(stackName); + final Map stackParameters = cloudFormationService.getStackParameters(stackId); + + if (! HEALTH_CHECK_MAP.containsKey(stackNameStr)) { + logger.error("Cannot reboot cluster: {}. Allowed stacks: {}", stackName, HEALTH_CHECK_MAP.keySet()); + return false; + } else if (! stackParameters.containsKey(MIN_INSTANCES_STACK_PARAMETER_KEY)) { + logger.error("Could not find parameter 'minInstances' on stack: {}", stackId); + return false; + } else { + return true; + } + } +} diff --git a/src/main/java/com/nike/cerberus/operation/cms/ViewConfigOperation.java b/src/main/java/com/nike/cerberus/operation/core/ViewConfigOperation.java similarity index 92% rename from src/main/java/com/nike/cerberus/operation/cms/ViewConfigOperation.java rename to src/main/java/com/nike/cerberus/operation/core/ViewConfigOperation.java index f3e3e4e9..eb060168 100644 --- a/src/main/java/com/nike/cerberus/operation/cms/ViewConfigOperation.java +++ b/src/main/java/com/nike/cerberus/operation/core/ViewConfigOperation.java @@ -14,7 +14,7 @@ * limitations under the License. */ -package com.nike.cerberus.operation.cms; +package com.nike.cerberus.operation.core; import com.nike.cerberus.command.core.ViewConfigCommand; import com.nike.cerberus.operation.Operation; @@ -26,7 +26,7 @@ import java.util.Optional; /** - * Gathers all of the CMS environment configuration and puts it in the config bucket. + * Displays the given configuration file from S3. */ public class ViewConfigOperation implements Operation { diff --git a/src/main/java/com/nike/cerberus/service/AutoScalingService.java b/src/main/java/com/nike/cerberus/service/AutoScalingService.java index a61ca37c..1e5a19db 100644 --- a/src/main/java/com/nike/cerberus/service/AutoScalingService.java +++ b/src/main/java/com/nike/cerberus/service/AutoScalingService.java @@ -20,6 +20,9 @@ import com.amazonaws.services.autoscaling.model.AutoScalingGroup; import com.amazonaws.services.autoscaling.model.DescribeAutoScalingGroupsRequest; import com.amazonaws.services.autoscaling.model.DescribeAutoScalingGroupsResult; +import com.amazonaws.services.autoscaling.model.EnterStandbyRequest; +import com.amazonaws.services.autoscaling.model.ExitStandbyRequest; +import com.amazonaws.services.autoscaling.model.UpdateAutoScalingGroupRequest; import com.amazonaws.services.ec2.AmazonEC2; import com.amazonaws.services.ec2.model.DescribeInstancesRequest; import com.amazonaws.services.ec2.model.DescribeInstancesResult; @@ -52,10 +55,7 @@ public AutoScalingService(final AmazonAutoScaling autoScalingClient, final Amazo */ public List getPublicDnsForAutoScalingGroup(final String logicalId) { final List instanceIds = Lists.newLinkedList(); - final DescribeAutoScalingGroupsRequest describeAsg = new DescribeAutoScalingGroupsRequest() - .withAutoScalingGroupNames(logicalId); - final DescribeAutoScalingGroupsResult result = autoScalingClient.describeAutoScalingGroups(describeAsg); - final Optional autoScalingGroup = result.getAutoScalingGroups().stream().findFirst(); + final Optional autoScalingGroup = describeAutoScalingGroup(logicalId); final List publicDnsNames = Lists.newLinkedList(); if (autoScalingGroup.isPresent()) { @@ -75,4 +75,55 @@ public List getPublicDnsForAutoScalingGroup(final String logicalId) { return publicDnsNames; } + + /** + * Updates the minimum number of instances allowed in the auto scaling group + * @param logicalId - Name of the auto scaling group + */ + public void updateMinInstancesForAutoScalingGroup(final String logicalId, final int minInstances) { + + final UpdateAutoScalingGroupRequest request = new UpdateAutoScalingGroupRequest() + .withAutoScalingGroupName(logicalId) + .withMinSize(minInstances); + + autoScalingClient.updateAutoScalingGroup(request); + } + + /** + * Set an EC2 instance to standby state, so that the desired instance count on the AutoScaling group is decreased + * and a new instance is not spun up on instance reboot. This also removes the instance from the ELB, so that the + * instance is not terminated when the health check fails. + * @param logicalId - Name of the auto scaling group + * @param instanceId - ID of the EC2 instance + */ + public void setInstanceStateToStandby(final String logicalId, final String instanceId) { + final EnterStandbyRequest request = new EnterStandbyRequest() + .withAutoScalingGroupName(logicalId) + .withInstanceIds(instanceId) + .withShouldDecrementDesiredCapacity(true); + + autoScalingClient.enterStandby(request); + } + + /** + * Signify that the EC2 instance is now in service and ready to be re-added to the ELB and AutoScaling group. This + * will also increase the desired instance count for the ASG. + * @param logicalId - Name of the auto scaling group + * @param instanceId - ID of the EC2 instance + */ + public void setInstanceStateToInService(final String logicalId, final String instanceId) { + final ExitStandbyRequest request = new ExitStandbyRequest() + .withAutoScalingGroupName(logicalId) + .withInstanceIds(instanceId); + + autoScalingClient.exitStandby(request); + } + + private Optional describeAutoScalingGroup(final String autoscalingGroupName) { + final DescribeAutoScalingGroupsRequest describeAsg = new DescribeAutoScalingGroupsRequest() + .withAutoScalingGroupNames(autoscalingGroupName); + final DescribeAutoScalingGroupsResult result = autoScalingClient.describeAutoScalingGroups(describeAsg); + + return result.getAutoScalingGroups().stream().findFirst(); + } } diff --git a/src/main/java/com/nike/cerberus/service/CloudFormationService.java b/src/main/java/com/nike/cerberus/service/CloudFormationService.java index c400db1e..39ed995e 100644 --- a/src/main/java/com/nike/cerberus/service/CloudFormationService.java +++ b/src/main/java/com/nike/cerberus/service/CloudFormationService.java @@ -64,6 +64,10 @@ public class CloudFormationService { private final Logger logger = LoggerFactory.getLogger(getClass()); + public final static String AUTO_SCALING_GROUP_LOGICAL_ID_OUTPUT_KEY = "autoscalingGroupLogicalId"; + + public final static String MIN_INSTANCES_STACK_PARAMETER_KEY = "minimumInstances"; + private final AmazonCloudFormation cloudFormationClient; private final EnvironmentMetadata environmentMetadata; diff --git a/src/main/java/com/nike/cerberus/service/Ec2Service.java b/src/main/java/com/nike/cerberus/service/Ec2Service.java index 1a456ad0..f2476882 100644 --- a/src/main/java/com/nike/cerberus/service/Ec2Service.java +++ b/src/main/java/com/nike/cerberus/service/Ec2Service.java @@ -21,13 +21,25 @@ import com.amazonaws.services.ec2.model.AvailabilityZone; import com.amazonaws.services.ec2.model.AvailabilityZoneState; import com.amazonaws.services.ec2.model.DescribeAvailabilityZonesResult; +import com.amazonaws.services.ec2.model.DescribeInstanceStatusRequest; +import com.amazonaws.services.ec2.model.DescribeInstancesRequest; +import com.amazonaws.services.ec2.model.DescribeInstancesResult; import com.amazonaws.services.ec2.model.DescribeKeyPairsRequest; import com.amazonaws.services.ec2.model.DescribeKeyPairsResult; +import com.amazonaws.services.ec2.model.Filter; import com.amazonaws.services.ec2.model.ImportKeyPairRequest; import com.amazonaws.services.ec2.model.ImportKeyPairResult; +import com.amazonaws.services.ec2.model.Instance; +import com.amazonaws.services.ec2.model.InstanceStatus; +import com.amazonaws.services.ec2.model.RebootInstancesRequest; +import com.amazonaws.services.ec2.model.Reservation; +import com.google.common.collect.Lists; +import com.google.common.collect.Sets; import javax.inject.Inject; import java.util.List; +import java.util.Optional; +import java.util.Set; import java.util.stream.Collectors; /** @@ -35,6 +47,14 @@ */ public class Ec2Service { + public final static String EC2_ASG_GROUP_NAME_TAG_KEY = "aws:autoscaling:groupName"; + + public final static String INSTANCE_STATE_FILTER_NAME = "instance-state-name"; + + public final static String INSTANCE_STATE_RUNNING_FILTER_VALUE = "running"; + + protected final static String FILTER_NAME_TEMPL_FOR_EC2_TAGS = "tag:%s"; + private final AmazonEC2 ec2Client; @Inject @@ -77,7 +97,7 @@ public boolean isKeyPairPresent(final String keyName) { } /** - * Determines all availabity zones for a region that are marked as available. + * Determines all availability zones for a region that are marked as available. * * @return List of availability zones */ @@ -89,4 +109,41 @@ public List getAvailabilityZones() { .filter(az -> AvailabilityZoneState.Available == AvailabilityZoneState.fromValue(az.getState())) .map(AvailabilityZone::getZoneName).collect(Collectors.toList()); } + + /** + * Gets all EC2 instances with the given tag key/value pair + * @param tagKey - Key of the tag + * @param tagValue - Value of the tag + * @param filters - Array of EC2 filters + * @return - List of instances with the given tag + */ + public List getInstancesByTag(final String tagKey, final String tagValue, final Filter... filters) { + final String filterName = String.format(FILTER_NAME_TEMPL_FOR_EC2_TAGS, tagKey); + final Filter tagFilter = new Filter().withName(filterName).withValues(tagValue); + + final Set filterSet = Sets.newHashSet(filters); + filterSet.add(tagFilter); + final DescribeInstancesRequest request = new DescribeInstancesRequest().withFilters(filterSet); + + DescribeInstancesResult result = ec2Client.describeInstances(request); + List instances = Lists.newArrayList(); + + result.getReservations().forEach(reservation -> { + instances.addAll(reservation.getInstances()); + }); + + return instances; + } + + /** + * Reboots the EC2 instance with the given ID + * @param instanceId - EC2 instance ID + */ + public void rebootEc2Instance(final String instanceId) { + + final RebootInstancesRequest request = new RebootInstancesRequest().withInstanceIds(instanceId); + + ec2Client.rebootInstances(request); + } + } diff --git a/src/test/java/com/nike/cerberus/service/AutoScalingServiceTest.java b/src/test/java/com/nike/cerberus/service/AutoScalingServiceTest.java index 41662036..197dc98c 100644 --- a/src/test/java/com/nike/cerberus/service/AutoScalingServiceTest.java +++ b/src/test/java/com/nike/cerberus/service/AutoScalingServiceTest.java @@ -20,28 +20,42 @@ import com.amazonaws.services.autoscaling.model.AutoScalingGroup; import com.amazonaws.services.autoscaling.model.DescribeAutoScalingGroupsRequest; import com.amazonaws.services.autoscaling.model.DescribeAutoScalingGroupsResult; +import com.amazonaws.services.autoscaling.model.EnterStandbyRequest; +import com.amazonaws.services.autoscaling.model.ExitStandbyRequest; import com.amazonaws.services.autoscaling.model.Instance; +import com.amazonaws.services.autoscaling.model.UpdateAutoScalingGroupRequest; import com.amazonaws.services.ec2.AmazonEC2; import com.amazonaws.services.ec2.model.DescribeInstancesRequest; import com.amazonaws.services.ec2.model.DescribeInstancesResult; import com.amazonaws.services.ec2.model.Reservation; +import org.junit.Before; import org.junit.Test; import java.util.List; import static org.junit.Assert.assertEquals; import static org.mockito.Mockito.mock; +import static org.mockito.Mockito.verify; import static org.mockito.Mockito.when; public class AutoScalingServiceTest { - @Test - public void testGetPublicDnsForAutoScalingGroup() { + private AmazonAutoScaling autoScalingClient; + + private AmazonEC2 ec2Client; + + private AutoScalingService autoScalingService; + + @Before + public void setup() { + autoScalingClient = mock(AmazonAutoScaling.class); + ec2Client = mock(AmazonEC2.class); - AmazonAutoScaling autoScalingClient = mock(AmazonAutoScaling.class); - AmazonEC2 ec2Client = mock(AmazonEC2.class); + autoScalingService = new AutoScalingService(autoScalingClient, ec2Client); + } - AutoScalingService autoScalingService = new AutoScalingService(autoScalingClient, ec2Client); + @Test + public void testGetPublicDnsForAutoScalingGroup() { String logicalId = "fake-logical-id"; String instanceId = "fake-instance-id"; @@ -88,11 +102,6 @@ public void testGetPublicDnsForAutoScalingGroup() { @Test public void testGetPublicDnsForAutoScalingGroupNoInstancesFound() { - AmazonAutoScaling autoScalingClient = mock(AmazonAutoScaling.class); - AmazonEC2 ec2Client = mock(AmazonEC2.class); - - AutoScalingService autoScalingService = new AutoScalingService(autoScalingClient, ec2Client); - String logicalId = "fake-logical-id"; when(autoScalingClient.describeAutoScalingGroups( @@ -109,4 +118,58 @@ public void testGetPublicDnsForAutoScalingGroupNoInstancesFound() { assertEquals(0, results.size()); } + @Test + public void testIncrementMinInstancesForAsgHappy() { + + String logicalId = "asg id"; + String instanceId = "instance id"; + int minSize = 2; + + when(autoScalingClient.describeAutoScalingGroups( + new DescribeAutoScalingGroupsRequest().withAutoScalingGroupNames(logicalId)) + ).thenReturn( + new DescribeAutoScalingGroupsResult() + .withAutoScalingGroups( + new AutoScalingGroup().withInstances( + new Instance().withInstanceId(instanceId)) + .withMinSize(minSize) + ) + ); + + autoScalingService.updateMinInstancesForAutoScalingGroup(logicalId, minSize - 1); + + verify(autoScalingClient).updateAutoScalingGroup(new UpdateAutoScalingGroupRequest() + .withAutoScalingGroupName(logicalId) + .withMinSize(minSize - 1)); + } + + @Test + public void testSetInstanceStateToInServiceHappy() { + + String logicalId = "asg id"; + String instanceId = "instance id"; + + autoScalingService.setInstanceStateToInService(logicalId, instanceId); + + verify(autoScalingClient).exitStandby(new ExitStandbyRequest() + .withAutoScalingGroupName(logicalId) + .withInstanceIds(instanceId) + ); + } + + @Test + public void testSetInstanceStateToStandbyHappy() { + + String logicalId = "asg id"; + String instanceId = "instance id"; + + autoScalingService.setInstanceStateToStandby(logicalId, instanceId); + + verify(autoScalingClient).enterStandby(new EnterStandbyRequest() + .withAutoScalingGroupName(logicalId) + .withInstanceIds(instanceId) + .withShouldDecrementDesiredCapacity(true) + ); + } + } \ No newline at end of file diff --git a/src/test/java/com/nike/cerberus/service/Ec2ServiceTest.java b/src/test/java/com/nike/cerberus/service/Ec2ServiceTest.java index 28c32033..19976453 100644 --- a/src/test/java/com/nike/cerberus/service/Ec2ServiceTest.java +++ b/src/test/java/com/nike/cerberus/service/Ec2ServiceTest.java @@ -21,30 +21,49 @@ import com.amazonaws.services.ec2.model.AvailabilityZone; import com.amazonaws.services.ec2.model.AvailabilityZoneState; import com.amazonaws.services.ec2.model.DescribeAvailabilityZonesResult; +import com.amazonaws.services.ec2.model.DescribeInstancesRequest; +import com.amazonaws.services.ec2.model.DescribeInstancesResult; import com.amazonaws.services.ec2.model.DescribeKeyPairsRequest; import com.amazonaws.services.ec2.model.DescribeKeyPairsResult; +import com.amazonaws.services.ec2.model.Filter; import com.amazonaws.services.ec2.model.ImportKeyPairRequest; import com.amazonaws.services.ec2.model.ImportKeyPairResult; +import com.amazonaws.services.ec2.model.Instance; import com.amazonaws.services.ec2.model.KeyPairInfo; +import com.amazonaws.services.ec2.model.RebootInstancesRequest; +import com.amazonaws.services.ec2.model.Reservation; +import org.junit.Before; import org.junit.Test; import java.util.List; +import static com.nike.cerberus.service.Ec2Service.FILTER_NAME_TEMPL_FOR_EC2_TAGS; +import static com.nike.cerberus.service.Ec2Service.INSTANCE_STATE_FILTER_NAME; +import static com.nike.cerberus.service.Ec2Service.INSTANCE_STATE_RUNNING_FILTER_VALUE; import static org.junit.Assert.assertEquals; import static org.junit.Assert.assertFalse; import static org.junit.Assert.assertTrue; import static org.junit.Assert.fail; import static org.mockito.Mockito.mock; +import static org.mockito.Mockito.verify; import static org.mockito.Mockito.when; public class Ec2ServiceTest { + private AmazonEC2 ec2Client; + + private Ec2Service ec2Service; + + @Before + public void setup() { + ec2Client = mock(AmazonEC2.class); + + ec2Service = new Ec2Service(ec2Client); + } + @Test public void testImportKey() { - AmazonEC2 ec2Client = mock(AmazonEC2.class); - Ec2Service ec2Service = new Ec2Service(ec2Client); - String keyName = "key-name"; String publicKeyMaterial = "public-key-material"; String keyNameResult = "key-name-result"; @@ -60,8 +79,6 @@ public void testImportKey() { @Test public void testIsKeyPairPresentTrue() { - AmazonEC2 ec2Client = mock(AmazonEC2.class); - Ec2Service ec2Service = new Ec2Service(ec2Client); String keyName = "key-name"; @@ -82,8 +99,6 @@ public void testIsKeyPairPresentTrue() { @Test public void testIsKeyPairPresentFalse() { - AmazonEC2 ec2Client = mock(AmazonEC2.class); - Ec2Service ec2Service = new Ec2Service(ec2Client); String keyName = "key-name"; @@ -96,8 +111,6 @@ public void testIsKeyPairPresentFalse() { @Test public void testIsKeyPairPresentFalseNotFound() { - AmazonEC2 ec2Client = mock(AmazonEC2.class); - Ec2Service ec2Service = new Ec2Service(ec2Client); String keyName = "key-name"; @@ -113,8 +126,6 @@ public void testIsKeyPairPresentFalseNotFound() { @Test public void testIsKeyPairPresentException() { - AmazonEC2 ec2Client = mock(AmazonEC2.class); - Ec2Service ec2Service = new Ec2Service(ec2Client); String keyName = "key-name"; String fakeExceptionMessage = "fake-exception"; @@ -134,8 +145,6 @@ public void testIsKeyPairPresentException() { @Test public void testGetAvailabilityZones() { - AmazonEC2 ec2Client = mock(AmazonEC2.class); - Ec2Service ec2Service = new Ec2Service(ec2Client); String zoneName = "zone-name"; @@ -158,4 +167,43 @@ public void testGetAvailabilityZones() { assertEquals(zoneName, results.get(0)); } + @Test + public void testGetInstancesByTagHappy() { + + String tagKey = "tag key"; + String tagValue = "tag value"; + Filter filter = new Filter().withName(INSTANCE_STATE_FILTER_NAME).withValues(INSTANCE_STATE_RUNNING_FILTER_VALUE); + Instance instance = mock(Instance.class); + + when(ec2Client.describeInstances(new DescribeInstancesRequest() + .withFilters( + filter, + new Filter() + .withName(String.format(FILTER_NAME_TEMPL_FOR_EC2_TAGS, tagKey)) + .withValues(tagValue) + ) + )).thenReturn( + new DescribeInstancesResult() + .withReservations( + new Reservation() + .withInstances(instance)) + ); + + List instances = ec2Service.getInstancesByTag(tagKey, tagValue, filter); + + assertTrue(instances.contains(instance)); + } + + @Test + public void testRebootInstancesHappy() { + + String instanceId = "instance id"; + + ec2Service.rebootEc2Instance(instanceId); + + verify(ec2Client).rebootInstances(new RebootInstancesRequest() + .withInstanceIds(instanceId) + ); + } + } \ No newline at end of file