Skip to content
This repository was archived by the owner on Apr 30, 2025. It is now read-only.
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 7 additions & 0 deletions mbus/subscriber.go
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@ import (
type RegistryMessage struct {
Host string `json:"host"`
Port uint16 `json:"port"`
Protocol string `json:"protocol"`
TLSPort uint16 `json:"tls_port"`
Uris []route.Uri `json:"uris"`
Tags map[string]string `json:"tags"`
Expand All @@ -51,10 +52,16 @@ func (rm *RegistryMessage) makeEndpoint() (*route.Endpoint, error) {
updatedAt = time.Unix(0, rm.EndpointUpdatedAtNs).UTC()
}

protocol := rm.Protocol
if protocol == "" {
protocol = "http1"
}

return route.NewEndpoint(&route.EndpointOpts{
AppId: rm.App,
Host: rm.Host,
Port: port,
Protocol: protocol,
ServerCertDomainSAN: rm.ServerCertDomainSAN,
PrivateInstanceId: rm.PrivateInstanceID,
PrivateInstanceIndex: rm.PrivateInstanceIndex,
Expand Down
8 changes: 8 additions & 0 deletions mbus/subscriber_easyjson.go

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

66 changes: 66 additions & 0 deletions mbus/subscriber_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -349,6 +349,69 @@ var _ = Describe("Subscriber", func() {
})
})

Context("when the message does not contain a protocol", func() {
BeforeEach(func() {
sub = mbus.NewSubscriber(natsClient, registry, cfg, reconnected, l)
process = ifrit.Invoke(sub)
Eventually(process.Ready()).Should(BeClosed())
})
It("endpoint is constructed with protocol http1", func() {
msg := mbus.RegistryMessage{
Host: "host",
App: "app",
Uris: []route.Uri{"test.example.com"},
}

data, err := json.Marshal(msg)
Expect(err).NotTo(HaveOccurred())

err = natsClient.Publish("router.register", data)
Expect(err).ToNot(HaveOccurred())

Eventually(registry.RegisterCallCount).Should(Equal(1))
_, originalEndpoint := registry.RegisterArgsForCall(0)
expectedEndpoint := route.NewEndpoint(&route.EndpointOpts{
Host: "host",
AppId: "app",
Protocol: "http1",
})

Expect(originalEndpoint).To(Equal(expectedEndpoint))
})
})

Context("when the message contains a protocol", func() {
BeforeEach(func() {
sub = mbus.NewSubscriber(natsClient, registry, cfg, reconnected, l)
process = ifrit.Invoke(sub)
Eventually(process.Ready()).Should(BeClosed())
})
It("endpoint is constructed with the protocol", func() {
msg := mbus.RegistryMessage{
Host: "host",
App: "app",
Protocol: "http2",
Uris: []route.Uri{"test.example.com"},
}

data, err := json.Marshal(msg)
Expect(err).NotTo(HaveOccurred())

err = natsClient.Publish("router.register", data)
Expect(err).ToNot(HaveOccurred())

Eventually(registry.RegisterCallCount).Should(Equal(1))
_, originalEndpoint := registry.RegisterArgsForCall(0)
expectedEndpoint := route.NewEndpoint(&route.EndpointOpts{
Host: "host",
AppId: "app",
Protocol: "http2",
})

Expect(originalEndpoint).To(Equal(expectedEndpoint))
})
})

Context("when the message contains a tls port for route", func() {
BeforeEach(func() {
sub = mbus.NewSubscriber(natsClient, registry, cfg, reconnected, l)
Expand Down Expand Up @@ -380,6 +443,7 @@ var _ = Describe("Subscriber", func() {
Host: "host",
AppId: "app",
Port: 1999,
Protocol: "http1",
UseTLS: true,
ServerCertDomainSAN: "san",
PrivateInstanceId: "id",
Expand Down Expand Up @@ -413,6 +477,7 @@ var _ = Describe("Subscriber", func() {
expectedEndpoint := route.NewEndpoint(&route.EndpointOpts{
Host: "host",
Port: 1111,
Protocol: "http1",
UpdatedAt: time.Unix(0, 1234).UTC(),
})

Expand Down Expand Up @@ -451,6 +516,7 @@ var _ = Describe("Subscriber", func() {
Host: "host",
AppId: "app",
Port: 1111,
Protocol: "http1",
UseTLS: false,
ServerCertDomainSAN: "san",
PrivateInstanceId: "id",
Expand Down
23 changes: 8 additions & 15 deletions proxy/backend_tls_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -20,16 +20,9 @@ var _ = Describe("Backend TLS", func() {
Expect(err).NotTo(HaveOccurred())
return caCertPool
}
// createCertAndAddCA creates a signed cert with a root CA and adds the CA
// to the specified cert pool
createCertAndAddCA := func(cn test_util.CertNames, cp *x509.CertPool) test_util.CertChain {
certChain := test_util.CreateSignedCertWithRootCA(cn)
cp.AddCert(certChain.CACert)
return certChain
}

registerAppAndTest := func() *http.Response {
ln := test_util.RegisterHandler(r, "test", func(conn *test_util.HttpConn) {
ln := test_util.RegisterConnHandler(r, "test", func(conn *test_util.HttpConn) {
req, err := http.ReadRequest(conn.Reader)
if err != nil {
conn.WriteResponse(test_util.NewResponse(http.StatusInternalServerError))
Expand Down Expand Up @@ -62,8 +55,8 @@ var _ = Describe("Backend TLS", func() {
// Clear backend app's CA cert pool
backendCACertPool := x509.NewCertPool()

backendCertChain := createCertAndAddCA(test_util.CertNames{CommonName: serverCertDomainSAN}, proxyCertPool)
clientCertChain := createCertAndAddCA(test_util.CertNames{CommonName: "gorouter"}, backendCACertPool)
backendCertChain := test_util.CreateCertAndAddCA(test_util.CertNames{CommonName: serverCertDomainSAN}, proxyCertPool)
clientCertChain := test_util.CreateCertAndAddCA(test_util.CertNames{CommonName: "gorouter"}, backendCACertPool)

backendTLSConfig := backendCertChain.AsTLSConfig()
backendTLSConfig.ClientCAs = backendCACertPool
Expand Down Expand Up @@ -162,7 +155,7 @@ var _ = Describe("Backend TLS", func() {
Context("when the backend instance returns a cert that only has a DNS SAN", func() {
BeforeEach(func() {
proxyCertPool := freshProxyCACertPool()
backendCertChain := createCertAndAddCA(test_util.CertNames{
backendCertChain := test_util.CreateCertAndAddCA(test_util.CertNames{
SANs: test_util.SubjectAltNames{DNS: registerConfig.ServerCertDomainSAN},
}, proxyCertPool)
registerConfig.TLSConfig = backendCertChain.AsTLSConfig()
Expand All @@ -178,7 +171,7 @@ var _ = Describe("Backend TLS", func() {
Context("when the backend instance returns a cert that has a matching CommonName but non-matching DNS SAN", func() {
BeforeEach(func() {
proxyCertPool := freshProxyCACertPool()
backendCertChain := createCertAndAddCA(test_util.CertNames{
backendCertChain := test_util.CreateCertAndAddCA(test_util.CertNames{
CommonName: registerConfig.InstanceId,
SANs: test_util.SubjectAltNames{DNS: "foo"},
}, proxyCertPool)
Expand All @@ -194,7 +187,7 @@ var _ = Describe("Backend TLS", func() {
Context("when the backend instance returns a cert that has a non-matching CommonName but matching DNS SAN", func() {
BeforeEach(func() {
proxyCertPool := freshProxyCACertPool()
backendCertChain := createCertAndAddCA(test_util.CertNames{
backendCertChain := test_util.CreateCertAndAddCA(test_util.CertNames{
CommonName: "foo",
SANs: test_util.SubjectAltNames{DNS: registerConfig.ServerCertDomainSAN},
}, proxyCertPool)
Expand All @@ -210,7 +203,7 @@ var _ = Describe("Backend TLS", func() {
Context("when the backend instance returns a cert that has a matching CommonName but non-matching IP SAN", func() {
BeforeEach(func() {
proxyCertPool := freshProxyCACertPool()
backendCertChain := createCertAndAddCA(test_util.CertNames{
backendCertChain := test_util.CreateCertAndAddCA(test_util.CertNames{
CommonName: registerConfig.InstanceId,
SANs: test_util.SubjectAltNames{IP: "192.0.2.1"},
}, proxyCertPool)
Expand All @@ -226,7 +219,7 @@ var _ = Describe("Backend TLS", func() {
Context("when the backend instance returns a cert that has a non-matching CommonName but matching IP SAN", func() {
BeforeEach(func() {
proxyCertPool := freshProxyCACertPool()
backendCertChain := createCertAndAddCA(test_util.CertNames{
backendCertChain := test_util.CreateCertAndAddCA(test_util.CertNames{
CommonName: "foo",
SANs: test_util.SubjectAltNames{IP: "127.0.0.1"},
}, proxyCertPool)
Expand Down
13 changes: 10 additions & 3 deletions proxy/proxy_suite_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -132,8 +132,15 @@ var _ = JustBeforeEach(func() {

p = proxy.NewProxy(testLogger, al, ew, conf, r, fakeReporter, routeServiceConfig, tlsConfig, tlsConfig, healthStatus, fakeRouteServicesClient)

server := http.Server{Handler: p}
go server.Serve(proxyServer)
if conf.EnableHTTP2 {
server := http.Server{Handler: p}
tlsConfig.NextProtos = []string{"h2", "http/1.1"}
tlsListener := tls.NewListener(proxyServer, tlsConfig)
go server.Serve(tlsListener)
} else {
server := http.Server{Handler: p}
go server.Serve(proxyServer)
}
})

var _ = AfterEach(func() {
Expand All @@ -144,7 +151,7 @@ var _ = AfterEach(func() {
})

func shouldEcho(input string, expected string) {
ln := test_util.RegisterHandler(r, "encoding", func(x *test_util.HttpConn) {
ln := test_util.RegisterConnHandler(r, "encoding", func(x *test_util.HttpConn) {
x.CheckLine("GET " + expected + " HTTP/1.1")
resp := test_util.NewResponse(http.StatusOK)
x.WriteResponse(resp)
Expand Down
Loading