Skip to content

Commit 50774cd

Browse files
committed
managed identity bug fixes
1 parent 390f9a3 commit 50774cd

File tree

3 files changed

+106
-22
lines changed

3 files changed

+106
-22
lines changed

sdk/azidentity/CHANGELOG.md

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,9 @@
1111
* Removed `AzurePipelinesCredential` and the persistent token caching API.
1212
They will return in v1.7.0-beta.1
1313

14+
### Bugs Fixed
15+
* Managed identity bug fixes
16+
1417
## 1.6.0-beta.4 (2024-05-14)
1518

1619
### Features Added

sdk/azidentity/managed_identity_client.go

Lines changed: 49 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -14,13 +14,15 @@ import (
1414
"net/http"
1515
"net/url"
1616
"os"
17+
"path/filepath"
18+
"runtime"
1719
"strconv"
1820
"strings"
1921
"time"
2022

2123
"github.com/Azure/azure-sdk-for-go/sdk/azcore"
2224
"github.com/Azure/azure-sdk-for-go/sdk/azcore/policy"
23-
"github.com/Azure/azure-sdk-for-go/sdk/azcore/runtime"
25+
azruntime "github.com/Azure/azure-sdk-for-go/sdk/azcore/runtime"
2426
"github.com/Azure/azure-sdk-for-go/sdk/azcore/streaming"
2527
"github.com/Azure/azure-sdk-for-go/sdk/internal/log"
2628
"github.com/AzureAD/microsoft-authentication-library-for-go/apps/confidential"
@@ -65,6 +67,18 @@ type managedIdentityClient struct {
6567
probeIMDS bool
6668
}
6769

70+
// arcKeyDirectory returns the directory expected to contain Azure Arc keys
71+
var arcKeyDirectory = func() (string, error) {
72+
switch runtime.GOOS {
73+
case "linux":
74+
return "/var/opt/azcmagent/tokens", nil
75+
case "windows":
76+
return filepath.Join(os.Getenv("ProgramData"), "AzureConnectedMachineAgent", "Tokens"), nil
77+
default:
78+
return "", fmt.Errorf("unsupported OS %q", runtime.GOOS)
79+
}
80+
}
81+
6882
type wrappedNumber json.Number
6983

7084
func (n *wrappedNumber) UnmarshalJSON(b []byte) error {
@@ -152,8 +166,8 @@ func newManagedIdentityClient(options *ManagedIdentityCredentialOptions) (*manag
152166
setIMDSRetryOptionDefaults(&cp.Retry)
153167
}
154168

155-
client, err := azcore.NewClient(module, version, runtime.PipelineOptions{
156-
Tracing: runtime.TracingOptions{
169+
client, err := azcore.NewClient(module, version, azruntime.PipelineOptions{
170+
Tracing: azruntime.TracingOptions{
157171
Namespace: traceNamespace,
158172
},
159173
}, &cp)
@@ -188,7 +202,7 @@ func (c *managedIdentityClient) authenticate(ctx context.Context, id ManagedIDKi
188202
cx, cancel := context.WithTimeout(ctx, imdsProbeTimeout)
189203
defer cancel()
190204
cx = policy.WithRetryOptions(cx, policy.RetryOptions{MaxRetries: -1})
191-
req, err := runtime.NewRequest(cx, http.MethodGet, c.endpoint)
205+
req, err := azruntime.NewRequest(cx, http.MethodGet, c.endpoint)
192206
if err == nil {
193207
_, err = c.azClient.Pipeline().Do(req)
194208
}
@@ -213,7 +227,7 @@ func (c *managedIdentityClient) authenticate(ctx context.Context, id ManagedIDKi
213227
return azcore.AccessToken{}, newAuthenticationFailedError(credNameManagedIdentity, err.Error(), nil, err)
214228
}
215229

216-
if runtime.HasStatusCode(resp, http.StatusOK, http.StatusCreated) {
230+
if azruntime.HasStatusCode(resp, http.StatusOK, http.StatusCreated) {
217231
return c.createAccessToken(resp)
218232
}
219233

@@ -224,14 +238,14 @@ func (c *managedIdentityClient) authenticate(ctx context.Context, id ManagedIDKi
224238
return azcore.AccessToken{}, newAuthenticationFailedError(credNameManagedIdentity, "the requested identity isn't assigned to this resource", resp, nil)
225239
}
226240
msg := "failed to authenticate a system assigned identity"
227-
if body, err := runtime.Payload(resp); err == nil && len(body) > 0 {
241+
if body, err := azruntime.Payload(resp); err == nil && len(body) > 0 {
228242
msg += fmt.Sprintf(". The endpoint responded with %s", body)
229243
}
230244
return azcore.AccessToken{}, newCredentialUnavailableError(credNameManagedIdentity, msg)
231245
case http.StatusForbidden:
232246
// Docker Desktop runs a proxy that responds 403 to IMDS token requests. If we get that response,
233247
// we return credentialUnavailableError so credential chains continue to their next credential
234-
body, err := runtime.Payload(resp)
248+
body, err := azruntime.Payload(resp)
235249
if err == nil && strings.Contains(string(body), "unreachable") {
236250
return azcore.AccessToken{}, newCredentialUnavailableError(credNameManagedIdentity, fmt.Sprintf("unexpected response %q", string(body)))
237251
}
@@ -249,7 +263,7 @@ func (c *managedIdentityClient) createAccessToken(res *http.Response) (azcore.Ac
249263
ExpiresIn wrappedNumber `json:"expires_in,omitempty"` // this field should always return the number of seconds for which a token is valid
250264
ExpiresOn interface{} `json:"expires_on,omitempty"` // the value returned in this field varies between a number and a date string
251265
}{}
252-
if err := runtime.UnmarshalAsJSON(res, &value); err != nil {
266+
if err := azruntime.UnmarshalAsJSON(res, &value); err != nil {
253267
return azcore.AccessToken{}, fmt.Errorf("internal AccessToken: %v", err)
254268
}
255269
if value.ExpiresIn != "" {
@@ -299,7 +313,7 @@ func (c *managedIdentityClient) createAuthRequest(ctx context.Context, id Manage
299313
}
300314

301315
func (c *managedIdentityClient) createIMDSAuthRequest(ctx context.Context, id ManagedIDKind, scopes []string) (*policy.Request, error) {
302-
request, err := runtime.NewRequest(ctx, http.MethodGet, c.endpoint)
316+
request, err := azruntime.NewRequest(ctx, http.MethodGet, c.endpoint)
303317
if err != nil {
304318
return nil, err
305319
}
@@ -319,7 +333,7 @@ func (c *managedIdentityClient) createIMDSAuthRequest(ctx context.Context, id Ma
319333
}
320334

321335
func (c *managedIdentityClient) createAppServiceAuthRequest(ctx context.Context, id ManagedIDKind, scopes []string) (*policy.Request, error) {
322-
request, err := runtime.NewRequest(ctx, http.MethodGet, c.endpoint)
336+
request, err := azruntime.NewRequest(ctx, http.MethodGet, c.endpoint)
323337
if err != nil {
324338
return nil, err
325339
}
@@ -339,7 +353,7 @@ func (c *managedIdentityClient) createAppServiceAuthRequest(ctx context.Context,
339353
}
340354

341355
func (c *managedIdentityClient) createAzureMLAuthRequest(ctx context.Context, id ManagedIDKind, scopes []string) (*policy.Request, error) {
342-
request, err := runtime.NewRequest(ctx, http.MethodGet, c.endpoint)
356+
request, err := azruntime.NewRequest(ctx, http.MethodGet, c.endpoint)
343357
if err != nil {
344358
return nil, err
345359
}
@@ -362,7 +376,7 @@ func (c *managedIdentityClient) createAzureMLAuthRequest(ctx context.Context, id
362376
}
363377

364378
func (c *managedIdentityClient) createServiceFabricAuthRequest(ctx context.Context, id ManagedIDKind, scopes []string) (*policy.Request, error) {
365-
request, err := runtime.NewRequest(ctx, http.MethodGet, c.endpoint)
379+
request, err := azruntime.NewRequest(ctx, http.MethodGet, c.endpoint)
366380
if err != nil {
367381
return nil, err
368382
}
@@ -385,7 +399,7 @@ func (c *managedIdentityClient) createServiceFabricAuthRequest(ctx context.Conte
385399

386400
func (c *managedIdentityClient) getAzureArcSecretKey(ctx context.Context, resources []string) (string, error) {
387401
// create the request to retreive the secret key challenge provided by the HIMDS service
388-
request, err := runtime.NewRequest(ctx, http.MethodGet, c.endpoint)
402+
request, err := azruntime.NewRequest(ctx, http.MethodGet, c.endpoint)
389403
if err != nil {
390404
return "", err
391405
}
@@ -407,22 +421,36 @@ func (c *managedIdentityClient) getAzureArcSecretKey(ctx context.Context, resour
407421
}
408422
header := response.Header.Get("WWW-Authenticate")
409423
if len(header) == 0 {
410-
return "", errors.New("did not receive a value from WWW-Authenticate header")
424+
return "", errors.New("response has no WWW-Authenticate header")
411425
}
412426
// the WWW-Authenticate header is expected in the following format: Basic realm=/some/file/path.key
413-
pos := strings.LastIndex(header, "=")
414-
if pos == -1 {
415-
return "", fmt.Errorf("did not receive a correct value from WWW-Authenticate header: %s", header)
427+
_, p, found := strings.Cut(header, "=")
428+
if !found {
429+
return "", fmt.Errorf("unexpected WWW-Authenticate header: %s", header)
430+
}
431+
expected, err := arcKeyDirectory()
432+
if err != nil {
433+
return "", err
434+
}
435+
if filepath.Dir(p) != expected || !strings.HasSuffix(p, ".key") {
436+
return "", fmt.Errorf("unexpected file path from HIMDS service: %s", p)
437+
}
438+
f, err := os.Stat(p)
439+
if err != nil {
440+
return "", fmt.Errorf("could not stat %q: %v", p, err)
441+
}
442+
if s := f.Size(); s > 4096 {
443+
return "", fmt.Errorf("key is too large (%d bytes)", s)
416444
}
417-
key, err := os.ReadFile(header[pos+1:])
445+
key, err := os.ReadFile(p)
418446
if err != nil {
419-
return "", fmt.Errorf("could not read file (%s) contents: %v", header[pos+1:], err)
447+
return "", fmt.Errorf("could not read %q: %v", p, err)
420448
}
421449
return string(key), nil
422450
}
423451

424452
func (c *managedIdentityClient) createAzureArcAuthRequest(ctx context.Context, id ManagedIDKind, resources []string, key string) (*policy.Request, error) {
425-
request, err := runtime.NewRequest(ctx, http.MethodGet, c.endpoint)
453+
request, err := azruntime.NewRequest(ctx, http.MethodGet, c.endpoint)
426454
if err != nil {
427455
return nil, err
428456
}
@@ -444,7 +472,7 @@ func (c *managedIdentityClient) createAzureArcAuthRequest(ctx context.Context, i
444472
}
445473

446474
func (c *managedIdentityClient) createCloudShellAuthRequest(ctx context.Context, id ManagedIDKind, scopes []string) (*policy.Request, error) {
447-
request, err := runtime.NewRequest(ctx, http.MethodPost, c.endpoint)
475+
request, err := azruntime.NewRequest(ctx, http.MethodPost, c.endpoint)
448476
if err != nil {
449477
return nil, err
450478
}

sdk/azidentity/managed_identity_credential_test.go

Lines changed: 54 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@
77
package azidentity
88

99
import (
10+
"bytes"
1011
"context"
1112
"fmt"
1213
"net/http"
@@ -32,7 +33,11 @@ const (
3233
)
3334

3435
func TestManagedIdentityCredential_AzureArc(t *testing.T) {
35-
file, err := os.Create(filepath.Join(t.TempDir(), "arc.key"))
36+
d := t.TempDir()
37+
before := arcKeyDirectory
38+
arcKeyDirectory = func() (string, error) { return d, nil }
39+
defer func() { arcKeyDirectory = before }()
40+
file, err := os.Create(filepath.Join(d, "arc.key"))
3641
if err != nil {
3742
t.Fatal(err)
3843
}
@@ -150,6 +155,54 @@ func TestManagedIdentityCredential_AzureArcErrors(t *testing.T) {
150155
t.Fatal("expected an error")
151156
}
152157
})
158+
t.Run("key too large", func(t *testing.T) {
159+
d := t.TempDir()
160+
f := filepath.Join(d, "test.key")
161+
err := os.WriteFile(f, bytes.Repeat([]byte("."), 4097), 0600)
162+
require.NoError(t, err)
163+
before := arcKeyDirectory
164+
arcKeyDirectory = func() (string, error) { return d, nil }
165+
defer func() { arcKeyDirectory = before }()
166+
srv, close := mock.NewServer(mock.WithTransformAllRequestsToTestServerUrl())
167+
defer close()
168+
srv.AppendResponse(
169+
mock.WithHeader("WWW-Authenticate", "Basic realm="+f),
170+
mock.WithStatusCode(http.StatusUnauthorized),
171+
)
172+
cred, err := NewManagedIdentityCredential(&ManagedIdentityCredentialOptions{ClientOptions: azcore.ClientOptions{Transport: srv}})
173+
require.NoError(t, err)
174+
_, err = cred.GetToken(context.Background(), testTRO)
175+
require.ErrorContains(t, err, "too large")
176+
})
177+
t.Run("unexpected file paths", func(t *testing.T) {
178+
d, err := arcKeyDirectory()
179+
if err != nil {
180+
// test is running on an unsupported OS e.g. darwin
181+
t.Skip(err)
182+
}
183+
srv, close := mock.NewServer(mock.WithTransformAllRequestsToTestServerUrl())
184+
defer close()
185+
srv.AppendResponse(
186+
// unexpected directory
187+
mock.WithHeader("WWW-Authenticate", "Basic realm="+filepath.Join("foo", "bar.key")),
188+
mock.WithStatusCode(http.StatusUnauthorized),
189+
)
190+
o := ManagedIdentityCredentialOptions{ClientOptions: azcore.ClientOptions{Transport: srv}}
191+
cred, err := NewManagedIdentityCredential(&o)
192+
require.NoError(t, err)
193+
_, err = cred.GetToken(context.Background(), testTRO)
194+
require.ErrorContains(t, err, "unexpected file path")
195+
196+
srv.AppendResponse(
197+
// unexpected extension
198+
mock.WithHeader("WWW-Authenticate", "Basic realm="+filepath.Join(d, "foo")),
199+
mock.WithStatusCode(http.StatusUnauthorized),
200+
)
201+
cred, err = NewManagedIdentityCredential(&o)
202+
require.NoError(t, err)
203+
_, err = cred.GetToken(context.Background(), testTRO)
204+
require.ErrorContains(t, err, "unexpected file path")
205+
})
153206
}
154207

155208
func TestManagedIdentityCredential_AzureContainerInstanceLive(t *testing.T) {

0 commit comments

Comments
 (0)