diff --git a/.gitignore b/.gitignore index b0683ab..460fade 100644 --- a/.gitignore +++ b/.gitignore @@ -1,4 +1,7 @@ ### Go ### +# IDE +.idea + # Binaries for programs and plugins *.exe *.exe~ diff --git a/cmd/main.go b/cmd/main.go index b6067c3..f537a72 100644 --- a/cmd/main.go +++ b/cmd/main.go @@ -18,6 +18,7 @@ import ( "log" "os" "strings" + "sync" commandline "github.com/aws/amazon-ec2-instance-selector/v2/pkg/cli" "github.com/aws/amazon-ec2-instance-selector/v2/pkg/selector" @@ -187,11 +188,29 @@ Full docs can be found at github.com/aws/amazon-` + binName flags[region] = sess.Config.Region instanceSelector := selector.New(sess) - if _, ok := flags[pricePerHour]; ok { - if flags[usageClass] == nil || *flags[usageClass].(*string) == "on-demand" { - instanceSelector.EC2Pricing.HydrateOndemandCache() + outputFlag := cli.StringMe(flags[output]) + if outputFlag != nil && *outputFlag == tableWideOutput { + // If output type is `table-wide`, simply print both prices for better comparison, + // even if the actual filter is applied on any one of those based on usage class + + // Save time by hydrating in parallel + wg := &sync.WaitGroup{} + wg.Add(2) + go func(waitGroup *sync.WaitGroup) { + defer waitGroup.Done() + _ = instanceSelector.EC2Pricing.HydrateOndemandCache() + }(wg) + go func(waitGroup *sync.WaitGroup) { + defer waitGroup.Done() + _ = instanceSelector.EC2Pricing.HydrateSpotCache(30) + }(wg) + wg.Wait() + } else if flags[pricePerHour] != nil { + // Else, if price filters are applied, only hydrate the respective cache as we don't have to print the prices + if flags[usageClass] == nil || *cli.StringMe(flags[usageClass]) == "on-demand" { + _ = instanceSelector.EC2Pricing.HydrateOndemandCache() } else { - instanceSelector.EC2Pricing.HydrateSpotCache(30) + _ = instanceSelector.EC2Pricing.HydrateSpotCache(30) } } @@ -252,7 +271,6 @@ Full docs can be found at github.com/aws/amazon-` + binName } } - outputFlag := cli.StringMe(flags[output]) outputFn := getOutputFn(outputFlag, selector.InstanceTypesOutputFn(resultsOutputFn)) instanceTypes, itemsTruncated, err := instanceSelector.FilterWithOutput(filters, outputFn) diff --git a/go.mod b/go.mod index d31add8..67087d1 100644 --- a/go.mod +++ b/go.mod @@ -9,14 +9,16 @@ require ( github.com/hashicorp/hcl v1.0.0 github.com/imdario/mergo v0.3.11 github.com/mitchellh/go-homedir v1.1.0 - github.com/smartystreets/goconvey v1.6.4 // indirect github.com/spf13/cobra v0.0.7 github.com/spf13/pflag v1.0.3 + go.uber.org/multierr v1.1.0 gopkg.in/ini.v1 v1.57.0 ) require ( github.com/inconshreveable/mousetrap v1.0.0 // indirect github.com/jmespath/go-jmespath v0.4.0 // indirect + github.com/smartystreets/goconvey v1.6.4 // indirect + go.uber.org/atomic v1.4.0 // indirect gopkg.in/yaml.v2 v2.3.0 // indirect ) diff --git a/go.sum b/go.sum index 385017f..eae6bc0 100644 --- a/go.sum +++ b/go.sum @@ -109,13 +109,16 @@ github.com/spf13/pflag v1.0.3/go.mod h1:DYY7MBk1bdzusC3SYhjObp+wFpr4gzcvqqNjLnIn github.com/spf13/viper v1.4.0/go.mod h1:PTJ7Z/lr49W6bUbkmS1V3by4uWynFiR9p7+dSq/yZzE= github.com/stretchr/objx v0.1.0/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME= github.com/stretchr/objx v0.1.1/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME= +github.com/stretchr/testify v1.2.2 h1:bSDNvY7ZPG5RlJ8otE/7V6gMiyenm9RtJ7IUVIAoJ1w= github.com/stretchr/testify v1.2.2/go.mod h1:a8OnRcib4nhh0OaRAV+Yts87kKdq0PP7pXfy6kDkUVs= github.com/tmc/grpc-websocket-proxy v0.0.0-20190109142713-0ad062ec5ee5/go.mod h1:ncp9v5uamzpCO7NfCPTXjqaC+bZgJeR0sMTm6dMHP7U= github.com/ugorji/go v1.1.4/go.mod h1:uQMGLiO92mf5W77hV/PUCpI3pbzQx3CRekS0kk+RGrc= github.com/xiang90/probing v0.0.0-20190116061207-43a291ad63a2/go.mod h1:UETIi67q53MR2AWcXfiuqkDkRtnGDLqkBTpCHuJHxtU= github.com/xordataexchange/crypt v0.0.3-0.20170626215501-b2862e3d0a77/go.mod h1:aYKd//L2LvnjZzWKhF00oedf4jCCReLcmhLdhm1A27Q= go.etcd.io/bbolt v1.3.2/go.mod h1:IbVyRI1SCnLcuJnV2u8VeU0CEYM7e686BmAb1XKL+uU= +go.uber.org/atomic v1.4.0 h1:cxzIVoETapQEqDhQu3QfnvXAV4AlzcvUCxkVUFw3+EU= go.uber.org/atomic v1.4.0/go.mod h1:gD2HeocX3+yG+ygLZcrzQJaqmWj9AIm7n08wl/qW/PE= +go.uber.org/multierr v1.1.0 h1:HoEmRHQPVSqub6w2z2d2EOVs2fjyFRGyofhKuyDq0QI= go.uber.org/multierr v1.1.0/go.mod h1:wR5kodmAFQ0UK8QlbwjlSNy0Z68gJhDJUG5sjR94q/0= go.uber.org/zap v1.10.0/go.mod h1:vwi/ZaCAaUcBkycHslxD9B2zi4UTXhF60s6SWpuDF0Q= golang.org/x/crypto v0.0.0-20180904163835-0709b304e793/go.mod h1:6SG95UA2DQfeDnfUPMdvaQW0Q7yPrPDi9nlGo2tz2b4= diff --git a/pkg/ec2pricing/ec2pricing.go b/pkg/ec2pricing/ec2pricing.go index 3c96638..efba374 100644 --- a/pkg/ec2pricing/ec2pricing.go +++ b/pkg/ec2pricing/ec2pricing.go @@ -8,8 +8,10 @@ import ( "strings" "time" - "github.com/aws/aws-sdk-go/aws" "github.com/aws/aws-sdk-go/aws/endpoints" + "go.uber.org/multierr" + + "github.com/aws/aws-sdk-go/aws" "github.com/aws/aws-sdk-go/aws/session" "github.com/aws/aws-sdk-go/service/ec2" "github.com/aws/aws-sdk-go/service/ec2/ec2iface" @@ -25,19 +27,25 @@ const ( // EC2Pricing is the public struct to interface with AWS pricing APIs type EC2Pricing struct { - PricingClient pricingiface.PricingAPI - EC2Client ec2iface.EC2API - AWSSession *session.Session - cache map[string]float64 - spotCache map[string]map[string][]spotPricingEntry + PricingClient pricingiface.PricingAPI + EC2Client ec2iface.EC2API + AWSSession *session.Session + onDemandCache map[string]float64 + spotCache map[string]map[string][]spotPricingEntry + lastOnDemandCacheUTC *time.Time // Updated on successful cache write + lastSpotCacheUTC *time.Time // Updated on successful cache write } // EC2PricingIface is the EC2Pricing interface mainly used to mock out ec2pricing during testing type EC2PricingIface interface { GetOndemandInstanceTypeCost(instanceType string) (float64, error) GetSpotInstanceTypeNDayAvgCost(instanceType string, availabilityZones []string, days int) (float64, error) + // Keep hydrate functions thread safe by keeping different write data points + // In simple words, make sure they don't write the same variable/file/row etc. which they don't (they have different cache maps) HydrateOndemandCache() error HydrateSpotCache(days int) error + LastOnDemandCacheUTC() *time.Time + LastSpotCacheUTC() *time.Time } type spotPricingEntry struct { @@ -49,12 +57,26 @@ type spotPricingEntry struct { func New(sess *session.Session) *EC2Pricing { return &EC2Pricing{ // use us-east-1 since pricing only has endpoints in us-east-1 and ap-south-1 - PricingClient: pricing.New(sess.Copy(aws.NewConfig().WithRegion("us-east-1"))), - EC2Client: ec2.New(sess), - AWSSession: sess, + PricingClient: pricing.New(sess.Copy(aws.NewConfig().WithRegion("us-east-1"))), + EC2Client: ec2.New(sess), + AWSSession: sess, + lastOnDemandCacheUTC: nil, + lastSpotCacheUTC: nil, } } +// LastOnDemandCacheUTC returns the UTC timestamp when the onDemandCache was last refreshed +// Returns nil if the onDemandCache has not been initialized +func (p *EC2Pricing) LastOnDemandCacheUTC() *time.Time { + return p.lastOnDemandCacheUTC +} + +// LastSpotCacheUTC returns the UTC timestamp when the spotCache was last refreshed +// Returns nil if the spotCache has not been initialized +func (p *EC2Pricing) LastSpotCacheUTC() *time.Time { + return p.lastSpotCacheUTC +} + // GetSpotInstanceTypeNDayAvgCost retrieves the spot price history for a given AZ from the past N days and averages the price // Passing an empty list for availabilityZones will retrieve avg cost for all AZs in the current AWSSession's region func (p *EC2Pricing) GetSpotInstanceTypeNDayAvgCost(instanceType string, availabilityZones []string, days int) (float64, error) { @@ -67,16 +89,19 @@ func (p *EC2Pricing) GetSpotInstanceTypeNDayAvgCost(instanceType string, availab EndTime: &endTime, InstanceTypes: []*string{&instanceType}, } - zoneToPriceEntries := map[string][]spotPricingEntry{} + zoneToPriceEntries := make(map[string][]spotPricingEntry) if _, ok := p.spotCache[instanceType]; !ok { var processingErr error - err := p.EC2Client.DescribeSpotPriceHistoryPages(&spotPriceHistInput, func(dspho *ec2.DescribeSpotPriceHistoryOutput, b bool) bool { + errAPI := p.EC2Client.DescribeSpotPriceHistoryPages(&spotPriceHistInput, func(dspho *ec2.DescribeSpotPriceHistoryOutput, b bool) bool { for _, history := range dspho.SpotPriceHistory { var spotPrice float64 - spotPrice, processingErr = strconv.ParseFloat(*history.SpotPrice, 64) + spotPrice, errParse := strconv.ParseFloat(*history.SpotPrice, 64) + if errParse != nil { + processingErr = multierr.Append(processingErr, errParse) + continue + } zone := *history.AvailabilityZone - zoneToPriceEntries[zone] = append(zoneToPriceEntries[zone], spotPricingEntry{ Timestamp: *history.Timestamp, SpotPrice: spotPrice, @@ -84,11 +109,11 @@ func (p *EC2Pricing) GetSpotInstanceTypeNDayAvgCost(instanceType string, availab } return true }) - if err != nil { - return float64(0), err + if errAPI != nil { + return float64(-1), errAPI } if processingErr != nil { - return float64(0), processingErr + return float64(-1), processingErr } } else { for zone, priceEntries := range p.spotCache[instanceType] { @@ -113,7 +138,7 @@ func (p *EC2Pricing) GetSpotInstanceTypeNDayAvgCost(instanceType string, availab aggregateZonePriceSum += p.calculateSpotAggregate(priceEntries) } - return (aggregateZonePriceSum / float64(numOfZones)), nil + return aggregateZonePriceSum / float64(numOfZones), nil } func (p *EC2Pricing) calculateSpotAggregate(spotPriceEntries []spotPricingEntry) float64 { @@ -134,11 +159,16 @@ func (p *EC2Pricing) calculateSpotAggregate(spotPriceEntries []spotPricingEntry) duration := spotPriceEntries[int(math.Max(float64(i-1), 0))].Timestamp.Sub(entry.Timestamp).Minutes() priceSum += duration * entry.SpotPrice } - return (priceSum / totalDuration) + return priceSum / totalDuration } // GetOndemandInstanceTypeCost retrieves the on-demand hourly cost for the specified instance type func (p *EC2Pricing) GetOndemandInstanceTypeCost(instanceType string) (float64, error) { + // Check cache first and return it if available + if price, ok := p.onDemandCache[instanceType]; ok { + return price, nil + } + regionDescription := p.getRegionForPricingAPI() // TODO: mac.metal instances cannot be found with the below filters productInput := pricing.GetProductsInput{ @@ -154,25 +184,25 @@ func (p *EC2Pricing) GetOndemandInstanceTypeCost(instanceType string) (float64, }, } - // Check cache first and return it if available - if price, ok := p.cache[instanceType]; ok { - return price, nil - } - pricePerUnitInUSD := float64(-1) - err := p.PricingClient.GetProductsPages(&productInput, func(pricingOutput *pricing.GetProductsOutput, nextPage bool) bool { - var err error + var processingErr error + errAPI := p.PricingClient.GetProductsPages(&productInput, func(pricingOutput *pricing.GetProductsOutput, nextPage bool) bool { + var errParse error for _, priceDoc := range pricingOutput.PriceList { - _, pricePerUnitInUSD, err = parseOndemandUnitPrice(priceDoc) - } - if err != nil { - // keep going through pages if we can't parse the pricing doc - return true + _, pricePerUnitInUSD, errParse = parseOndemandUnitPrice(priceDoc) + if errParse != nil { + processingErr = multierr.Append(processingErr, errParse) + // keep going through pages if we can't parse the pricing doc + return true + } } return false }) - if err != nil { - return -1, err + if errAPI != nil { + return -1, errAPI + } + if processingErr != nil { + return -1, processingErr } return pricePerUnitInUSD, nil } @@ -182,7 +212,7 @@ func (p *EC2Pricing) GetOndemandInstanceTypeCost(instanceType string) (float64, // There is no TTL on cache entries // You'll only want to use this if you don't mind a long startup time (around 30 seconds) and will query the cache often after that. func (p *EC2Pricing) HydrateSpotCache(days int) error { - newCache := map[string]map[string][]spotPricingEntry{} + newCache := make(map[string]map[string][]spotPricingEntry) endTime := time.Now().UTC() startTime := endTime.Add(time.Hour * time.Duration(24*-1*days)) @@ -192,14 +222,17 @@ func (p *EC2Pricing) HydrateSpotCache(days int) error { EndTime: &endTime, } var processingErr error - err := p.EC2Client.DescribeSpotPriceHistoryPages(&spotPriceHistInput, func(dspho *ec2.DescribeSpotPriceHistoryOutput, b bool) bool { + errAPI := p.EC2Client.DescribeSpotPriceHistoryPages(&spotPriceHistInput, func(dspho *ec2.DescribeSpotPriceHistoryOutput, b bool) bool { for _, history := range dspho.SpotPriceHistory { - var spotPrice float64 - spotPrice, processingErr = strconv.ParseFloat(*history.SpotPrice, 64) + spotPrice, errFloat := strconv.ParseFloat(*history.SpotPrice, 64) + if errFloat != nil { + processingErr = multierr.Append(processingErr, errFloat) + continue + } instanceType := *history.InstanceType zone := *history.AvailabilityZone if _, ok := newCache[instanceType]; !ok { - newCache[instanceType] = map[string][]spotPricingEntry{} + newCache[instanceType] = make(map[string][]spotPricingEntry) } newCache[instanceType][zone] = append(newCache[instanceType][zone], spotPricingEntry{ Timestamp: *history.Timestamp, @@ -208,10 +241,12 @@ func (p *EC2Pricing) HydrateSpotCache(days int) error { } return true }) - if err != nil { - return err + if errAPI != nil { + return errAPI } + cTime := time.Now().UTC() p.spotCache = newCache + p.lastSpotCacheUTC = &cTime return processingErr } @@ -219,9 +254,8 @@ func (p *EC2Pricing) HydrateSpotCache(days int) error { // If HydrateOndemandCache is called more than once, the cache will be fully refreshed // There is no TTL on cache entries func (p *EC2Pricing) HydrateOndemandCache() error { - if p.cache == nil { - p.cache = make(map[string]float64) - } + newOnDemandCache := make(map[string]float64) + regionDescription := p.getRegionForPricingAPI() productInput := pricing.GetProductsInput{ ServiceCode: aws.String(serviceCode), @@ -234,17 +268,25 @@ func (p *EC2Pricing) HydrateOndemandCache() error { {Type: aws.String(pricing.FilterTypeTermMatch), Field: aws.String("tenancy"), Value: aws.String("shared")}, }, } - err := p.PricingClient.GetProductsPages(&productInput, func(pricingOutput *pricing.GetProductsOutput, nextPage bool) bool { + var processingErr error + errAPI := p.PricingClient.GetProductsPages(&productInput, func(pricingOutput *pricing.GetProductsOutput, nextPage bool) bool { for _, priceDoc := range pricingOutput.PriceList { - instanceTypeName, price, err := parseOndemandUnitPrice(priceDoc) - if err != nil { + instanceTypeName, price, errParse := parseOndemandUnitPrice(priceDoc) + if errParse != nil { + processingErr = multierr.Append(processingErr, errParse) continue } - p.cache[instanceTypeName] = price + newOnDemandCache[instanceTypeName] = price } return true }) - return err + if errAPI != nil { + return errAPI + } + cTime := time.Now().UTC() + p.onDemandCache = newOnDemandCache + p.lastOnDemandCacheUTC = &cTime + return processingErr } // getRegionForPricingAPI attempts to retrieve the region description based on the AWS session used to create diff --git a/pkg/selector/outputs/outputs.go b/pkg/selector/outputs/outputs.go index 39aeb05..4a9b89e 100644 --- a/pkg/selector/outputs/outputs.go +++ b/pkg/selector/outputs/outputs.go @@ -199,10 +199,8 @@ func TableOutputWide(instanceTypeInfoSlice []instancetypes.Details) []string { w.Init(buf, 8, 8, 2, ' ', 0) defer w.Flush() - pricePerHourHeader := "On-Demand Price/Hr" - if instanceTypeInfoSlice[0].SpotPrice != nil { - pricePerHourHeader = "Spot Price/Hr (30 days)" - } + onDemandPricePerHourHeader := "On-Demand Price/Hr" + spotPricePerHourHeader := "Spot Price/Hr (30d avg)" headers := []interface{}{ "Instance Type", @@ -217,9 +215,10 @@ func TableOutputWide(instanceTypeInfoSlice []instancetypes.Details) []string { "GPUs", "GPU Mem (GiB)", "GPU Info", - pricePerHourHeader, + onDemandPricePerHourHeader, + spotPricePerHourHeader, } - separators := []interface{}{} + separators := make([]interface{}, 0) headerFormat := "" for _, header := range headers { @@ -249,17 +248,16 @@ func TableOutputWide(instanceTypeInfoSlice []instancetypes.Details) []string { } } - pricePerHour := instanceTypeInfo.OndemandPricePerHour - if instanceTypeInfo.SpotPrice != nil { - pricePerHour = instanceTypeInfo.SpotPrice + onDemandPricePerHourStr := "-Not Fetched-" + spotPricePerHourStr := "-Not Fetched-" + if instanceTypeInfo.OndemandPricePerHour != nil { + onDemandPricePerHourStr = fmt.Sprintf("$%s", formatFloat(*instanceTypeInfo.OndemandPricePerHour)) } - specifyPriceFilter := "-No Price Filter Specified-" - pricePerHourStr := specifyPriceFilter - if pricePerHour != nil { - pricePerHourStr = fmt.Sprintf("$%s", formatFloat(*pricePerHour)) + if instanceTypeInfo.SpotPrice != nil { + spotPricePerHourStr = fmt.Sprintf("$%s", formatFloat(*instanceTypeInfo.SpotPrice)) } - fmt.Fprintf(w, "\n%s\t%d\t%s\t%s\t%t\t%t\t%s\t%s\t%d\t%d\t%s\t%s\t%s\t", + fmt.Fprintf(w, "\n%s\t%d\t%s\t%s\t%t\t%t\t%s\t%s\t%d\t%d\t%s\t%s\t%s\t%s\t", *instanceTypeInfo.InstanceType, *instanceTypeInfo.VCpuInfo.DefaultVCpus, formatFloat(float64(*instanceTypeInfo.MemoryInfo.SizeInMiB)/1024.0), @@ -272,7 +270,8 @@ func TableOutputWide(instanceTypeInfoSlice []instancetypes.Details) []string { gpus, formatFloat(float64(gpuMemory)/1024.0), strings.Join(gpuType, ", "), - pricePerHourStr, + onDemandPricePerHourStr, + spotPricePerHourStr, ) } w.Flush() diff --git a/pkg/selector/selector.go b/pkg/selector/selector.go index d4d4fbb..247b7bf 100644 --- a/pkg/selector/selector.go +++ b/pkg/selector/selector.go @@ -157,7 +157,7 @@ func (itf Selector) rawFilter(filters Filters) ([]instancetypes.Details, error) if err != nil { return nil, err } - var locations []string + var locations, availabilityZones []string if filters.CPUArchitecture != nil && *filters.CPUArchitecture == cpuArchitectureAMD64 { *filters.CPUArchitecture = cpuArchitectureX8664 @@ -165,8 +165,8 @@ func (itf Selector) rawFilter(filters Filters) ([]instancetypes.Details, error) if filters.VirtualizationType != nil && *filters.VirtualizationType == virtualizationTypePV { *filters.VirtualizationType = virtualizationTypeParaVirtual } - if filters.AvailabilityZones != nil { + availabilityZones = *filters.AvailabilityZones locations = *filters.AvailabilityZones } else if filters.Region != nil { locations = []string{*filters.Region} @@ -186,24 +186,34 @@ func (itf Selector) rawFilter(filters Filters) ([]instancetypes.Details, error) instanceTypeName := *instanceTypeInfo.InstanceType instanceTypeCandidates[instanceTypeName] = &instancetypes.Details{InstanceTypeInfo: *instanceTypeInfo} isFpga := instanceTypeInfo.FpgaInfo != nil - instanceTypeHourlyPrice := float64(0.0) - if filters.PricePerHour != nil { - if filters.UsageClass != nil && *filters.UsageClass == "spot" { - azs := []string{} - if filters.AvailabilityZones != nil { - azs = *filters.AvailabilityZones - } - instanceTypeHourlyPrice, err = itf.EC2Pricing.GetSpotInstanceTypeNDayAvgCost(instanceTypeName, azs, 30) - if err != nil { - fmt.Printf("Could not retrieve 30 day avg spot price for instance type %s\n", instanceTypeName) - } - instanceTypeCandidates[instanceTypeName].SpotPrice = &instanceTypeHourlyPrice + var instanceTypeHourlyPriceForFilter float64 // Price used to filter based on usage class + var instanceTypeHourlyPriceOnDemand, instanceTypeHourlyPriceSpot *float64 + // If prices are fetched, populate the fields irrespective of the price filters + if itf.EC2Pricing.LastOnDemandCacheUTC() != nil { + price, err := itf.EC2Pricing.GetOndemandInstanceTypeCost(instanceTypeName) + if err != nil { + fmt.Printf("Could not retrieve instantaneous hourly on-demand price for instance type %s\n", instanceTypeName) + } else { + instanceTypeHourlyPriceOnDemand = &price + instanceTypeCandidates[instanceTypeName].OndemandPricePerHour = instanceTypeHourlyPriceOnDemand + } + } + if itf.EC2Pricing.LastSpotCacheUTC() != nil { + price, err := itf.EC2Pricing.GetSpotInstanceTypeNDayAvgCost(instanceTypeName, availabilityZones, 30) + if err != nil { + fmt.Printf("Could not retrieve 30 day avg hourly spot price for instance type %s\n", instanceTypeName) } else { - instanceTypeHourlyPrice, err = itf.EC2Pricing.GetOndemandInstanceTypeCost(instanceTypeName) - if err != nil { - fmt.Printf("Could not retrieve hourly price for instance type %s\n", instanceTypeName) - } - instanceTypeCandidates[instanceTypeName].OndemandPricePerHour = &instanceTypeHourlyPrice + instanceTypeHourlyPriceSpot = &price + instanceTypeCandidates[instanceTypeName].SpotPrice = instanceTypeHourlyPriceSpot + } + } + if filters.PricePerHour != nil { + // If price filter is present, prices should be already fetched + // If prices are not fetched, filter should fail and the corresponding error is already printed + if filters.UsageClass != nil && *filters.UsageClass == "spot" && instanceTypeHourlyPriceSpot != nil { + instanceTypeHourlyPriceForFilter = *instanceTypeHourlyPriceSpot + } else if instanceTypeHourlyPriceOnDemand != nil { + instanceTypeHourlyPriceForFilter = *instanceTypeHourlyPriceOnDemand } } @@ -231,7 +241,7 @@ func (itf Selector) rawFilter(filters Filters) ([]instancetypes.Details, error) networkPerformance: {filters.NetworkPerformance, getNetworkPerformance(instanceTypeInfo.NetworkInfo.NetworkPerformance)}, instanceTypes: {filters.InstanceTypes, instanceTypeInfo.InstanceType}, virtualizationType: {filters.VirtualizationType, instanceTypeInfo.SupportedVirtualizationTypes}, - pricePerHour: {filters.PricePerHour, &instanceTypeHourlyPrice}, + pricePerHour: {filters.PricePerHour, &instanceTypeHourlyPriceForFilter}, } if isInDenyList(filters.DenyList, instanceTypeName) || !isInAllowList(filters.AllowList, instanceTypeName) { diff --git a/pkg/selector/selector_test.go b/pkg/selector/selector_test.go index 2ae5104..28ea980 100644 --- a/pkg/selector/selector_test.go +++ b/pkg/selector/selector_test.go @@ -21,6 +21,7 @@ import ( "regexp" "strconv" "testing" + "time" "github.com/aws/amazon-ec2-instance-selector/v2/pkg/bytequantity" "github.com/aws/amazon-ec2-instance-selector/v2/pkg/selector" @@ -148,7 +149,8 @@ func TestNew(t *testing.T) { func TestFilterVerbose(t *testing.T) { ec2Mock := setupMock(t, describeInstanceTypesPages, "t3_micro.json") itf := selector.Selector{ - EC2: ec2Mock, + EC2: ec2Mock, + EC2Pricing: &ec2PricingMock{}, } filters := selector.Filters{ VCpusRange: &selector.IntRangeFilter{LowerBound: 2, UpperBound: 2}, @@ -162,7 +164,8 @@ func TestFilterVerbose(t *testing.T) { func TestFilterVerbose_NoResults(t *testing.T) { ec2Mock := setupMock(t, describeInstanceTypesPages, "t3_micro.json") itf := selector.Selector{ - EC2: ec2Mock, + EC2: ec2Mock, + EC2Pricing: &ec2PricingMock{}, } filters := selector.Filters{ VCpusRange: &selector.IntRangeFilter{LowerBound: 4, UpperBound: 4}, @@ -194,7 +197,8 @@ func TestFilterVerbose_AZFilteredIn(t *testing.T) { DescribeAvailabilityZonesResp: setupMock(t, describeAvailabilityZones, "us-east-2.json").DescribeAvailabilityZonesResp, } itf := selector.Selector{ - EC2: ec2Mock, + EC2: ec2Mock, + EC2Pricing: &ec2PricingMock{}, } filters := selector.Filters{ VCpusRange: &selector.IntRangeFilter{LowerBound: 2, UpperBound: 2}, @@ -213,7 +217,8 @@ func TestFilterVerbose_AZFilteredOut(t *testing.T) { DescribeAvailabilityZonesResp: setupMock(t, describeAvailabilityZones, "us-east-2.json").DescribeAvailabilityZonesResp, } itf := selector.Selector{ - EC2: ec2Mock, + EC2: ec2Mock, + EC2Pricing: &ec2PricingMock{}, } filters := selector.Filters{ AvailabilityZones: &[]string{"us-east-2a"}, @@ -239,7 +244,8 @@ func TestFilterVerboseAZ_FilteredErr(t *testing.T) { func TestFilterVerbose_Gpus(t *testing.T) { ec2Mock := setupMock(t, describeInstanceTypesPages, "t3_micro_and_p3_16xl.json") itf := selector.Selector{ - EC2: ec2Mock, + EC2: ec2Mock, + EC2Pricing: &ec2PricingMock{}, } gpuMemory, err := bytequantity.ParseToByteQuantity("128g") h.Ok(t, err) @@ -259,7 +265,8 @@ func TestFilterVerbose_Gpus(t *testing.T) { func TestFilter(t *testing.T) { ec2Mock := setupMock(t, describeInstanceTypesPages, "t3_micro.json") itf := selector.Selector{ - EC2: ec2Mock, + EC2: ec2Mock, + EC2Pricing: &ec2PricingMock{}, } filters := selector.Filters{ VCpusRange: &selector.IntRangeFilter{LowerBound: 2, UpperBound: 2}, @@ -273,7 +280,8 @@ func TestFilter(t *testing.T) { func TestFilter_MoreFilters(t *testing.T) { ec2Mock := setupMock(t, describeInstanceTypesPages, "t3_micro.json") itf := selector.Selector{ - EC2: ec2Mock, + EC2: ec2Mock, + EC2Pricing: &ec2PricingMock{}, } filters := selector.Filters{ VCpusRange: &selector.IntRangeFilter{LowerBound: 2, UpperBound: 2}, @@ -291,7 +299,8 @@ func TestFilter_MoreFilters(t *testing.T) { func TestFilter_TruncateToMaxResults(t *testing.T) { ec2Mock := setupMock(t, describeInstanceTypesPages, "25_instances.json") itf := selector.Selector{ - EC2: ec2Mock, + EC2: ec2Mock, + EC2Pricing: &ec2PricingMock{}, } filters := selector.Filters{ VCpusRange: &selector.IntRangeFilter{LowerBound: 0, UpperBound: 100}, @@ -323,7 +332,8 @@ func TestFilter_Failure(t *testing.T) { } itf := selector.Selector{ - EC2: ec2Mock, + EC2: ec2Mock, + EC2Pricing: &ec2PricingMock{}, } filters := selector.Filters{ VCpusRange: &selector.IntRangeFilter{LowerBound: 4, UpperBound: 4}, @@ -432,7 +442,8 @@ func TestFilter_InstanceTypeBase(t *testing.T) { DescribeInstanceTypeOfferingsResp: setupMock(t, describeInstanceTypeOfferings, "us-east-2a.json").DescribeInstanceTypeOfferingsResp, } itf := selector.Selector{ - EC2: ec2Mock, + EC2: ec2Mock, + EC2Pricing: &ec2PricingMock{}, } c4Large := "c4.large" filters := selector.Filters{ @@ -507,7 +518,8 @@ func TestFilter_AllowList(t *testing.T) { DescribeInstanceTypeOfferingsResp: setupMock(t, describeInstanceTypeOfferings, "us-east-2a.json").DescribeInstanceTypeOfferingsResp, } itf := selector.Selector{ - EC2: ec2Mock, + EC2: ec2Mock, + EC2Pricing: &ec2PricingMock{}, } allowRegex, err := regexp.Compile("c4.large") h.Ok(t, err) @@ -525,7 +537,8 @@ func TestFilter_DenyList(t *testing.T) { DescribeInstanceTypeOfferingsResp: setupMock(t, describeInstanceTypeOfferings, "us-east-2a.json").DescribeInstanceTypeOfferingsResp, } itf := selector.Selector{ - EC2: ec2Mock, + EC2: ec2Mock, + EC2Pricing: &ec2PricingMock{}, } denyRegex, err := regexp.Compile("c4.large") h.Ok(t, err) @@ -543,7 +556,8 @@ func TestFilter_AllowAndDenyList(t *testing.T) { DescribeInstanceTypeOfferingsResp: setupMock(t, describeInstanceTypeOfferings, "us-east-2a.json").DescribeInstanceTypeOfferingsResp, } itf := selector.Selector{ - EC2: ec2Mock, + EC2: ec2Mock, + EC2Pricing: &ec2PricingMock{}, } allowRegex, err := regexp.Compile("c4.*") h.Ok(t, err) @@ -561,7 +575,8 @@ func TestFilter_AllowAndDenyList(t *testing.T) { func TestFilter_X8664_AMD64(t *testing.T) { ec2Mock := setupMock(t, describeInstanceTypesPages, "t3_micro.json") itf := selector.Selector{ - EC2: ec2Mock, + EC2: ec2Mock, + EC2Pricing: &ec2PricingMock{}, } filters := selector.Filters{ CPUArchitecture: aws.String("amd64"), @@ -575,7 +590,8 @@ func TestFilter_X8664_AMD64(t *testing.T) { func TestFilter_VirtType_PV(t *testing.T) { ec2Mock := setupMock(t, describeInstanceTypesPages, "pv_instances.json") itf := selector.Selector{ - EC2: ec2Mock, + EC2: ec2Mock, + EC2Pricing: &ec2PricingMock{}, } filters := selector.Filters{ VirtualizationType: aws.String("pv"), @@ -599,6 +615,8 @@ type ec2PricingMock struct { GetSpotInstanceTypeNDayAvgCostErr error HydrateOndemandCacheErr error HydrateSpotCacheErr error + lastOnDemandCacheUTC *time.Time + lastSpotCacheUTC *time.Time } func (p *ec2PricingMock) GetOndemandInstanceTypeCost(instanceType string) (float64, error) { @@ -617,12 +635,22 @@ func (p *ec2PricingMock) HydrateSpotCache(days int) error { return p.HydrateSpotCacheErr } +func (p *ec2PricingMock) LastOnDemandCacheUTC() *time.Time { + return p.lastOnDemandCacheUTC +} + +func (p *ec2PricingMock) LastSpotCacheUTC() *time.Time { + return p.lastSpotCacheUTC +} + func TestFilter_PricePerHour(t *testing.T) { ec2Mock := setupMock(t, describeInstanceTypesPages, "t3_micro.json") + now := time.Now() itf := selector.Selector{ EC2: ec2Mock, EC2Pricing: &ec2PricingMock{ GetOndemandInstanceTypeCostResp: 0.0104, + lastOnDemandCacheUTC: &now, }, } filters := selector.Filters{ @@ -633,15 +661,17 @@ func TestFilter_PricePerHour(t *testing.T) { } results, err := itf.Filter(filters) h.Ok(t, err) - h.Assert(t, len(results) == 1, "Should return 1 instance type") + h.Assert(t, len(results) == 1, fmt.Sprintf("Should return 1 instance type; got %d", len(results))) } func TestFilter_PricePerHour_NoResults(t *testing.T) { ec2Mock := setupMock(t, describeInstanceTypesPages, "t3_micro.json") + now := time.Now() itf := selector.Selector{ EC2: ec2Mock, EC2Pricing: &ec2PricingMock{ GetOndemandInstanceTypeCostResp: 0.0104, + lastOnDemandCacheUTC: &now, }, } filters := selector.Filters{ @@ -657,10 +687,12 @@ func TestFilter_PricePerHour_NoResults(t *testing.T) { func TestFilter_PricePerHour_OD(t *testing.T) { ec2Mock := setupMock(t, describeInstanceTypesPages, "t3_micro.json") + now := time.Now() itf := selector.Selector{ EC2: ec2Mock, EC2Pricing: &ec2PricingMock{ GetOndemandInstanceTypeCostResp: 0.0104, + lastOnDemandCacheUTC: &now, }, } filters := selector.Filters{ @@ -672,15 +704,17 @@ func TestFilter_PricePerHour_OD(t *testing.T) { } results, err := itf.Filter(filters) h.Ok(t, err) - h.Assert(t, len(results) == 1, "Should return 1 instance type") + h.Assert(t, len(results) == 1, fmt.Sprintf("Should return 1 instance type; got %d", len(results))) } func TestFilter_PricePerHour_Spot(t *testing.T) { ec2Mock := setupMock(t, describeInstanceTypesPages, "t3_micro.json") + now := time.Now() itf := selector.Selector{ EC2: ec2Mock, EC2Pricing: &ec2PricingMock{ GetSpotInstanceTypeNDayAvgCostResp: 0.0104, + lastSpotCacheUTC: &now, }, } filters := selector.Filters{ @@ -692,5 +726,5 @@ func TestFilter_PricePerHour_Spot(t *testing.T) { } results, err := itf.Filter(filters) h.Ok(t, err) - h.Assert(t, len(results) == 1, "Should return 1 instance type") + h.Assert(t, len(results) == 1, fmt.Sprintf("Should return 1 instance type; got %d", len(results))) }