@@ -14,13 +14,15 @@ import (
14
14
"net/http"
15
15
"net/url"
16
16
"os"
17
+ "path/filepath"
18
+ "runtime"
17
19
"strconv"
18
20
"strings"
19
21
"time"
20
22
21
23
"github.com/Azure/azure-sdk-for-go/sdk/azcore"
22
24
"github.com/Azure/azure-sdk-for-go/sdk/azcore/policy"
23
- "github.com/Azure/azure-sdk-for-go/sdk/azcore/runtime"
25
+ azruntime "github.com/Azure/azure-sdk-for-go/sdk/azcore/runtime"
24
26
"github.com/Azure/azure-sdk-for-go/sdk/azcore/streaming"
25
27
"github.com/Azure/azure-sdk-for-go/sdk/internal/log"
26
28
"github.com/AzureAD/microsoft-authentication-library-for-go/apps/confidential"
@@ -65,6 +67,18 @@ type managedIdentityClient struct {
65
67
probeIMDS bool
66
68
}
67
69
70
+ // arcKeyDirectory returns the directory expected to contain Azure Arc keys
71
+ var arcKeyDirectory = func () (string , error ) {
72
+ switch runtime .GOOS {
73
+ case "linux" :
74
+ return "/var/opt/azcmagent/tokens" , nil
75
+ case "windows" :
76
+ return filepath .Join (os .Getenv ("ProgramData" ), "AzureConnectedMachineAgent" , "Tokens" ), nil
77
+ default :
78
+ return "" , fmt .Errorf ("unsupported OS %q" , runtime .GOOS )
79
+ }
80
+ }
81
+
68
82
type wrappedNumber json.Number
69
83
70
84
func (n * wrappedNumber ) UnmarshalJSON (b []byte ) error {
@@ -152,8 +166,8 @@ func newManagedIdentityClient(options *ManagedIdentityCredentialOptions) (*manag
152
166
setIMDSRetryOptionDefaults (& cp .Retry )
153
167
}
154
168
155
- client , err := azcore .NewClient (module , version , runtime .PipelineOptions {
156
- Tracing : runtime .TracingOptions {
169
+ client , err := azcore .NewClient (module , version , azruntime .PipelineOptions {
170
+ Tracing : azruntime .TracingOptions {
157
171
Namespace : traceNamespace ,
158
172
},
159
173
}, & cp )
@@ -188,7 +202,7 @@ func (c *managedIdentityClient) authenticate(ctx context.Context, id ManagedIDKi
188
202
cx , cancel := context .WithTimeout (ctx , imdsProbeTimeout )
189
203
defer cancel ()
190
204
cx = policy .WithRetryOptions (cx , policy.RetryOptions {MaxRetries : - 1 })
191
- req , err := runtime .NewRequest (cx , http .MethodGet , c .endpoint )
205
+ req , err := azruntime .NewRequest (cx , http .MethodGet , c .endpoint )
192
206
if err == nil {
193
207
_ , err = c .azClient .Pipeline ().Do (req )
194
208
}
@@ -213,7 +227,7 @@ func (c *managedIdentityClient) authenticate(ctx context.Context, id ManagedIDKi
213
227
return azcore.AccessToken {}, newAuthenticationFailedError (credNameManagedIdentity , err .Error (), nil , err )
214
228
}
215
229
216
- if runtime .HasStatusCode (resp , http .StatusOK , http .StatusCreated ) {
230
+ if azruntime .HasStatusCode (resp , http .StatusOK , http .StatusCreated ) {
217
231
return c .createAccessToken (resp )
218
232
}
219
233
@@ -224,14 +238,14 @@ func (c *managedIdentityClient) authenticate(ctx context.Context, id ManagedIDKi
224
238
return azcore.AccessToken {}, newAuthenticationFailedError (credNameManagedIdentity , "the requested identity isn't assigned to this resource" , resp , nil )
225
239
}
226
240
msg := "failed to authenticate a system assigned identity"
227
- if body , err := runtime .Payload (resp ); err == nil && len (body ) > 0 {
241
+ if body , err := azruntime .Payload (resp ); err == nil && len (body ) > 0 {
228
242
msg += fmt .Sprintf (". The endpoint responded with %s" , body )
229
243
}
230
244
return azcore.AccessToken {}, newCredentialUnavailableError (credNameManagedIdentity , msg )
231
245
case http .StatusForbidden :
232
246
// Docker Desktop runs a proxy that responds 403 to IMDS token requests. If we get that response,
233
247
// we return credentialUnavailableError so credential chains continue to their next credential
234
- body , err := runtime .Payload (resp )
248
+ body , err := azruntime .Payload (resp )
235
249
if err == nil && strings .Contains (string (body ), "unreachable" ) {
236
250
return azcore.AccessToken {}, newCredentialUnavailableError (credNameManagedIdentity , fmt .Sprintf ("unexpected response %q" , string (body )))
237
251
}
@@ -249,7 +263,7 @@ func (c *managedIdentityClient) createAccessToken(res *http.Response) (azcore.Ac
249
263
ExpiresIn wrappedNumber `json:"expires_in,omitempty"` // this field should always return the number of seconds for which a token is valid
250
264
ExpiresOn interface {} `json:"expires_on,omitempty"` // the value returned in this field varies between a number and a date string
251
265
}{}
252
- if err := runtime .UnmarshalAsJSON (res , & value ); err != nil {
266
+ if err := azruntime .UnmarshalAsJSON (res , & value ); err != nil {
253
267
return azcore.AccessToken {}, fmt .Errorf ("internal AccessToken: %v" , err )
254
268
}
255
269
if value .ExpiresIn != "" {
@@ -299,7 +313,7 @@ func (c *managedIdentityClient) createAuthRequest(ctx context.Context, id Manage
299
313
}
300
314
301
315
func (c * managedIdentityClient ) createIMDSAuthRequest (ctx context.Context , id ManagedIDKind , scopes []string ) (* policy.Request , error ) {
302
- request , err := runtime .NewRequest (ctx , http .MethodGet , c .endpoint )
316
+ request , err := azruntime .NewRequest (ctx , http .MethodGet , c .endpoint )
303
317
if err != nil {
304
318
return nil , err
305
319
}
@@ -319,7 +333,7 @@ func (c *managedIdentityClient) createIMDSAuthRequest(ctx context.Context, id Ma
319
333
}
320
334
321
335
func (c * managedIdentityClient ) createAppServiceAuthRequest (ctx context.Context , id ManagedIDKind , scopes []string ) (* policy.Request , error ) {
322
- request , err := runtime .NewRequest (ctx , http .MethodGet , c .endpoint )
336
+ request , err := azruntime .NewRequest (ctx , http .MethodGet , c .endpoint )
323
337
if err != nil {
324
338
return nil , err
325
339
}
@@ -339,7 +353,7 @@ func (c *managedIdentityClient) createAppServiceAuthRequest(ctx context.Context,
339
353
}
340
354
341
355
func (c * managedIdentityClient ) createAzureMLAuthRequest (ctx context.Context , id ManagedIDKind , scopes []string ) (* policy.Request , error ) {
342
- request , err := runtime .NewRequest (ctx , http .MethodGet , c .endpoint )
356
+ request , err := azruntime .NewRequest (ctx , http .MethodGet , c .endpoint )
343
357
if err != nil {
344
358
return nil , err
345
359
}
@@ -362,7 +376,7 @@ func (c *managedIdentityClient) createAzureMLAuthRequest(ctx context.Context, id
362
376
}
363
377
364
378
func (c * managedIdentityClient ) createServiceFabricAuthRequest (ctx context.Context , id ManagedIDKind , scopes []string ) (* policy.Request , error ) {
365
- request , err := runtime .NewRequest (ctx , http .MethodGet , c .endpoint )
379
+ request , err := azruntime .NewRequest (ctx , http .MethodGet , c .endpoint )
366
380
if err != nil {
367
381
return nil , err
368
382
}
@@ -385,7 +399,7 @@ func (c *managedIdentityClient) createServiceFabricAuthRequest(ctx context.Conte
385
399
386
400
func (c * managedIdentityClient ) getAzureArcSecretKey (ctx context.Context , resources []string ) (string , error ) {
387
401
// create the request to retreive the secret key challenge provided by the HIMDS service
388
- request , err := runtime .NewRequest (ctx , http .MethodGet , c .endpoint )
402
+ request , err := azruntime .NewRequest (ctx , http .MethodGet , c .endpoint )
389
403
if err != nil {
390
404
return "" , err
391
405
}
@@ -407,22 +421,36 @@ func (c *managedIdentityClient) getAzureArcSecretKey(ctx context.Context, resour
407
421
}
408
422
header := response .Header .Get ("WWW-Authenticate" )
409
423
if len (header ) == 0 {
410
- return "" , errors .New ("did not receive a value from WWW-Authenticate header" )
424
+ return "" , errors .New ("response has no WWW-Authenticate header" )
411
425
}
412
426
// the WWW-Authenticate header is expected in the following format: Basic realm=/some/file/path.key
413
- pos := strings .LastIndex (header , "=" )
414
- if pos == - 1 {
415
- return "" , fmt .Errorf ("did not receive a correct value from WWW-Authenticate header: %s" , header )
427
+ _ , p , found := strings .Cut (header , "=" )
428
+ if ! found {
429
+ return "" , fmt .Errorf ("unexpected WWW-Authenticate header: %s" , header )
430
+ }
431
+ expected , err := arcKeyDirectory ()
432
+ if err != nil {
433
+ return "" , err
434
+ }
435
+ if filepath .Dir (p ) != expected || ! strings .HasSuffix (p , ".key" ) {
436
+ return "" , fmt .Errorf ("unexpected file path from HIMDS service: %s" , p )
437
+ }
438
+ f , err := os .Stat (p )
439
+ if err != nil {
440
+ return "" , fmt .Errorf ("could not stat %q: %v" , p , err )
441
+ }
442
+ if s := f .Size (); s > 4096 {
443
+ return "" , fmt .Errorf ("key is too large (%d bytes)" , s )
416
444
}
417
- key , err := os .ReadFile (header [ pos + 1 :] )
445
+ key , err := os .ReadFile (p )
418
446
if err != nil {
419
- return "" , fmt .Errorf ("could not read file (%s) contents : %v" , header [ pos + 1 :] , err )
447
+ return "" , fmt .Errorf ("could not read %q : %v" , p , err )
420
448
}
421
449
return string (key ), nil
422
450
}
423
451
424
452
func (c * managedIdentityClient ) createAzureArcAuthRequest (ctx context.Context , id ManagedIDKind , resources []string , key string ) (* policy.Request , error ) {
425
- request , err := runtime .NewRequest (ctx , http .MethodGet , c .endpoint )
453
+ request , err := azruntime .NewRequest (ctx , http .MethodGet , c .endpoint )
426
454
if err != nil {
427
455
return nil , err
428
456
}
@@ -444,7 +472,7 @@ func (c *managedIdentityClient) createAzureArcAuthRequest(ctx context.Context, i
444
472
}
445
473
446
474
func (c * managedIdentityClient ) createCloudShellAuthRequest (ctx context.Context , id ManagedIDKind , scopes []string ) (* policy.Request , error ) {
447
- request , err := runtime .NewRequest (ctx , http .MethodPost , c .endpoint )
475
+ request , err := azruntime .NewRequest (ctx , http .MethodPost , c .endpoint )
448
476
if err != nil {
449
477
return nil , err
450
478
}
0 commit comments