Skip to content

Commit bba73e9

Browse files
RUBY-3375 CSOT for CSFLE (#2868)
1 parent d0fe3fb commit bba73e9

34 files changed

+962
-601
lines changed

lib/mongo/auth/aws/credentials_retriever.rb

Lines changed: 77 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -69,20 +69,25 @@ def initialize(user = nil, credentials_cache: CredentialsCache.instance)
6969
# Retrieves a valid set of credentials, if possible, or raises
7070
# Auth::InvalidConfiguration.
7171
#
72+
# @param [ Operation::Context | nil ] context Context of the operation
73+
# credentials are retrieved for.
74+
#
7275
# @return [ Auth::Aws::Credentials ] A valid set of credentials.
7376
#
7477
# @raise Auth::InvalidConfiguration if a source contains an invalid set
7578
# of credentials.
7679
# @raise Auth::Aws::CredentialsNotFound if credentials could not be
7780
# retrieved from any source.
78-
def credentials
81+
# @raise Error::TimeoutError if credentials cannot be retrieved within
82+
# the timeout defined on the operation context.
83+
def credentials(context = nil)
7984
credentials = credentials_from_user(user)
8085
return credentials unless credentials.nil?
8186

8287
credentials = credentials_from_environment
8388
return credentials unless credentials.nil?
8489

85-
credentials = @credentials_cache.fetch { obtain_credentials_from_endpoints }
90+
credentials = @credentials_cache.fetch { obtain_credentials_from_endpoints(context) }
8691
return credentials unless credentials.nil?
8792

8893
raise Auth::Aws::CredentialsNotFound
@@ -127,47 +132,58 @@ def credentials_from_environment
127132

128133
# Returns credentials from the AWS metadata endpoints.
129134
#
135+
# @param [ Operation::Context | nil ] context Context of the operation
136+
# credentials are retrieved for.
137+
#
130138
# @return [ Auth::Aws::Credentials | nil ] A set of credentials, or nil
131139
# if retrieval failed or the obtained credentials are invalid.
132140
#
133141
# @raise Auth::InvalidConfiguration if a source contains an invalid set
134142
# of credentials.
135-
def obtain_credentials_from_endpoints
136-
if (credentials = web_identity_credentials) && credentials_valid?(credentials, 'Web identity token')
143+
# @ raise Error::TimeoutError if credentials cannot be retrieved within
144+
# the timeout defined on the operation context.
145+
def obtain_credentials_from_endpoints(context = nil)
146+
if (credentials = web_identity_credentials(context)) && credentials_valid?(credentials, 'Web identity token')
137147
credentials
138-
elsif (credentials = ecs_metadata_credentials) && credentials_valid?(credentials, 'ECS task metadata')
148+
elsif (credentials = ecs_metadata_credentials(context)) && credentials_valid?(credentials, 'ECS task metadata')
139149
credentials
140-
elsif (credentials = ec2_metadata_credentials) && credentials_valid?(credentials, 'EC2 instance metadata')
150+
elsif (credentials = ec2_metadata_credentials(context)) && credentials_valid?(credentials, 'EC2 instance metadata')
141151
credentials
142152
end
143153
end
144154

145155
# Returns credentials from the EC2 metadata endpoint. The credentials
146156
# could be empty, partial or invalid.
147157
#
158+
# @param [ Operation::Context | nil ] context Context of the operation
159+
# credentials are retrieved for.
160+
#
148161
# @return [ Auth::Aws::Credentials | nil ] A set of credentials, or nil
149162
# if retrieval failed.
150-
def ec2_metadata_credentials
163+
# @ raise Error::TimeoutError if credentials cannot be retrieved within
164+
# the timeout defined on the operation context.
165+
def ec2_metadata_credentials(context = nil)
166+
context&.check_timeout!
151167
http = Net::HTTP.new('169.254.169.254')
152168
req = Net::HTTP::Put.new('/latest/api/token',
153169
# The TTL is required in order to obtain the metadata token.
154170
{'x-aws-ec2-metadata-token-ttl-seconds' => '30'})
155-
resp = ::Timeout.timeout(METADATA_TIMEOUT) do
171+
resp = with_timeout(context) do
156172
http.request(req)
157173
end
158174
if resp.code != '200'
159175
return nil
160176
end
161177
metadata_token = resp.body
162-
resp = ::Timeout.timeout(METADATA_TIMEOUT) do
178+
resp = with_timeout(context) do
163179
http_get(http, '/latest/meta-data/iam/security-credentials', metadata_token)
164180
end
165181
if resp.code != '200'
166182
return nil
167183
end
168184
role_name = resp.body
169185
escaped_role_name = CGI.escape(role_name).gsub('+', '%20')
170-
resp = ::Timeout.timeout(METADATA_TIMEOUT) do
186+
resp = with_timeout(context) do
171187
http_get(http, "/latest/meta-data/iam/security-credentials/#{escaped_role_name}", metadata_token)
172188
end
173189
if resp.code != '200'
@@ -189,7 +205,18 @@ def ec2_metadata_credentials
189205
return nil
190206
end
191207

192-
def ecs_metadata_credentials
208+
# Returns credentials from the ECS metadata endpoint. The credentials
209+
# could be empty, partial or invalid.
210+
#
211+
# @param [ Operation::Context | nil ] context Context of the operation
212+
# credentials are retrieved for.
213+
#
214+
# @return [ Auth::Aws::Credentials | nil ] A set of credentials, or nil
215+
# if retrieval failed.
216+
# @ raise Error::TimeoutError if credentials cannot be retrieved within
217+
# the timeout defined on the operation context.
218+
def ecs_metadata_credentials(context = nil)
219+
context&.check_timeout!
193220
relative_uri = ENV['AWS_CONTAINER_CREDENTIALS_RELATIVE_URI']
194221
if relative_uri.nil? || relative_uri.empty?
195222
return nil
@@ -203,7 +230,7 @@ def ecs_metadata_credentials
203230
# a leading slash must be added by the driver, but this is not
204231
# in fact needed.
205232
req = Net::HTTP::Get.new(relative_uri)
206-
resp = ::Timeout.timeout(METADATA_TIMEOUT) do
233+
resp = with_timeout(context) do
207234
http.request(req)
208235
end
209236
if resp.code != '200'
@@ -225,13 +252,16 @@ def ecs_metadata_credentials
225252
# inside EKS. See https://docs.aws.amazon.com/eks/latest/userguide/iam-roles-for-service-accounts.html
226253
# for further details.
227254
#
255+
# @param [ Operation::Context | nil ] context Context of the operation
256+
# credentials are retrieved for.
257+
#
228258
# @return [ Auth::Aws::Credentials | nil ] A set of credentials, or nil
229259
# if retrieval failed.
230-
def web_identity_credentials
260+
def web_identity_credentials(context = nil)
231261
web_identity_token, role_arn, role_session_name = prepare_web_identity_inputs
232262
return nil if web_identity_token.nil?
233263
response = request_web_identity_credentials(
234-
web_identity_token, role_arn, role_session_name
264+
web_identity_token, role_arn, role_session_name, context
235265
)
236266
return if response.nil?
237267
credentials_from_web_identity_response(response)
@@ -266,10 +296,16 @@ def prepare_web_identity_inputs
266296
# that the caller is assuming.
267297
# @param [ String ] role_session_name An identifier for the assumed
268298
# role session.
299+
# @param [ Operation::Context | nil ] context Context of the operation
300+
# credentials are retrieved for.
269301
#
270302
# @return [ Net::HTTPResponse | nil ] AWS API response if successful,
271303
# otherwise nil.
272-
def request_web_identity_credentials(token, role_arn, role_session_name)
304+
#
305+
# @ raise Error::TimeoutError if credentials cannot be retrieved within
306+
# the timeout defined on the operation context.
307+
def request_web_identity_credentials(token, role_arn, role_session_name, context)
308+
context&.check_timeout!
273309
uri = URI('https://sts.amazonaws.com/')
274310
params = {
275311
'Action' => 'AssumeRoleWithWebIdentity',
@@ -281,8 +317,10 @@ def request_web_identity_credentials(token, role_arn, role_session_name)
281317
uri.query = ::URI.encode_www_form(params)
282318
req = Net::HTTP::Post.new(uri)
283319
req['Accept'] = 'application/json'
284-
resp = Net::HTTP.start(uri.hostname, uri.port, use_ssl: true) do |https|
285-
https.request(req)
320+
resp = with_timeout(context) do
321+
Net::HTTP.start(uri.hostname, uri.port, use_ssl: true) do |https|
322+
https.request(req)
323+
end
286324
end
287325
if resp.code != '200'
288326
return nil
@@ -351,6 +389,28 @@ def credentials_valid?(credentials, source)
351389

352390
true
353391
end
392+
393+
# Execute the given block considering the timeout defined on the context,
394+
# or the default timeout value.
395+
#
396+
# We use +Timeout.timeout+ here because there is no other acceptable easy
397+
# way to time limit http requests.
398+
#
399+
# @param [ Operation::Context | nil ] context Context of the operation
400+
#
401+
# @ raise Error::TimeoutError if deadline exceeded.
402+
def with_timeout(context)
403+
context&.check_timeout!
404+
timeout = context&.remaining_timeout_sec || METADATA_TIMEOUT
405+
exception_class = if context&.csot?
406+
Error::TimeoutError
407+
else
408+
nil
409+
end
410+
::Timeout.timeout(timeout, exception_class) do
411+
yield
412+
end
413+
end
354414
end
355415
end
356416
end

lib/mongo/client_encryption.rb

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -40,6 +40,8 @@ class ClientEncryption
4040
# should be hashes of TLS connection options. The options are equivalent
4141
# to TLS connection options of Mongo::Client.
4242
# @see Mongo::Client#initialize for list of TLS options.
43+
# @option options [ Integer ] :timeout_ms Timeout that will be applied to all
44+
# operations on this instance.
4345
#
4446
# @raise [ ArgumentError ] If required options are missing or incorrectly
4547
# formatted.

lib/mongo/crypt/auto_encrypter.rb

Lines changed: 4 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -119,8 +119,6 @@ def initialize(options)
119119
@options[:extra_options][:crypt_shared_lib_required]
120120

121121
unless @options[:extra_options][:crypt_shared_lib_required] || @crypt_handle.crypt_shared_lib_available? || @options[:bypass_query_analysis]
122-
# Set server selection timeout to 1 to prevent the client waiting for a
123-
# long timeout before spawning mongocryptd
124122
@mongocryptd_client = Client.new(
125123
@options[:extra_options][:mongocryptd_uri],
126124
monitoring_io: @options[:client].options[:monitoring_io],
@@ -189,26 +187,26 @@ def encrypt?
189187
# @param [ Hash ] command The command to be encrypted.
190188
#
191189
# @return [ BSON::Document ] The encrypted command.
192-
def encrypt(database_name, command)
190+
def encrypt(database_name, command, context)
193191
AutoEncryptionContext.new(
194192
@crypt_handle,
195193
@encryption_io,
196194
database_name,
197195
command
198-
).run_state_machine
196+
).run_state_machine(context)
199197
end
200198

201199
# Decrypt a database command.
202200
#
203201
# @param [ Hash ] command The command with encrypted fields.
204202
#
205203
# @return [ BSON::Document ] The decrypted command.
206-
def decrypt(command)
204+
def decrypt(command, context)
207205
AutoDecryptionContext.new(
208206
@crypt_handle,
209207
@encryption_io,
210208
command
211-
).run_state_machine
209+
).run_state_machine(context)
212210
end
213211

214212
# Close the resources created by the AutoEncrypter.

lib/mongo/crypt/context.rb

Lines changed: 16 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -64,7 +64,10 @@ def state
6464
end
6565

6666
# Runs the mongocrypt_ctx_t state machine and handles
67-
# all I/O on behalf of libmongocrypt
67+
# all I/O on behalf of
68+
#
69+
# @param [ Operation::Context ] context Context of the operation the state
70+
# machine is run for.
6871
#
6972
# @return [ BSON::Document ] A BSON document representing the outcome
7073
# of the state machine. Contents can differ depending on how the
@@ -75,8 +78,10 @@ def state
7578
#
7679
# This method is not currently unit tested. It is integration tested
7780
# in spec/integration/explicit_encryption_spec.rb
78-
def run_state_machine
81+
def run_state_machine(context)
7982
while true
83+
context.check_timeout!
84+
timeout_ms = context.remaining_timeout_ms
8085
case state
8186
when :error
8287
Binding.check_ctx_status(self)
@@ -88,22 +93,22 @@ def run_state_machine
8893
when :need_mongo_keys
8994
filter = Binding.ctx_mongo_op(self)
9095

91-
@encryption_io.find_keys(filter).each do |key|
96+
@encryption_io.find_keys(filter, timeout_ms: timeout_ms).each do |key|
9297
mongocrypt_feed(key) if key
9398
end
9499

95100
mongocrypt_done
96101
when :need_mongo_collinfo
97102
filter = Binding.ctx_mongo_op(self)
98103

99-
result = @encryption_io.collection_info(@db_name, filter)
104+
result = @encryption_io.collection_info(@db_name, filter, timeout_ms: timeout_ms)
100105
mongocrypt_feed(result) if result
101106

102107
mongocrypt_done
103108
when :need_mongo_markings
104109
cmd = Binding.ctx_mongo_op(self)
105110

106-
result = @encryption_io.mark_command(cmd)
111+
result = @encryption_io.mark_command(cmd, timeout_ms: timeout_ms)
107112
mongocrypt_feed(result)
108113

109114
mongocrypt_done
@@ -118,7 +123,7 @@ def run_state_machine
118123
when :need_kms_credentials
119124
Binding.ctx_provide_kms_providers(
120125
self,
121-
retrieve_kms_credentials.to_document
126+
retrieve_kms_credentials(context).to_document
122127
)
123128
else
124129
raise Error::CryptError.new(
@@ -147,13 +152,16 @@ def mongocrypt_feed(doc)
147152
# Retrieves KMS credentials for providers that are configured
148153
# for automatic credentials retrieval.
149154
#
155+
# @param [ Operation::Context ] context Context of the operation credentials
156+
# are retrieved for.
157+
#
150158
# @return [ Crypt::KMS::Credentials ] Credentials for the configured
151159
# KMS providers.
152-
def retrieve_kms_credentials
160+
def retrieve_kms_credentials(context)
153161
providers = {}
154162
if kms_providers.aws&.empty?
155163
begin
156-
aws_credentials = Mongo::Auth::Aws::CredentialsRetriever.new.credentials
164+
aws_credentials = Mongo::Auth::Aws::CredentialsRetriever.new.credentials(context)
157165
rescue Auth::Aws::CredentialsNotFound
158166
raise Error::CryptError.new(
159167
"Could not locate AWS credentials (checked environment variables, ECS and EC2 metadata)"

0 commit comments

Comments
 (0)