Skip to content

Commit 00c5fbd

Browse files
committed
feat: MDS connections use mTLS
1 parent f2708b2 commit 00c5fbd

File tree

5 files changed

+398
-34
lines changed

5 files changed

+398
-34
lines changed

google/auth/compute_engine/_metadata.py

Lines changed: 52 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -30,6 +30,8 @@
3030
from google.auth import metrics
3131
from google.auth import transport
3232
from google.auth._exponential_backoff import ExponentialBackoff
33+
from google.auth.compute_engine import _mtls
34+
from google.auth.transport import requests
3335

3436
_LOGGER = logging.getLogger(__name__)
3537

@@ -42,13 +44,24 @@
4244
_GCE_METADATA_HOST = os.getenv(
4345
environment_vars.GCE_METADATA_ROOT, "metadata.google.internal"
4446
)
45-
_METADATA_ROOT = "http://{}/computeMetadata/v1/".format(_GCE_METADATA_HOST)
4647

47-
# This is used to ping the metadata server, it avoids the cost of a DNS
48-
# lookup.
49-
_METADATA_IP_ROOT = "http://{}".format(
50-
os.getenv(environment_vars.GCE_METADATA_IP, "169.254.169.254")
51-
)
48+
GCE_MDS_HOSTS = ["metadata.google.internal", "169.254.169.254"]
49+
50+
51+
def _get_metadata_root(use_mtls):
52+
"""Returns the metadata server root URL."""
53+
scheme = "https" if use_mtls else "http"
54+
return "{}://{}/computeMetadata/v1/".format(scheme, _GCE_METADATA_HOST)
55+
56+
57+
def _get_metadata_ip_root(use_mtls):
58+
"""Returns the metadata server IP root URL."""
59+
scheme = "https" if use_mtls else "http"
60+
return "{}://{}".format(
61+
scheme, os.getenv(environment_vars.GCE_METADATA_IP, "169.254.169.254")
62+
)
63+
64+
5265
_METADATA_FLAVOR_HEADER = "metadata-flavor"
5366
_METADATA_FLAVOR_VALUE = "Google"
5467
_METADATA_HEADERS = {_METADATA_FLAVOR_HEADER: _METADATA_FLAVOR_VALUE}
@@ -102,6 +115,24 @@ def detect_gce_residency_linux():
102115
return content.startswith(_GOOGLE)
103116

104117

118+
def _prepare_request_for_mds(request, use_mtls=False):
119+
"""Prepares a request for the metadata server.
120+
121+
This will check if mTLS should be used and return a new request object if so.
122+
123+
Args:
124+
request (google.auth.transport.Request): A callable used to make
125+
HTTP requests.
126+
127+
Returns:
128+
google.auth.transport.Request: Request
129+
object to use.
130+
"""
131+
if use_mtls:
132+
request = requests.Request(_mtls.create_session())
133+
return request
134+
135+
105136
def ping(request, timeout=_METADATA_DEFAULT_TIMEOUT, retry_count=3):
106137
"""Checks to see if the metadata server is available.
107138
@@ -115,6 +146,8 @@ def ping(request, timeout=_METADATA_DEFAULT_TIMEOUT, retry_count=3):
115146
Returns:
116147
bool: True if the metadata server is reachable, False otherwise.
117148
"""
149+
use_mtls = _mtls.should_use_mds_mtls()
150+
request = _prepare_request_for_mds(request, use_mtls=use_mtls)
118151
# NOTE: The explicit ``timeout`` is a workaround. The underlying
119152
# issue is that resolving an unknown host on some networks will take
120153
# 20-30 seconds; making this timeout short fixes the issue, but
@@ -129,7 +162,10 @@ def ping(request, timeout=_METADATA_DEFAULT_TIMEOUT, retry_count=3):
129162
for attempt in backoff:
130163
try:
131164
response = request(
132-
url=_METADATA_IP_ROOT, method="GET", headers=headers, timeout=timeout
165+
url=_get_metadata_ip_root(use_mtls),
166+
method="GET",
167+
headers=headers,
168+
timeout=timeout,
133169
)
134170

135171
metadata_flavor = response.headers.get(_METADATA_FLAVOR_HEADER)
@@ -153,7 +189,7 @@ def ping(request, timeout=_METADATA_DEFAULT_TIMEOUT, retry_count=3):
153189
def get(
154190
request,
155191
path,
156-
root=_METADATA_ROOT,
192+
root=None,
157193
params=None,
158194
recursive=False,
159195
retry_count=5,
@@ -190,6 +226,14 @@ def get(
190226
google.auth.exceptions.TransportError: if an error occurred while
191227
retrieving metadata.
192228
"""
229+
use_mtls = _mtls.should_use_mds_mtls()
230+
# Prepare the request object for mTLS if needed.
231+
# This will create a new request object with the mTLS session.
232+
request = _prepare_request_for_mds(request, use_mtls=use_mtls)
233+
234+
if root is None:
235+
root = _get_metadata_root(use_mtls)
236+
193237
base_url = urljoin(root, path)
194238
query_params = {} if params is None else params
195239

Lines changed: 95 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,95 @@
1+
# -*- coding: utf-8 -*-
2+
#
3+
# Copyright 2024 Google LLC
4+
#
5+
# Licensed under the Apache License, Version 2.0 (the "License");
6+
# you may not use this file except in compliance with the License.
7+
# You may obtain a copy of the License at
8+
#
9+
# http://www.apache.org/licenses/LICENSE-2.0
10+
#
11+
# Unless required by applicable law or agreed to in writing, software
12+
# distributed under the License is distributed on an "AS IS" BASIS,
13+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14+
# See the License for the specific language governing permissions and
15+
# limitations under the License.
16+
#
17+
"""Mutual TLS for Google Compute Engine metadata server."""
18+
import enum
19+
import os
20+
import ssl
21+
from dataclasses import dataclass
22+
23+
from pathlib import Path
24+
import requests
25+
from requests.adapters import HTTPAdapter
26+
27+
from google.auth import environment_vars, exceptions
28+
29+
@dataclass
30+
class MdsMtlsConfig:
31+
ca_cert_path: str = os.path.join(Path.home(), "mtls_mds_certificates", "root.crt") # path to CA certificate
32+
client_combined_cert_path: str = os.path.join(Path.home(), "mtls_mds_certificates", "client_creds.key") # path to file containing client certificate and key
33+
34+
class MdsMtlsMode(enum.Enum):
35+
"""MDS mTLS mode."""
36+
STRICT = "strict"
37+
NONE = "none"
38+
DEFAULT = "default"
39+
40+
def _parse_mds_mode():
41+
"""Parses the GCE_METADATA_MTLS_MODE environment variable."""
42+
mode_str = os.environ.get(environment_vars.GCE_METADATA_MTLS_MODE, "default").lower()
43+
try:
44+
return MdsMtlsMode(mode_str)
45+
except ValueError:
46+
raise ValueError(
47+
"Invalid value for GCE_METADATA_MTLS_MODE. Must be one of 'strict', 'none', or 'default'."
48+
)
49+
50+
51+
def _certs_exist(mds_mtls_config: MdsMtlsConfig):
52+
"""Checks if the mTLS certificates exist."""
53+
return os.path.exists(mds_mtls_config.ca_cert_path) and os.path.exists(
54+
mds_mtls_config.client_combined_cert_path
55+
)
56+
57+
58+
class MdsMtlsAdapter(HTTPAdapter):
59+
"""An HTTP adapter that uses mTLS for the metadata server."""
60+
61+
def __init__(self, mds_mtls_config: MdsMtlsConfig, *args, **kwargs):
62+
self.ssl_context = ssl.create_default_context()
63+
self.ssl_context.load_verify_locations(cafile=mds_mtls_config.ca_cert_path)
64+
self.ssl_context.load_cert_chain(certfile=mds_mtls_config.client_combined_cert_path)
65+
super(MdsMtlsAdapter, self).__init__(*args, **kwargs)
66+
67+
def init_poolmanager(self, *args, **kwargs):
68+
kwargs["ssl_context"] = self.ssl_context
69+
return super(MdsMtlsAdapter, self).init_poolmanager(*args, **kwargs)
70+
71+
def proxy_manager_for(self, *args, **kwargs):
72+
kwargs["ssl_context"] = self.ssl_context
73+
return super(MdsMtlsAdapter, self).proxy_manager_for(*args, **kwargs)
74+
75+
def create_session(mds_mtls_config: MdsMtlsConfig = MdsMtlsConfig()):
76+
"""Creates a requests.Session configured for mTLS."""
77+
session = requests.Session()
78+
adapter = MdsMtlsAdapter(mds_mtls_config)
79+
session.mount("https://", adapter)
80+
return session
81+
82+
83+
def should_use_mds_mtls(mds_mtls_config: MdsMtlsConfig = MdsMtlsConfig()):
84+
"""Determines if mTLS should be used for the metadata server."""
85+
mode = _parse_mds_mode()
86+
if mode == MdsMtlsMode.STRICT:
87+
if not _certs_exist(mds_mtls_config):
88+
raise exceptions.MutualTLSChannelError(
89+
"mTLS certificates not found in strict mode."
90+
)
91+
return True
92+
elif mode == MdsMtlsMode.NONE:
93+
return False
94+
else: # Default mode
95+
return _certs_exist(mds_mtls_config)

google/auth/environment_vars.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -60,6 +60,12 @@
6060
"""Environment variable providing an alternate ip:port to be used for ip-only
6161
GCE metadata requests."""
6262

63+
GCE_METADATA_MTLS_MODE = "GCE_METADATA_MTLS_MODE"
64+
"""Environment variable controlling the mTLS behavior for GCE metadata requests.
65+
66+
Can be one of "strict", "none", or "default".
67+
"""
68+
6369
GOOGLE_API_USE_CLIENT_CERTIFICATE = "GOOGLE_API_USE_CLIENT_CERTIFICATE"
6470
"""Environment variable controlling whether to use client certificate or not.
6571

0 commit comments

Comments
 (0)