Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
72 changes: 72 additions & 0 deletions certgen/main.go
Original file line number Diff line number Diff line change
Expand Up @@ -9,9 +9,11 @@ import (
"crypto/ecdsa"
"crypto/elliptic"
"crypto/rand"
"crypto/rsa"
"crypto/x509"
"crypto/x509/pkix"
"encoding/pem"
"fmt"
"math/big"
"os"
"path"
Expand Down Expand Up @@ -40,11 +42,46 @@ func writeKey(path string, keyx interface{}) {
if err != nil {
panic(err)
}
case *rsa.PrivateKey:
m := x509.MarshalPKCS1PrivateKey(key)
err := pem.Encode(file, &pem.Block{Type: "RSA PRIVATE KEY", Bytes: m})
if err != nil {
panic(err)
}
default:
panic("Unknown key type")
}
}

func writeKeyEncrypted(password, path string, keyx interface{}) {
file := createFile(path)
defer file.Close()

switch key := keyx.(type) {
case *rsa.PrivateKey:
m := x509.MarshalPKCS1PrivateKey(key)

block := &pem.Block{
Type: "RSA PRIVATE KEY",
Bytes: m,
}

block, err := x509.EncryptPEMBlock(rand.Reader, block.Type, block.Bytes,
[]byte(password), x509.PEMCipherAES256)
if err != nil {
panic(err)
}

err = pem.Encode(file, block)
if err != nil {
panic(err)
}
default:
panic("Unknown key type")
}

}

func writeCert(path string, der []byte) {
file := createFile(path)
defer file.Close()
Expand Down Expand Up @@ -113,6 +150,30 @@ func generateServer(parent *x509.Certificate, parentPrivate interface{}, notBefo
return key, derBytes
}

func generateRsa4096Sha512(notBefore, notAfter time.Time, commonName string) (interface{}, []byte) {
key, err := rsa.GenerateKey(rand.Reader, 4096)
if err != nil {
panic(err)
}

template := x509.Certificate{
SerialNumber: newSerialNumber(),
NotBefore: notBefore,
NotAfter: notAfter,
Subject: pkix.Name{CommonName: commonName},
KeyUsage: x509.KeyUsageDigitalSignature,
ExtKeyUsage: []x509.ExtKeyUsage{x509.ExtKeyUsageClientAuth},
BasicConstraintsValid: true,
IsCA: true,
}

derBytes, err := x509.CreateCertificate(rand.Reader, &template, &template, &key.PublicKey, key)
if err != nil {
panic(err)
}
return key, derBytes
}

func main() {
basePath := path.Join(os.Args[1], "certs")

Expand Down Expand Up @@ -177,4 +238,15 @@ func main() {
untrustedRoot_server1Key, untrustedRoot_server1Der := generateServer(untrustedRootCert, untrustedRootKey, anHourAgo, tenYearsFromNow, "untrustedRoot_thehost", "thehost")
writeKey(path.Join(basePath, "server", "untrustedRoot_thehost.key"), untrustedRoot_server1Key)
writeCert(path.Join(basePath, "server", "untrustedRoot_thehost.pem"), untrustedRoot_server1Der)

// Generate client's certificates
for i := 1; i <= 2; i++ {
clientKey, clientDer := generateRsa4096Sha512(anHourAgo, tenYearsFromNow, "client")
writeCert(path.Join(basePath, "driver", fmt.Sprintf("certificate%d.pem", i)), clientDer)
writeKey(path.Join(basePath, "driver", fmt.Sprintf("privatekey%d.pem", i)), clientKey)
writeKeyEncrypted(fmt.Sprintf("thepassword%d", i),
path.Join(basePath, "driver", fmt.Sprintf("privatekey%d_with_thepassword%d.pem", i, i)), clientKey)
// Copy to
writeCert(path.Join(basePath, "server", "bolt", "trusted", fmt.Sprintf("client%d.pem", i)), clientDer)
}
}
4 changes: 4 additions & 0 deletions nutkit/frontend/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,10 @@
BookmarkManager,
Neo4jBookmarkManagerConfig,
)
from .client_certificate_provider import (
ClientCertificateHolder,
ClientCertificateProvider,
)
from .driver import Driver
from .exceptions import ApplicationCodeError
from .fake_time import FakeTime
Expand Down
4 changes: 0 additions & 4 deletions nutkit/frontend/auth_token_manager.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,5 @@
from __future__ import annotations

from dataclasses import dataclass
from typing import (
Any,
Callable,
Expand Down Expand Up @@ -42,7 +41,6 @@
]


@dataclass
class AuthTokenManager:
_registry: ClassVar[Dict[Any, AuthTokenManager]] = {}

Expand Down Expand Up @@ -105,7 +103,6 @@ def close(self, hooks=None):
del self._registry[self.id]


@dataclass
class BasicAuthTokenManager:
_registry: ClassVar[Dict[Any, BasicAuthTokenManager]] = {}

Expand Down Expand Up @@ -163,7 +160,6 @@ def close(self, hooks=None):
del self._registry[self.id]


@dataclass
class BearerAuthTokenManager:
_registry: ClassVar[Dict[Any, BearerAuthTokenManager]] = {}

Expand Down
1 change: 0 additions & 1 deletion nutkit/frontend/bookmark_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,6 @@ class Neo4jBookmarkManagerConfig:
bookmarks_consumer: Optional[Callable[[List[str]], None]] = None


@dataclass
class BookmarkManager:
_registry: ClassVar[Dict[Any, BookmarkManager]] = {}

Expand Down
85 changes: 85 additions & 0 deletions nutkit/frontend/client_certificate_provider.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,85 @@
from __future__ import annotations

from dataclasses import dataclass
from typing import (
Any,
Callable,
ClassVar,
Dict,
)

from ..backend import Backend
from ..protocol import ClientCertificate
from ..protocol import (
ClientCertificateProvider as ClientCertificateProviderMessage,
)
from ..protocol import (
ClientCertificateProviderClose,
ClientCertificateProviderCompleted,
ClientCertificateProviderRequest,
NewClientCertificateProvider,
)

__all__ = [
"ClientCertificateHolder",
"ClientCertificateProvider",
]


@dataclass
class ClientCertificateHolder:
cert: ClientCertificate
has_update: bool = True


class ClientCertificateProvider:
_registry: ClassVar[Dict[Any, ClientCertificateProvider]] = {}
_backend: Any
_handler: Callable[[], ClientCertificateHolder]

def __init__(
self,
backend: Backend,
handler: Callable[[], ClientCertificateHolder],
):
self._backend = backend
self._handler = handler

req = NewClientCertificateProvider()
res = backend.send_and_receive(req)
if not isinstance(res, ClientCertificateProviderMessage):
raise Exception(
f"Should be ClientCertificateProvider but was {res}"
)

self._client_certificate_provider = res
self._registry[self._client_certificate_provider.id] = self

@property
def id(self):
return self._client_certificate_provider.id

@classmethod
def process_callbacks(cls, request):
if isinstance(request, ClientCertificateProviderRequest):
if request.client_certificate_provider_id not in cls._registry:
raise Exception(
"Backend provided unknown Client Certificate Provider "
f"id: {request.client_certificate_provider_id} not found"
)
manager = cls._registry[request.client_certificate_provider_id]
cert_holder = manager._handler()
return ClientCertificateProviderCompleted(
request.id, cert_holder.has_update, cert_holder.cert
)

def close(self, hooks=None):
res = self._backend.send_and_receive(
ClientCertificateProviderClose(self.id),
hooks=hooks
)
if not isinstance(res, ClientCertificateProviderMessage):
raise Exception(
f"Should be ClientCertificateProvider but was {res}"
)
del self._registry[self.id]
27 changes: 21 additions & 6 deletions nutkit/frontend/driver.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
BearerAuthTokenManager,
)
from .bookmark_manager import BookmarkManager
from .client_certificate_provider import ClientCertificateProvider
from .session import Session


Expand All @@ -18,7 +19,8 @@ def __init__(self, backend, uri, auth_token, user_agent=None,
connection_acquisition_timeout_ms=None,
notifications_min_severity=None,
notifications_disabled_categories=None,
telemetry_disabled=None):
telemetry_disabled=None,
client_certificate=None):
self._backend = backend
self._resolver_fn = resolver_fn
self._domain_name_resolver_fn = domain_name_resolver_fn
Expand All @@ -37,6 +39,16 @@ def __init__(self, backend, uri, auth_token, user_agent=None,
)
self._auth_token_manager = auth_token
auth_token_manager_id = auth_token.id
client_certificate_, client_certificate_provider_id_ = None, None
if client_certificate is not None:
assert isinstance(
client_certificate,
(protocol.ClientCertificate, ClientCertificateProvider)
)
if isinstance(client_certificate, protocol.ClientCertificate):
client_certificate_ = client_certificate
else:
client_certificate_provider_id_ = client_certificate.id

req = protocol.NewDriver(
uri, self._auth_token, auth_token_manager_id,
Expand All @@ -50,7 +62,9 @@ def __init__(self, backend, uri, auth_token, user_agent=None,
connection_acquisition_timeout_ms=connection_acquisition_timeout_ms, # noqa: E501
notifications_min_severity=notifications_min_severity,
notifications_disabled_categories=notifications_disabled_categories, # noqa: E501
telemetry_disabled=telemetry_disabled
telemetry_disabled=telemetry_disabled,
client_certificate=client_certificate_,
client_certificate_provider_id=client_certificate_provider_id_,
)
res = backend.send_and_receive(req)
if not isinstance(res, protocol.Driver):
Expand Down Expand Up @@ -78,10 +92,11 @@ def receive(self, timeout=None, hooks=None, *, allow_resolution):
)
continue
for cb_processor in (
AuthTokenManager,
BasicAuthTokenManager,
BearerAuthTokenManager,
BookmarkManager,
AuthTokenManager,
BasicAuthTokenManager,
BearerAuthTokenManager,
BookmarkManager,
ClientCertificateProvider,
):
cb_response = cb_processor.process_callbacks(res)
if cb_response is not None:
Expand Down
2 changes: 2 additions & 0 deletions nutkit/protocol/feature.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,6 +67,8 @@ class Feature(Enum):
# The session supports notification filters configuration.
API_SESSION_NOTIFICATIONS_CONFIG = \
"Feature:API:Session:NotificationsConfig"
# The driver implements configuration for client certificates.
API_SSL_CLIENT_CERTIFICATE = "Feature:API:SSLClientCertificate"
# The driver implements explicit configuration options for SSL.
# - enable / disable SSL
# - verify signature against system store / custom cert / not at all
Expand Down
59 changes: 58 additions & 1 deletion nutkit/protocol/requests.py
Original file line number Diff line number Diff line change
Expand Up @@ -75,7 +75,8 @@ def __init__(
connection_acquisition_timeout_ms=None,
notifications_min_severity=None,
notifications_disabled_categories=None,
telemetry_disabled=None
telemetry_disabled=None,
client_certificate=None, client_certificate_provider_id=None,
):
# Neo4j URI to connect to
self.uri = uri
Expand All @@ -93,6 +94,10 @@ def __init__(
self.livenessCheckTimeoutMs = liveness_check_timeout_ms
self.maxConnectionPoolSize = max_connection_pool_size
self.connectionAcquisitionTimeoutMs = connection_acquisition_timeout_ms
assert (client_certificate is None
or client_certificate_provider_id is None)
self.clientCertificate = client_certificate
self.clientCertificateProviderId = client_certificate_provider_id
if notifications_min_severity is not None:
self.notificationsMinSeverity = notifications_min_severity
if notifications_disabled_categories is not None:
Expand Down Expand Up @@ -259,6 +264,58 @@ def __init__(self, request_id, auth):
self.auth = auth


class ClientCertificate:
"""
Not a request but used in `NewDriver`.

This property is used for configuring client certificates
for mutual TLS configuration.
"""

def __init__(self, certfile, keyfile, password=None):
self.certfile = certfile
self.keyfile = keyfile
self.password = password


class NewClientCertificateProvider:
"""
Create a new client certificate provider on the backend.

The backend should respond with `ClientCertificateProvider`.
"""

def __init__(self):
pass


class ClientCertificateProviderClose:
"""
Request to remove a client certificate provider from the backend.

The backend may free any resources associated with the provider and respond
with `ClientCertificateProvider` echoing back the given id.
"""

def __init__(self, id):
# Id of the client certificate provider to close.
self.id = id


class ClientCertificateProviderCompleted:
"""
Result of a completed client certificate provider call.

No response is expected.
"""

def __init__(self, request_id, has_update, client_certificate):
self.requestId = request_id
assert isinstance(client_certificate, ClientCertificate)
self.clientCertificate = client_certificate
self.hasUpdate = bool(has_update)


class VerifyConnectivity:
"""
Request to verify connectivity on the driver.
Expand Down
Loading