diff --git a/cmd/node-termination-handler.go b/cmd/node-termination-handler.go index 07df0a80..2e7c81da 100644 --- a/cmd/node-termination-handler.go +++ b/cmd/node-termination-handler.go @@ -179,15 +179,17 @@ func main() { } log.Debug().Msgf("AWS Credentials retrieved from provider: %s", creds.ProviderName) + completeLifecycleActionDelay := time.Duration(nthConfig.CompleteLifecycleActionDelaySeconds) * time.Second sqsMonitor := sqsevent.SQSMonitor{ - CheckIfManaged: nthConfig.CheckTagBeforeDraining, - ManagedTag: nthConfig.ManagedTag, - QueueURL: nthConfig.QueueURL, - InterruptionChan: interruptionChan, - CancelChan: cancelChan, - SQS: sqs.New(sess), - ASG: autoscaling.New(sess), - EC2: ec2.New(sess), + CheckIfManaged: nthConfig.CheckTagBeforeDraining, + ManagedTag: nthConfig.ManagedTag, + QueueURL: nthConfig.QueueURL, + InterruptionChan: interruptionChan, + CancelChan: cancelChan, + SQS: sqs.New(sess), + ASG: autoscaling.New(sess), + EC2: ec2.New(sess), + BeforeCompleteLifecycleAction: func() { <-time.After(completeLifecycleActionDelay) }, } monitoringFns[sqsEvents] = sqsMonitor } diff --git a/config/helm/aws-node-termination-handler/README.md b/config/helm/aws-node-termination-handler/README.md index acd4c308..5d8825af 100644 --- a/config/helm/aws-node-termination-handler/README.md +++ b/config/helm/aws-node-termination-handler/README.md @@ -82,6 +82,7 @@ The configuration in this table applies to all AWS Node Termination Handler mode | `podTerminationGracePeriod` | The time in seconds given to each pod to terminate gracefully. If negative, the default value specified in the pod will be used, which defaults to 30 seconds if not specified for the pod. | `-1` | | `nodeTerminationGracePeriod` | Period of time in seconds given to each node to terminate gracefully. Node draining will be scheduled based on this value to optimize the amount of compute time, but still safely drain the node before an event. | `120` | | `emitKubernetesEvents` | If `true`, Kubernetes events will be emitted when interruption events are received and when actions are taken on Kubernetes nodes. In IMDS Processor mode a default set of annotations with all the node metadata gathered from IMDS will be attached to each event. More information [here](https://github.com/aws/aws-node-termination-handler/blob/main/docs/kubernetes_events.md). | `false` | +| `completeLifecycleActionDelaySeconds` | Pause after draining the node before completing the EC2 Autoscaling lifecycle action. This may be helpful if Pods on the node have Persistent Volume Claims. | -1 | | `kubernetesEventsExtraAnnotations` | A comma-separated list of `key=value` extra annotations to attach to all emitted Kubernetes events (e.g. `first=annotation,sample.annotation/number=two"`). | `""` | | `webhookURL` | Posts event data to URL upon instance interruption action. | `""` | | `webhookURLSecretName` | Pass the webhook URL as a Secret using the key `webhookurl`. | `""` | diff --git a/config/helm/aws-node-termination-handler/templates/deployment.yaml b/config/helm/aws-node-termination-handler/templates/deployment.yaml index f6031945..548f00ab 100644 --- a/config/helm/aws-node-termination-handler/templates/deployment.yaml +++ b/config/helm/aws-node-termination-handler/templates/deployment.yaml @@ -106,6 +106,8 @@ spec: value: {{ .Values.nodeTerminationGracePeriod | quote }} - name: EMIT_KUBERNETES_EVENTS value: {{ .Values.emitKubernetesEvents | quote }} + - name: COMPLETE_LIFECYCLE_ACTION_DELAY_SECONDS + value: {{ .Values.completeLifecycleActionDelaySeconds | quote }} {{- with .Values.kubernetesEventsExtraAnnotations }} - name: KUBERNETES_EVENTS_EXTRA_ANNOTATIONS value: {{ . | quote }} diff --git a/config/helm/aws-node-termination-handler/values.yaml b/config/helm/aws-node-termination-handler/values.yaml index ca5db321..5e31b302 100644 --- a/config/helm/aws-node-termination-handler/values.yaml +++ b/config/helm/aws-node-termination-handler/values.yaml @@ -100,6 +100,9 @@ nodeTerminationGracePeriod: 120 # emitKubernetesEvents If true, Kubernetes events will be emitted when interruption events are received and when actions are taken on Kubernetes nodes. In IMDS Processor mode a default set of annotations with all the node metadata gathered from IMDS will be attached to each event emitKubernetesEvents: false +# completeLifecycleActionDelaySeconds will pause for the configured duration after draining the node before completing the EC2 Autoscaling lifecycle action. This may be helpful if Pods on the node have Persistent Volume Claims. +completeLifecycleActionDelaySeconds: -1 + # kubernetesEventsExtraAnnotations A comma-separated list of key=value extra annotations to attach to all emitted Kubernetes events # Example: "first=annotation,sample.annotation/number=two" kubernetesEventsExtraAnnotations: "" diff --git a/pkg/config/config.go b/pkg/config/config.go index 8d86e36c..938531e0 100644 --- a/pkg/config/config.go +++ b/pkg/config/config.go @@ -102,57 +102,59 @@ const ( awsRegionConfigKey = "AWS_REGION" awsEndpointConfigKey = "AWS_ENDPOINT" queueURLConfigKey = "QUEUE_URL" + completeLifecycleActionDelaySecondsKey = "COMPLETE_LIFECYCLE_ACTION_DELAY_SECONDS" ) -//Config arguments set via CLI, environment variables, or defaults +// Config arguments set via CLI, environment variables, or defaults type Config struct { - DryRun bool - NodeName string - PodName string - MetadataURL string - IgnoreDaemonSets bool - DeleteLocalData bool - KubernetesServiceHost string - KubernetesServicePort string - PodTerminationGracePeriod int - NodeTerminationGracePeriod int - WebhookURL string - WebhookHeaders string - WebhookTemplate string - WebhookTemplateFile string - WebhookProxy string - EnableScheduledEventDraining bool - EnableSpotInterruptionDraining bool - EnableSQSTerminationDraining bool - EnableRebalanceMonitoring bool - EnableRebalanceDraining bool - CheckASGTagBeforeDraining bool - CheckTagBeforeDraining bool - ManagedAsgTag string - ManagedTag string - MetadataTries int - CordonOnly bool - TaintNode bool - TaintEffect string - ExcludeFromLoadBalancers bool - JsonLogging bool - LogLevel string - UptimeFromFile string - EnablePrometheus bool - PrometheusPort int - EnableProbes bool - ProbesPort int - ProbesEndpoint string - EmitKubernetesEvents bool - KubernetesEventsExtraAnnotations string - AWSRegion string - AWSEndpoint string - QueueURL string - Workers int - UseProviderId bool + DryRun bool + NodeName string + PodName string + MetadataURL string + IgnoreDaemonSets bool + DeleteLocalData bool + KubernetesServiceHost string + KubernetesServicePort string + PodTerminationGracePeriod int + NodeTerminationGracePeriod int + WebhookURL string + WebhookHeaders string + WebhookTemplate string + WebhookTemplateFile string + WebhookProxy string + EnableScheduledEventDraining bool + EnableSpotInterruptionDraining bool + EnableSQSTerminationDraining bool + EnableRebalanceMonitoring bool + EnableRebalanceDraining bool + CheckASGTagBeforeDraining bool + CheckTagBeforeDraining bool + ManagedAsgTag string + ManagedTag string + MetadataTries int + CordonOnly bool + TaintNode bool + TaintEffect string + ExcludeFromLoadBalancers bool + JsonLogging bool + LogLevel string + UptimeFromFile string + EnablePrometheus bool + PrometheusPort int + EnableProbes bool + ProbesPort int + ProbesEndpoint string + EmitKubernetesEvents bool + KubernetesEventsExtraAnnotations string + AWSRegion string + AWSEndpoint string + QueueURL string + Workers int + UseProviderId bool + CompleteLifecycleActionDelaySeconds int } -//ParseCliArgs parses cli arguments and uses environment variables as fallback values +// ParseCliArgs parses cli arguments and uses environment variables as fallback values func ParseCliArgs() (config Config, err error) { var gracePeriod int defer func() { @@ -208,6 +210,7 @@ func ParseCliArgs() (config Config, err error) { flag.StringVar(&config.QueueURL, "queue-url", getEnv(queueURLConfigKey, ""), "Listens for messages on the specified SQS queue URL") flag.IntVar(&config.Workers, "workers", getIntEnv(workersConfigKey, workersDefault), "The amount of parallel event processors.") flag.BoolVar(&config.UseProviderId, "use-provider-id", getBoolEnv(useProviderIdConfigKey, useProviderIdDefault), "If true, fetch node name through Kubernetes node spec ProviderID instead of AWS event PrivateDnsHostname.") + flag.IntVar(&config.CompleteLifecycleActionDelaySeconds, "complete-lifecycle-action-delay-seconds", getIntEnv(completeLifecycleActionDelaySecondsKey, -1), "Delay completing the Autoscaling lifecycle action after a node has been drained.") flag.Parse() if isConfigProvided("pod-termination-grace-period", podTerminationGracePeriodConfigKey) && isConfigProvided("grace-period", gracePeriodConfigKey) { diff --git a/pkg/monitor/sqsevent/asg-lifecycle-event.go b/pkg/monitor/sqsevent/asg-lifecycle-event.go index ca4b4899..e5411cfb 100644 --- a/pkg/monitor/sqsevent/asg-lifecycle-event.go +++ b/pkg/monitor/sqsevent/asg-lifecycle-event.go @@ -91,7 +91,7 @@ func (m SQSMonitor) asgTerminationToInterruptionEvent(event *EventBridgeEvent, m } interruptionEvent.PostDrainTask = func(interruptionEvent monitor.InterruptionEvent, _ node.Node) error { - _, err := m.ASG.CompleteLifecycleAction(&autoscaling.CompleteLifecycleActionInput{ + _, err := m.completeLifecycleAction(&autoscaling.CompleteLifecycleActionInput{ AutoScalingGroupName: &lifecycleDetail.AutoScalingGroupName, LifecycleActionResult: aws.String("CONTINUE"), LifecycleHookName: &lifecycleDetail.LifecycleHookName, diff --git a/pkg/monitor/sqsevent/sqs-monitor.go b/pkg/monitor/sqsevent/sqs-monitor.go index 61e757b8..5ed440a8 100644 --- a/pkg/monitor/sqsevent/sqs-monitor.go +++ b/pkg/monitor/sqsevent/sqs-monitor.go @@ -21,6 +21,7 @@ import ( "github.com/aws/aws-node-termination-handler/pkg/monitor" "github.com/aws/aws-sdk-go/aws" "github.com/aws/aws-sdk-go/aws/awserr" + "github.com/aws/aws-sdk-go/service/autoscaling" "github.com/aws/aws-sdk-go/service/autoscaling/autoscalingiface" "github.com/aws/aws-sdk-go/service/ec2" "github.com/aws/aws-sdk-go/service/ec2/ec2iface" @@ -41,14 +42,15 @@ const ( // SQSMonitor is a struct definition that knows how to process events from Amazon EventBridge type SQSMonitor struct { - InterruptionChan chan<- monitor.InterruptionEvent - CancelChan chan<- monitor.InterruptionEvent - QueueURL string - SQS sqsiface.SQSAPI - ASG autoscalingiface.AutoScalingAPI - EC2 ec2iface.EC2API - CheckIfManaged bool - ManagedTag string + InterruptionChan chan<- monitor.InterruptionEvent + CancelChan chan<- monitor.InterruptionEvent + QueueURL string + SQS sqsiface.SQSAPI + ASG autoscalingiface.AutoScalingAPI + EC2 ec2iface.EC2API + CheckIfManaged bool + ManagedTag string + BeforeCompleteLifecycleAction func() } // InterruptionEventWrapper is a convenience wrapper for associating an interruption event with its error, if any @@ -283,6 +285,14 @@ func (m SQSMonitor) deleteMessages(messages []*sqs.Message) []error { return errs } +// completeLifecycleAction completes the lifecycle action after calling the "before" hook. +func (m SQSMonitor) completeLifecycleAction(input *autoscaling.CompleteLifecycleActionInput) (*autoscaling.CompleteLifecycleActionOutput, error) { + if m.BeforeCompleteLifecycleAction != nil { + m.BeforeCompleteLifecycleAction() + } + return m.ASG.CompleteLifecycleAction(input) +} + // NodeInfo is relevant information about a single node type NodeInfo struct { AsgName string diff --git a/pkg/monitor/sqsevent/sqs-monitor_test.go b/pkg/monitor/sqsevent/sqs-monitor_test.go index f08b8125..1db36a02 100644 --- a/pkg/monitor/sqsevent/sqs-monitor_test.go +++ b/pkg/monitor/sqsevent/sqs-monitor_test.go @@ -345,6 +345,51 @@ func TestMonitor_DrainTasks(t *testing.T) { } } +func TestMonitor_DrainTasks_Delay(t *testing.T) { + msg, err := getSQSMessageFromEvent(asgLifecycleEvent) + h.Ok(t, err) + + sqsMock := h.MockedSQS{ + ReceiveMessageResp: sqs.ReceiveMessageOutput{Messages: []*sqs.Message{&msg}}, + ReceiveMessageErr: nil, + DeleteMessageResp: sqs.DeleteMessageOutput{}, + } + dnsNodeName := "ip-10-0-0-157.us-east-2.compute.internal" + ec2Mock := h.MockedEC2{ + DescribeInstancesResp: getDescribeInstancesResp(dnsNodeName, true, true), + } + asgMock := h.MockedASG{ + CompleteLifecycleActionResp: autoscaling.CompleteLifecycleActionOutput{}, + } + drainChan := make(chan monitor.InterruptionEvent, 1) + + hookCalled := false + sqsMonitor := sqsevent.SQSMonitor{ + SQS: sqsMock, + EC2: ec2Mock, + ManagedTag: "aws-node-termination-handler/managed", + ASG: mockIsManagedTrue(&asgMock), + CheckIfManaged: true, + QueueURL: "https://test-queue", + InterruptionChan: drainChan, + BeforeCompleteLifecycleAction: func() { hookCalled = true }, + } + + err = sqsMonitor.Monitor() + h.Ok(t, err) + + t.Run(asgLifecycleEvent.DetailType, func(st *testing.T) { + result := <-drainChan + h.Equals(st, sqsevent.SQSTerminateKind, result.Kind) + h.Equals(st, result.NodeName, dnsNodeName) + h.Assert(st, result.PostDrainTask != nil, "PostDrainTask should have been set") + h.Assert(st, result.PreDrainTask != nil, "PreDrainTask should have been set") + err := result.PostDrainTask(result, node.Node{}) + h.Ok(st, err) + h.Assert(st, hookCalled, "BeforeCompleteLifecycleAction hook not called") + }) +} + func TestMonitor_DrainTasks_Errors(t *testing.T) { testEvents := []sqsevent.EventBridgeEvent{spotItnEvent, asgLifecycleEvent, {}, rebalanceRecommendationEvent} messages := make([]*sqs.Message, 0, len(testEvents))