Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
11 changes: 8 additions & 3 deletions cmd/node-termination-handler.go
Original file line number Diff line number Diff line change
Expand Up @@ -115,10 +115,15 @@ func main() {
}
imdsDisabled := nthConfig.EnableSQSTerminationDraining

imds := ec2metadata.New(nthConfig.MetadataURL, nthConfig.MetadataTries)

interruptionEventStore := interruptioneventstore.New(nthConfig)
nodeMetadata := imds.GetNodeMetadata(imdsDisabled)
var imds *ec2metadata.Service
var nodeMetadata ec2metadata.NodeMetadata

if !imdsDisabled {
imds = ec2metadata.New(nthConfig.MetadataURL, nthConfig.MetadataTries)
nodeMetadata = imds.GetNodeMetadata()
}

// Populate the aws region if available from node metadata and not already explicitly configured
if nthConfig.AWSRegion == "" && nodeMetadata.Region != "" {
nthConfig.AWSRegion = nodeMetadata.Region
Expand Down
44 changes: 21 additions & 23 deletions pkg/ec2metadata/ec2metadata.go
Original file line number Diff line number Diff line change
Expand Up @@ -325,32 +325,30 @@ func retry(attempts int, sleep time.Duration, httpReq func() (*http.Response, er
}

// GetNodeMetadata attempts to gather additional ec2 instance information from the metadata service
func (e *Service) GetNodeMetadata(imdsDisabled bool) NodeMetadata {
func (e *Service) GetNodeMetadata() NodeMetadata {
metadata := NodeMetadata{}
if !imdsDisabled {
identityDoc, err := e.GetMetadataInfo(IdentityDocPath)
if err != nil {
log.Err(err).Msg("Unable to fetch metadata from IMDS")
return metadata
}
err = json.NewDecoder(strings.NewReader(identityDoc)).Decode(&metadata)
if err != nil {
log.Warn().Msg("Unable to fetch instance identity document from ec2 metadata")
metadata.InstanceID, _ = e.GetMetadataInfo(InstanceIDPath)
metadata.InstanceType, _ = e.GetMetadataInfo(InstanceTypePath)
metadata.LocalIP, _ = e.GetMetadataInfo(LocalIPPath)
metadata.AvailabilityZone, _ = e.GetMetadataInfo(AZPlacementPath)
if len(metadata.AvailabilityZone) > 1 {
metadata.Region = metadata.AvailabilityZone[0 : len(metadata.AvailabilityZone)-1]
}
identityDoc, err := e.GetMetadataInfo(IdentityDocPath)
if err != nil {
log.Err(err).Msg("Unable to fetch metadata from IMDS")
return metadata
}
err = json.NewDecoder(strings.NewReader(identityDoc)).Decode(&metadata)
if err != nil {
log.Warn().Msg("Unable to fetch instance identity document from ec2 metadata")
metadata.InstanceID, _ = e.GetMetadataInfo(InstanceIDPath)
metadata.InstanceType, _ = e.GetMetadataInfo(InstanceTypePath)
metadata.LocalIP, _ = e.GetMetadataInfo(LocalIPPath)
metadata.AvailabilityZone, _ = e.GetMetadataInfo(AZPlacementPath)
if len(metadata.AvailabilityZone) > 1 {
metadata.Region = metadata.AvailabilityZone[0 : len(metadata.AvailabilityZone)-1]
}
metadata.InstanceLifeCycle, _ = e.GetMetadataInfo(InstanceLifeCycle)
metadata.LocalHostname, _ = e.GetMetadataInfo(LocalHostnamePath)
metadata.PublicHostname, _ = e.GetMetadataInfo(PublicHostnamePath)
metadata.PublicIP, _ = e.GetMetadataInfo(PublicIPPath)

log.Info().Interface("metadata", metadata).Msg("Startup Metadata Retrieved")
}
metadata.InstanceLifeCycle, _ = e.GetMetadataInfo(InstanceLifeCycle)
metadata.LocalHostname, _ = e.GetMetadataInfo(LocalHostnamePath)
metadata.PublicHostname, _ = e.GetMetadataInfo(PublicHostnamePath)
metadata.PublicIP, _ = e.GetMetadataInfo(PublicIPPath)

log.Info().Interface("metadata", metadata).Msg("Startup Metadata Retrieved")

return metadata
}
24 changes: 1 addition & 23 deletions pkg/ec2metadata/ec2metadata_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -580,7 +580,7 @@ func TestGetNodeMetadata(t *testing.T) {

// Use URL from our local test server
imds := ec2metadata.New(server.URL, 1)
nodeMetadata := imds.GetNodeMetadata(false)
nodeMetadata := imds.GetNodeMetadata()

h.Assert(t, nodeMetadata.AccountId == "", `AccountId should be empty string (only present in SQS events)`)
h.Assert(t, nodeMetadata.InstanceID == `metadata`, `Missing required NodeMetadata field InstanceID`)
Expand All @@ -593,25 +593,3 @@ func TestGetNodeMetadata(t *testing.T) {
h.Assert(t, nodeMetadata.AvailabilityZone == `metadata`, `Missing required NodeMetadata field AvailabilityZone`)
h.Assert(t, nodeMetadata.Region == `metadat`, `Region should equal AvailabilityZone with the final character truncated`)
}

func TestGetNodeMetadataWithIMDSDisabled(t *testing.T) {
server := httptest.NewServer(http.HandlerFunc(func(rw http.ResponseWriter, req *http.Request) {
h.Ok(t, fmt.Errorf("IMDS was called when using Queue Processor mode"))
}))
defer server.Close()

// Use URL from our local test server that throws errors when called
imds := ec2metadata.New(server.URL, 1)
nodeMetadata := imds.GetNodeMetadata(true)

h.Assert(t, nodeMetadata.AccountId == "", "AccountId should be empty string")
h.Assert(t, nodeMetadata.InstanceID == "", "InstanceID should be empty string")
h.Assert(t, nodeMetadata.InstanceLifeCycle == "", "InstanceLifeCycle should be empty string")
h.Assert(t, nodeMetadata.InstanceType == "", "InstanceType should be empty string")
h.Assert(t, nodeMetadata.PublicHostname == "", "PublicHostname should be empty string")
h.Assert(t, nodeMetadata.PublicIP == "", "PublicIP should be empty string")
h.Assert(t, nodeMetadata.LocalHostname == "", "LocalHostname should be empty string")
h.Assert(t, nodeMetadata.LocalIP == "", "LocalIP should be empty string")
h.Assert(t, nodeMetadata.AvailabilityZone == "", "AvailabilityZone should be empty string")
h.Assert(t, nodeMetadata.Region == "", "Region should be empty string")
}