@@ -752,7 +752,6 @@ class Connection(object):
752752 _socket = None
753753
754754 _socket_impl = socket
755- _ssl_impl = ssl
756755
757756 _check_hostname = False
758757 _product_type = None
@@ -780,7 +779,7 @@ def __init__(self, host='127.0.0.1', port=9042, authenticator=None,
780779 self .endpoint = host if isinstance (host , EndPoint ) else DefaultEndPoint (host , port )
781780
782781 self .authenticator = authenticator
783- self .ssl_options = ssl_options .copy () if ssl_options else None
782+ self .ssl_options = ssl_options .copy () if ssl_options else {}
784783 self .ssl_context = ssl_context
785784 self .sockopts = sockopts
786785 self .compression = compression
@@ -800,15 +799,20 @@ def __init__(self, host='127.0.0.1', port=9042, authenticator=None,
800799 self ._on_orphaned_stream_released = on_orphaned_stream_released
801800
802801 if ssl_options :
803- self ._check_hostname = bool (self .ssl_options .pop ('check_hostname' , False ))
804- if self ._check_hostname :
805- if not getattr (ssl , 'match_hostname' , None ):
806- raise RuntimeError ("ssl_options specify 'check_hostname', but ssl.match_hostname is not provided. "
807- "Patch or upgrade Python to use this option." )
808802 self .ssl_options .update (self .endpoint .ssl_options or {})
809803 elif self .endpoint .ssl_options :
810804 self .ssl_options = self .endpoint .ssl_options
811805
806+ # PYTHON-1331
807+ #
808+ # We always use SSLContext.wrap_socket() now but legacy configs may have other params that were passed to ssl.wrap_socket()...
809+ # and either could have 'check_hostname'. Remove these params into a separate map and use them to build an SSLContext if
810+ # we need to do so.
811+ #
812+ # Note the use of pop() here; we are very deliberately removing these params from ssl_options if they're present. After this
813+ # operation ssl_options should contain only args needed for the ssl_context.wrap_socket() call.
814+ if not self .ssl_context and self .ssl_options :
815+ self .ssl_context = self ._build_ssl_context_from_options ()
812816
813817 if protocol_version >= 3 :
814818 self .max_request_id = min (self .max_in_flight - 1 , (2 ** 15 ) - 1 )
@@ -882,15 +886,48 @@ def factory(cls, endpoint, timeout, host_conn = None, *args, **kwargs):
882886 else :
883887 return conn
884888
889+ def _build_ssl_context_from_options (self ):
890+
891+ # Extract a subset of names from self.ssl_options which apply to SSLContext creation
892+ ssl_context_opt_names = ['ssl_version' , 'cert_reqs' , 'check_hostname' , 'keyfile' , 'certfile' , 'ca_certs' , 'ciphers' ]
893+ opts = {k :self .ssl_options .get (k , None ) for k in ssl_context_opt_names if k in self .ssl_options }
894+
895+ # Python >= 3.10 requires either PROTOCOL_TLS_CLIENT or PROTOCOL_TLS_SERVER so we'll get ahead of things by always
896+ # being explicit
897+ ssl_version = opts .get ('ssl_version' , None ) or ssl .PROTOCOL_TLS_CLIENT
898+ cert_reqs = opts .get ('cert_reqs' , None ) or ssl .CERT_REQUIRED
899+ rv = ssl .SSLContext (protocol = int (ssl_version ))
900+ rv .check_hostname = bool (opts .get ('check_hostname' , False ))
901+ rv .options = int (cert_reqs )
902+
903+ certfile = opts .get ('certfile' , None )
904+ keyfile = opts .get ('keyfile' , None )
905+ if certfile :
906+ rv .load_cert_chain (certfile , keyfile )
907+ ca_certs = opts .get ('ca_certs' , None )
908+ if ca_certs :
909+ rv .load_verify_locations (ca_certs )
910+ ciphers = opts .get ('ciphers' , None )
911+ if ciphers :
912+ rv .set_ciphers (ciphers )
913+
914+ return rv
915+
885916 def _wrap_socket_from_context (self ):
886- ssl_options = self .ssl_options or {}
917+
918+ # Extract a subset of names from self.ssl_options which apply to SSLContext.wrap_socket (or at least the parts
919+ # of it that don't involve building an SSLContext under the covers)
920+ wrap_socket_opt_names = ['server_side' , 'do_handshake_on_connect' , 'suppress_ragged_eofs' , 'server_hostname' ]
921+ opts = {k :self .ssl_options .get (k , None ) for k in wrap_socket_opt_names if k in self .ssl_options }
922+
887923 # PYTHON-1186: set the server_hostname only if the SSLContext has
888924 # check_hostname enabled and it is not already provided by the EndPoint ssl options
889- if (self .ssl_context .check_hostname and
890- 'server_hostname' not in ssl_options ):
891- ssl_options = ssl_options .copy ()
892- ssl_options ['server_hostname' ] = self .endpoint .address
893- self ._socket = self .ssl_context .wrap_socket (self ._socket , ** ssl_options )
925+ #opts['server_hostname'] = self.endpoint.address
926+ if (self .ssl_context .check_hostname and 'server_hostname' not in opts ):
927+ server_hostname = self .endpoint .address
928+ opts ['server_hostname' ] = server_hostname
929+
930+ return self .ssl_context .wrap_socket (self ._socket , ** opts )
894931
895932 def _initiate_connection (self , sockaddr ):
896933 if self .features .shard_id is not None :
@@ -904,8 +941,11 @@ def _initiate_connection(self, sockaddr):
904941
905942 self ._socket .connect (sockaddr )
906943
907- def _match_hostname (self ):
908- ssl .match_hostname (self ._socket .getpeercert (), self .endpoint .address )
944+ # PYTHON-1331
945+ #
946+ # Allow implementations specific to an event loop to add additional behaviours
947+ def _validate_hostname (self ):
948+ pass
909949
910950 def _get_socket_addresses (self ):
911951 address , port = self .endpoint .resolve ()
@@ -927,18 +967,21 @@ def _connect_socket(self):
927967 try :
928968 self ._socket = self ._socket_impl .socket (af , socktype , proto )
929969 if self .ssl_context :
930- self ._wrap_socket_from_context ()
931- elif self .ssl_options :
932- if not self ._ssl_impl :
933- raise RuntimeError ("This version of Python was not compiled with SSL support" )
934- self ._socket = self ._ssl_impl .wrap_socket (self ._socket , ** self .ssl_options )
970+ self ._socket = self ._wrap_socket_from_context ()
935971 self ._socket .settimeout (self .connect_timeout )
936972 self ._initiate_connection (sockaddr )
937973 self ._socket .settimeout (None )
974+
938975 local_addr = self ._socket .getsockname ()
939976 log .debug ("Connection %s: '%s' -> '%s'" , id (self ), local_addr , sockaddr )
977+
978+ # PYTHON-1331
979+ #
980+ # Most checking is done via the check_hostname param on the SSLContext.
981+ # Subclasses can add additional behaviours via _validate_hostname() so
982+ # run that here.
940983 if self ._check_hostname :
941- self ._match_hostname ()
984+ self ._validate_hostname ()
942985 sockerr = None
943986 break
944987 except socket .error as err :
0 commit comments