Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions lib/mongo.rb
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,7 @@
require 'mongo/semaphore'
require 'mongo/distinguishing_semaphore'
require 'mongo/condition_variable'
require 'mongo/csot_timeout_holder'
require 'mongo/options'
require 'mongo/loggable'
require 'mongo/cluster_time'
Expand Down
67 changes: 30 additions & 37 deletions lib/mongo/auth/aws/credentials_retriever.rb
Original file line number Diff line number Diff line change
Expand Up @@ -69,8 +69,7 @@ def initialize(user = nil, credentials_cache: CredentialsCache.instance)
# Retrieves a valid set of credentials, if possible, or raises
# Auth::InvalidConfiguration.
#
# @param [ Operation::Context | nil ] context Context of the operation
# credentials are retrieved for.
# @param [ CsotTimeoutHolder | nil ] timeout_holder CSOT timeout, if any.
#
# @return [ Auth::Aws::Credentials ] A valid set of credentials.
#
Expand All @@ -80,14 +79,14 @@ def initialize(user = nil, credentials_cache: CredentialsCache.instance)
# retrieved from any source.
# @raise Error::TimeoutError if credentials cannot be retrieved within
# the timeout defined on the operation context.
def credentials(context = nil)
def credentials(timeout_holder = nil)
credentials = credentials_from_user(user)
return credentials unless credentials.nil?

credentials = credentials_from_environment
return credentials unless credentials.nil?

credentials = @credentials_cache.fetch { obtain_credentials_from_endpoints(context) }
credentials = @credentials_cache.fetch { obtain_credentials_from_endpoints(timeout_holder) }
return credentials unless credentials.nil?

raise Auth::Aws::CredentialsNotFound
Expand Down Expand Up @@ -132,8 +131,7 @@ def credentials_from_environment

# Returns credentials from the AWS metadata endpoints.
#
# @param [ Operation::Context | nil ] context Context of the operation
# credentials are retrieved for.
# @param [ CsotTimeoutHolder ] timeout_holder CSOT timeout.
#
# @return [ Auth::Aws::Credentials | nil ] A set of credentials, or nil
# if retrieval failed or the obtained credentials are invalid.
Expand All @@ -142,48 +140,47 @@ def credentials_from_environment
# of credentials.
# @ raise Error::TimeoutError if credentials cannot be retrieved within
# the timeout defined on the operation context.
def obtain_credentials_from_endpoints(context = nil)
if (credentials = web_identity_credentials(context)) && credentials_valid?(credentials, 'Web identity token')
def obtain_credentials_from_endpoints(timeout_holder = nil)
if (credentials = web_identity_credentials(timeout_holder)) && credentials_valid?(credentials, 'Web identity token')
credentials
elsif (credentials = ecs_metadata_credentials(context)) && credentials_valid?(credentials, 'ECS task metadata')
elsif (credentials = ecs_metadata_credentials(timeout_holder)) && credentials_valid?(credentials, 'ECS task metadata')
credentials
elsif (credentials = ec2_metadata_credentials(context)) && credentials_valid?(credentials, 'EC2 instance metadata')
elsif (credentials = ec2_metadata_credentials(timeout_holder)) && credentials_valid?(credentials, 'EC2 instance metadata')
credentials
end
end

# Returns credentials from the EC2 metadata endpoint. The credentials
# could be empty, partial or invalid.
#
# @param [ Operation::Context | nil ] context Context of the operation
# credentials are retrieved for.
# @param [ CsotTimeoutHolder ] timeout_holder CSOT timeout.
#
# @return [ Auth::Aws::Credentials | nil ] A set of credentials, or nil
# if retrieval failed.
# @ raise Error::TimeoutError if credentials cannot be retrieved within
# the timeout defined on the operation context.
def ec2_metadata_credentials(context = nil)
context&.check_timeout!
# the timeout.
def ec2_metadata_credentials(timeout_holder = nil)
timeout_holder&.check_timeout!
http = Net::HTTP.new('169.254.169.254')
req = Net::HTTP::Put.new('/latest/api/token',
# The TTL is required in order to obtain the metadata token.
{'x-aws-ec2-metadata-token-ttl-seconds' => '30'})
resp = with_timeout(context) do
resp = with_timeout(timeout_holder) do
http.request(req)
end
if resp.code != '200'
return nil
end
metadata_token = resp.body
resp = with_timeout(context) do
resp = with_timeout(timeout_holder) do
http_get(http, '/latest/meta-data/iam/security-credentials', metadata_token)
end
if resp.code != '200'
return nil
end
role_name = resp.body
escaped_role_name = CGI.escape(role_name).gsub('+', '%20')
resp = with_timeout(context) do
resp = with_timeout(timeout_holder) do
http_get(http, "/latest/meta-data/iam/security-credentials/#{escaped_role_name}", metadata_token)
end
if resp.code != '200'
Expand All @@ -208,15 +205,14 @@ def ec2_metadata_credentials(context = nil)
# Returns credentials from the ECS metadata endpoint. The credentials
# could be empty, partial or invalid.
#
# @param [ Operation::Context | nil ] context Context of the operation
# credentials are retrieved for.
# @param [ CsotTimeoutHolder | nil ] timeout_holder CSOT timeout.
#
# @return [ Auth::Aws::Credentials | nil ] A set of credentials, or nil
# if retrieval failed.
# @ raise Error::TimeoutError if credentials cannot be retrieved within
# the timeout defined on the operation context.
def ecs_metadata_credentials(context = nil)
context&.check_timeout!
def ecs_metadata_credentials(timeout_holder = nil)
timeout_holder&.check_timeout!
relative_uri = ENV['AWS_CONTAINER_CREDENTIALS_RELATIVE_URI']
if relative_uri.nil? || relative_uri.empty?
return nil
Expand All @@ -230,7 +226,7 @@ def ecs_metadata_credentials(context = nil)
# a leading slash must be added by the driver, but this is not
# in fact needed.
req = Net::HTTP::Get.new(relative_uri)
resp = with_timeout(context) do
resp = with_timeout(timeout_holder) do
http.request(req)
end
if resp.code != '200'
Expand All @@ -252,16 +248,15 @@ def ecs_metadata_credentials(context = nil)
# inside EKS. See https://docs.aws.amazon.com/eks/latest/userguide/iam-roles-for-service-accounts.html
# for further details.
#
# @param [ Operation::Context | nil ] context Context of the operation
# credentials are retrieved for.
# @param [ CsotTimeoutHolder | nil ] timeout_holder CSOT timeout.
#
# @return [ Auth::Aws::Credentials | nil ] A set of credentials, or nil
# if retrieval failed.
def web_identity_credentials(context = nil)
def web_identity_credentials(timeout_holder = nil)
web_identity_token, role_arn, role_session_name = prepare_web_identity_inputs
return nil if web_identity_token.nil?
response = request_web_identity_credentials(
web_identity_token, role_arn, role_session_name, context
web_identity_token, role_arn, role_session_name, timeout_holder
)
return if response.nil?
credentials_from_web_identity_response(response)
Expand Down Expand Up @@ -296,16 +291,15 @@ def prepare_web_identity_inputs
# that the caller is assuming.
# @param [ String ] role_session_name An identifier for the assumed
# role session.
# @param [ Operation::Context | nil ] context Context of the operation
# credentials are retrieved for.
# @param [ CsotTimeoutHolder | nil ] timeout_holder CSOT timeout.
#
# @return [ Net::HTTPResponse | nil ] AWS API response if successful,
# otherwise nil.
#
# @ raise Error::TimeoutError if credentials cannot be retrieved within
# the timeout defined on the operation context.
def request_web_identity_credentials(token, role_arn, role_session_name, context)
context&.check_timeout!
def request_web_identity_credentials(token, role_arn, role_session_name, timeout_holder)
timeout_holder&.check_timeout!
uri = URI('https://sts.amazonaws.com/')
params = {
'Action' => 'AssumeRoleWithWebIdentity',
Expand All @@ -317,7 +311,7 @@ def request_web_identity_credentials(token, role_arn, role_session_name, context
uri.query = ::URI.encode_www_form(params)
req = Net::HTTP::Post.new(uri)
req['Accept'] = 'application/json'
resp = with_timeout(context) do
resp = with_timeout(timeout_holder) do
Net::HTTP.start(uri.hostname, uri.port, use_ssl: true) do |https|
https.request(req)
end
Expand Down Expand Up @@ -396,13 +390,12 @@ def credentials_valid?(credentials, source)
# We use +Timeout.timeout+ here because there is no other acceptable easy
# way to time limit http requests.
#
# @param [ Operation::Context | nil ] context Context of the operation
# @param [ CsotTimeoutHolder | nil ] timeout_holder CSOT timeout.
#
# @ raise Error::TimeoutError if deadline exceeded.
def with_timeout(context)
context&.check_timeout!
timeout = context&.remaining_timeout_sec || METADATA_TIMEOUT
exception_class = if context&.csot?
def with_timeout(timeout_holder)
timeout = timeout_holder&.remaining_timeout_sec! || METADATA_TIMEOUT
exception_class = if timeout_holder&.csot?
Error::TimeoutError
else
nil
Expand Down
10 changes: 8 additions & 2 deletions lib/mongo/collection.rb
Original file line number Diff line number Diff line change
Expand Up @@ -441,20 +441,26 @@ def create(opts = {})
# @option opts [ Hash ] :write_concern The write concern options.
# @option opts [ Hash | nil ] :encrypted_fields Encrypted fields hash that
# was provided to `create` collection helper.
# @option opts [ Integer ] :timeout_ms The per-operation timeout in milliseconds.
# Must a positive integer. The default value is unset which means infinite.
#
# @return [ Result ] The result of the command.
#
# @since 2.0.0
def drop(opts = {})
client.send(:with_session, opts) do |session|
client.with_session(opts) do |session|
maybe_drop_emm_collections(opts[:encrypted_fields], client, session) do
temp_write_concern = write_concern
write_concern = if opts[:write_concern]
WriteConcern.get(opts[:write_concern])
else
temp_write_concern
end
context = Operation::Context.new(client: client, session: session)
context = Operation::Context.new(
client: client,
session: session,
operation_timeouts: operation_timeouts(opts)
)
operation = Operation::Drop.new({
selector: { :drop => name },
db_name: database.name,
Expand Down
8 changes: 4 additions & 4 deletions lib/mongo/crypt/auto_encrypter.rb
Original file line number Diff line number Diff line change
Expand Up @@ -187,26 +187,26 @@ def encrypt?
# @param [ Hash ] command The command to be encrypted.
#
# @return [ BSON::Document ] The encrypted command.
def encrypt(database_name, command, context)
def encrypt(database_name, command, timeout_holder)
AutoEncryptionContext.new(
@crypt_handle,
@encryption_io,
database_name,
command
).run_state_machine(context)
).run_state_machine(timeout_holder)
end

# Decrypt a database command.
#
# @param [ Hash ] command The command with encrypted fields.
#
# @return [ BSON::Document ] The decrypted command.
def decrypt(command, context)
def decrypt(command, timeout_holder)
AutoDecryptionContext.new(
@crypt_handle,
@encryption_io,
command
).run_state_machine(context)
).run_state_machine(timeout_holder)
end

# Close the resources created by the AutoEncrypter.
Expand Down
30 changes: 14 additions & 16 deletions lib/mongo/crypt/context.rb
Original file line number Diff line number Diff line change
Expand Up @@ -66,8 +66,8 @@ def state
# Runs the mongocrypt_ctx_t state machine and handles
# all I/O on behalf of
#
# @param [ Operation::Context ] context Context of the operation the state
# machine is run for.
# @param [ CsotTimeoutHolder ] timeout_holder CSOT timeouts for the
# operation the state.
#
# @return [ BSON::Document ] A BSON document representing the outcome
# of the state machine. Contents can differ depending on how the
Expand All @@ -78,10 +78,9 @@ def state
#
# This method is not currently unit tested. It is integration tested
# in spec/integration/explicit_encryption_spec.rb
def run_state_machine(context)
def run_state_machine(timeout_holder)
while true
context.check_timeout!
timeout_ms = context.remaining_timeout_ms
timeout_ms = timeout_holder.remaining_timeout_ms!
case state
when :error
Binding.check_ctx_status(self)
Expand Down Expand Up @@ -123,7 +122,7 @@ def run_state_machine(context)
when :need_kms_credentials
Binding.ctx_provide_kms_providers(
self,
retrieve_kms_credentials(context).to_document
retrieve_kms_credentials(timeout_holder).to_document
)
else
raise Error::CryptError.new(
Expand Down Expand Up @@ -152,16 +151,15 @@ def mongocrypt_feed(doc)
# Retrieves KMS credentials for providers that are configured
# for automatic credentials retrieval.
#
# @param [ Operation::Context ] context Context of the operation credentials
# are retrieved for.
# @param [ CsotTimeoutHolder ] timeout_holder CSOT timeout.
#
# @return [ Crypt::KMS::Credentials ] Credentials for the configured
# KMS providers.
def retrieve_kms_credentials(context)
def retrieve_kms_credentials(timeout_holder)
providers = {}
if kms_providers.aws&.empty?
begin
aws_credentials = Mongo::Auth::Aws::CredentialsRetriever.new.credentials(context)
aws_credentials = Mongo::Auth::Aws::CredentialsRetriever.new.credentials(timeout_holder)
rescue Auth::Aws::CredentialsNotFound
raise Error::CryptError.new(
"Could not locate AWS credentials (checked environment variables, ECS and EC2 metadata)"
Expand All @@ -170,10 +168,10 @@ def retrieve_kms_credentials(context)
providers[:aws] = aws_credentials.to_h
end
if kms_providers.gcp&.empty?
providers[:gcp] = { access_token: gcp_access_token }
providers[:gcp] = { access_token: gcp_access_token(timeout_holder) }
end
if kms_providers.azure&.empty?
providers[:azure] = { access_token: azure_access_token }
providers[:azure] = { access_token: azure_access_token(timeout_holder) }
end
KMS::Credentials.new(providers)
end
Expand All @@ -183,8 +181,8 @@ def retrieve_kms_credentials(context)
# @return [ String ] A GCP access token.
#
# @raise [ Error::CryptError ] If the GCP access token could not be
def gcp_access_token
KMS::GCP::CredentialsRetriever.fetch_access_token
def gcp_access_token(timeout_holder)
KMS::GCP::CredentialsRetriever.fetch_access_token(timeout_holder)
rescue KMS::CredentialsNotFound => e
raise Error::CryptError.new(
"Could not locate GCP credentials: #{e.class}: #{e.message}"
Expand All @@ -197,9 +195,9 @@ def gcp_access_token
#
# @raise [ Error::CryptError ] If the Azure access token could not be
# retrieved.
def azure_access_token
def azure_access_token(timeout_holder)
if @cached_azure_token.nil? || @cached_azure_token.expired?
@cached_azure_token = KMS::Azure::CredentialsRetriever.fetch_access_token
@cached_azure_token = KMS::Azure::CredentialsRetriever.fetch_access_token(timeout_holder: timeout_holder)
end
@cached_azure_token.access_token
rescue KMS::CredentialsNotFound => e
Expand Down
Loading