@@ -748,20 +748,13 @@ def test_deprecated_mmds_config(uvm_plain):
748748 )
749749
750750
751- @pytest .mark .parametrize ("version" , MMDS_VERSIONS )
752- @pytest .mark .parametrize ("imds_compat" , [None , False , True ])
753- @pytest .mark .parametrize ("sdk" , ["py" , "go" ])
754- def test_aws_credential_provider (uvm_plain , version , imds_compat , sdk ):
755- """
756- Test AWS SDK's credential provider works on MMDS
757- """
758- test_microvm = uvm_plain
759- test_microvm .spawn ()
760- test_microvm .basic_config ()
761- test_microvm .add_net_iface ()
751+ def _configure_with_aws_credentials (microvm , version , imds_compat ):
752+ microvm .spawn ()
753+ microvm .basic_config ()
754+ microvm .add_net_iface ()
762755 # V2 requires session tokens for GET requests
763756 configure_mmds (
764- test_microvm , iface_ids = ["eth0" ], version = version , imds_compat = imds_compat
757+ microvm , iface_ids = ["eth0" ], version = version , imds_compat = imds_compat
765758 )
766759 now = datetime .now (timezone .utc )
767760 credentials = {
@@ -783,13 +776,24 @@ def test_aws_credential_provider(uvm_plain, version, imds_compat, sdk):
783776 }
784777 }
785778 }
786- populate_data_store (test_microvm , data_store )
787- test_microvm .start ()
788-
789- ssh_connection = test_microvm .ssh
779+ populate_data_store (microvm , data_store )
780+ microvm .start ()
790781
782+ ssh_connection = microvm .ssh
791783 run_guest_cmd (ssh_connection , f"ip route add { DEFAULT_IPV4 } dev eth0" , "" )
792784
785+ return ssh_connection
786+
787+
788+ @pytest .mark .parametrize ("version" , MMDS_VERSIONS )
789+ @pytest .mark .parametrize ("imds_compat" , [None , False , True ])
790+ @pytest .mark .parametrize ("sdk" , ["py" , "go" ])
791+ def test_aws_credential_provider (uvm_plain , version , imds_compat , sdk ):
792+ """
793+ Test AWS SDK's credential provider works on MMDS
794+ """
795+ ssh_connection = _configure_with_aws_credentials (uvm_plain , version , imds_compat )
796+
793797 match sdk :
794798 case "py" :
795799 cmd = r"""python3 - <<EOF
@@ -813,5 +817,38 @@ def test_aws_credential_provider(uvm_plain, version, imds_compat, sdk):
813817"""
814818 case "go" :
815819 cmd = "/usr/local/bin/go_sdk_cred_provider"
816- _ , stdout , stderr = ssh_connection .check_output (cmd )
820+ ret , stdout , stderr = ssh_connection .check_output (cmd )
821+ assert ret == 0
817822 assert stdout == "AAA,BBB,CCC\n " , stderr
823+
824+
825+ @pytest .mark .parametrize ("version" , MMDS_VERSIONS )
826+ @pytest .mark .parametrize ("imds_compat" , [None , False , True ])
827+ def test_go_sdk_credential_provider_with_custom_endpoint (
828+ uvm_plain , version , imds_compat
829+ ):
830+ """
831+ Test AWS SDK's credential provider with custom endpoint.
832+
833+ It sets "Accept: application/json" in a request to retrieve AWS credentials.
834+ If imds_compat is True, it should work. If False, it should NOT work,
835+ because MMDS responds a string of a JSON object containing the credentials
836+ (i.e. wrapped with doublequotes) with "Content-Type: application/json" but
837+ AWS SDK for Go expects only the inner JSON object.
838+ """
839+ ssh_connection = _configure_with_aws_credentials (uvm_plain , version , imds_compat )
840+
841+ cmd = "/usr/local/bin/go_sdk_cred_provider_with_custom_endpoint"
842+ ret , stdout , stderr = ssh_connection .run (cmd )
843+ if imds_compat :
844+ assert ret == 0
845+ assert stdout == "AAA,BBB,CCC\n " , stderr
846+ else :
847+ assert ret == 1
848+ assert (
849+ "Unable to retrieve credentials: "
850+ "failed to refresh cached credentials, "
851+ "failed to load credentials, deserialization failed, "
852+ "failed to deserialize json response, "
853+ "json: cannot unmarshal string into Go value of type client.GetCredentialsOutput"
854+ ) in stderr
0 commit comments