Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
35 changes: 19 additions & 16 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -23,24 +23,27 @@ is a generic gRPC reverse proxy handler.
The package [`proxy`](proxy/) contains a generic gRPC reverse proxy handler that allows a gRPC server to
not know about registered handlers or their data types. Please consult the docs, here's an exaple usage.

Defining a `StreamDirector` that decides where (if at all) to send the request
Defining a `StreamDirector` that decides where (if at all) to send the request (see
example_test.go):
```go
director = func(ctx context.Context, fullMethodName string) (*grpc.ClientConn, error) {
// Make sure we never forward internal services.
if strings.HasPrefix(fullMethodName, "/com.example.internal.") {
return nil, grpc.Errorf(codes.Unimplemented, "Unknown method")
func (d *ExampleDirector) Connect(ctx context.Context, method string) (context.Context, *grpc.ClientConn, error) {
// Make sure we never forward internal services.
if strings.HasPrefix(method, "/com.example.internal.") {
return nil, nil, grpc.Errorf(codes.Unimplemented, "Unknown method")
}
md, ok := metadata.FromIncomingContext(ctx)
if ok {
// Decide on which backend to dial
if val, exists := md[":authority"]; exists && val[0] == "staging.api.example.com" {
// Make sure we use DialContext so the dialing can be cancelled/time out together with the context.
conn, err := grpc.DialContext(ctx, "api-service.staging.svc.local", grpc.WithCodec(proxy.Codec()))
return ctx, conn, err
} else if val, exists := md[":authority"]; exists && val[0] == "api.example.com" {
conn, err := grpc.DialContext(ctx, "api-service.prod.svc.local", grpc.WithCodec(proxy.Codec()))
return ctx, conn, err
}
md, ok := metadata.FromContext(ctx)
if ok {
// Decide on which backend to dial
if val, exists := md[":authority"]; exists && val[0] == "staging.api.example.com" {
// Make sure we use DialContext so the dialing can be cancelled/time out together with the context.
return grpc.DialContext(ctx, "api-service.staging.svc.local", grpc.WithCodec(proxy.Codec()))
} else if val, exists := md[":authority"]; exists && val[0] == "api.example.com" {
return grpc.DialContext(ctx, "api-service.prod.svc.local", grpc.WithCodec(proxy.Codec()))
}
}
return nil, grpc.Errorf(codes.Unimplemented, "Unknown method")
}
return nil, nil, grpc.Errorf(codes.Unimplemented, "Unknown method")
}
```
Then you need to register it with a `grpc.Server`. The server may have other handlers that will be served
Expand Down
7 changes: 6 additions & 1 deletion proxy/DOC.md
Original file line number Diff line number Diff line change
Expand Up @@ -67,10 +67,15 @@ ServerOption.
#### type StreamDirector

```go
type StreamDirector func(ctx context.Context, fullMethodName string) (*grpc.ClientConn, error)
type StreamDirector interface {
Connect(ctx context.Context, method string) (context.Context, *grpc.ClientConn, error)
Release(conn *grpc.ClientConn, method string)
}
```

StreamDirector returns a gRPC ClientConn to be used to forward the call to.
The Release method provides connection management, allowing the director to
cache connections.

The presence of the `Context` allows for rich filtering, e.g. based on Metadata
(headers). If no handling is meant to be done, a `codes.NotImplemented` gRPC
Expand Down
34 changes: 26 additions & 8 deletions proxy/director.go
Original file line number Diff line number Diff line change
Expand Up @@ -8,13 +8,31 @@ import (
"google.golang.org/grpc"
)

// StreamDirector returns a gRPC ClientConn to be used to forward the call to.
// StreamDirector manages gRPC Client connections used to forward requests.
//
// The presence of the `Context` allows for rich filtering, e.g. based on Metadata (headers).
// If no handling is meant to be done, a `codes.NotImplemented` gRPC error should be returned.
// The presence of the `Context` allows for rich filtering, e.g. based on
// Metadata (headers). If no handling is meant to be done, a
// `codes.NotImplemented` gRPC error should be returned.
//
// It is worth noting that the StreamDirector will be fired *after* all server-side stream interceptors
// are invoked. So decisions around authorization, monitoring etc. are better to be handled there.
//
// See the rather rich example.
type StreamDirector func(ctx context.Context, fullMethodName string) (*grpc.ClientConn, error)
// It is worth noting that the Connect will be called *after* all server-side
// stream interceptors are invoked. So decisions around authorization,
// monitoring etc. are better to be handled there.
type StreamDirector interface {
// Connect returns a connection to use for the given method,
// or an error if the call should not be handled.
//
// The provided context may be inspected for filtering on request
// metadata.
//
// The returned context is used as the basis for the outgoing connection.
Connect(ctx context.Context, method string) (context.Context, *grpc.ClientConn, error)

// Release is called when a connection is longer being used. This is called
// once for every call to Connect that does not return an error.
//
// The provided context is the one returned from Connect.
//
// This can be used by the director to pool connections or close unused
// connections.
Release(ctx context.Context, conn *grpc.ClientConn)
}
41 changes: 25 additions & 16 deletions proxy/examples_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -34,22 +34,31 @@ func ExampleTransparentHandler() {

// Provide sa simple example of a director that shields internal services and dials a staging or production backend.
// This is a *very naive* implementation that creates a new connection on every request. Consider using pooling.
func ExampleStreamDirector() {
director = func(ctx context.Context, fullMethodName string) (*grpc.ClientConn, error) {
// Make sure we never forward internal services.
if strings.HasPrefix(fullMethodName, "/com.example.internal.") {
return nil, grpc.Errorf(codes.Unimplemented, "Unknown method")
}
md, ok := metadata.FromContext(ctx)
if ok {
// Decide on which backend to dial
if val, exists := md[":authority"]; exists && val[0] == "staging.api.example.com" {
// Make sure we use DialContext so the dialing can be cancelled/time out together with the context.
return grpc.DialContext(ctx, "api-service.staging.svc.local", grpc.WithCodec(proxy.Codec()))
} else if val, exists := md[":authority"]; exists && val[0] == "api.example.com" {
return grpc.DialContext(ctx, "api-service.prod.svc.local", grpc.WithCodec(proxy.Codec()))
}
type ExampleDirector struct {
}

func (d *ExampleDirector) Connect(ctx context.Context, method string) (context.Context, *grpc.ClientConn, error) {
// Make sure we never forward internal services.
if strings.HasPrefix(method, "/com.example.internal.") {
return nil, nil, grpc.Errorf(codes.Unimplemented, "Unknown method")
}
md, ok := metadata.FromIncomingContext(ctx)
if ok {
// Decide on which backend to dial
if val, exists := md[":authority"]; exists && val[0] == "staging.api.example.com" {
// Make sure we use DialContext so the dialing can be cancelled/time out together with the context.
conn, err := grpc.DialContext(ctx, "api-service.staging.svc.local", grpc.WithCodec(proxy.Codec()))
return ctx, conn, err
} else if val, exists := md[":authority"]; exists && val[0] == "api.example.com" {
conn, err := grpc.DialContext(ctx, "api-service.prod.svc.local", grpc.WithCodec(proxy.Codec()))
return ctx, conn, err
}
return nil, grpc.Errorf(codes.Unimplemented, "Unknown method")
}
return nil, nil, grpc.Errorf(codes.Unimplemented, "Unknown method")
}

func (d *ExampleDirector) Release(ctx context.Context, conn *grpc.ClientConn) {
conn.Close()
}

var _ proxy.StreamDirector = &ExampleDirector{}
44 changes: 33 additions & 11 deletions proxy/handler.go
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,8 @@ import (
"golang.org/x/net/context"
"google.golang.org/grpc"
"google.golang.org/grpc/codes"
"google.golang.org/grpc/metadata"
"google.golang.org/grpc/peer"
"google.golang.org/grpc/transport"
)

Expand Down Expand Up @@ -59,27 +61,34 @@ type handler struct {
// It is invoked like any gRPC server stream and uses the gRPC server framing to get and receive bytes from the wire,
// forwarding it to a ClientStream established against the relevant ClientConn.
func (s *handler) handler(srv interface{}, serverStream grpc.ServerStream) error {
// little bit of gRPC internals never hurt anyone
lowLevelServerStream, ok := transport.StreamFromContext(serverStream.Context())
serverCtx := serverStream.Context()
lowLevelServerStream, ok := transport.StreamFromContext(serverCtx)
if !ok {
return grpc.Errorf(codes.Internal, "lowLevelServerStream not exists in context")
}
fullMethodName := lowLevelServerStream.Method()
clientCtx, clientCancel := context.WithCancel(serverStream.Context())
backendConn, err := s.director(serverStream.Context(), fullMethodName)
outCtx, backendConn, err := s.director.Connect(serverCtx, fullMethodName)
if err != nil {
return err
}
// TODO(mwitkow): Add a `forwarded` header to metadata, https://en.wikipedia.org/wiki/X-Forwarded-For.
defer s.director.Release(outCtx, backendConn)

clientCtx, clientCancel := context.WithCancel(outCtx)
defer clientCancel()
if _, ok := metadata.FromOutgoingContext(outCtx); !ok {
// Add a `forwarded` header to metadata, https://en.wikipedia.org/wiki/X-Forwarded-For.
clientCtx = addMetadata(clientCtx, outCtx)
}
clientStream, err := grpc.NewClientStream(clientCtx, clientStreamDescForProxying, backendConn, fullMethodName)
if err != nil {
return err
}

// Explicitly *do not close* s2cErrChan and c2sErrChan, otherwise the select below will not terminate.
// Channels do not have to be closed, it is just a control flow mechanism, see
// https://groups.google.com/forum/#!msg/golang-nuts/pZwdYRGxCIk/qpbHxRRPJdUJ
s2cErrChan := s.forwardServerToClient(serverStream, clientStream)
c2sErrChan := s.forwardClientToServer(clientStream, serverStream)
s2cErrChan := forwardServerToClient(serverStream, clientStream)
c2sErrChan := forwardClientToServer(clientStream, serverStream)
// We don't know which side is going to stop sending first, so we need a select between the two.
for i := 0; i < 2; i++ {
select {
Expand All @@ -88,12 +97,11 @@ func (s *handler) handler(srv interface{}, serverStream grpc.ServerStream) error
// this is the happy case where the sender has encountered io.EOF, and won't be sending anymore./
// the clientStream>serverStream may continue pumping though.
clientStream.CloseSend()
break
} else {
// however, we may have gotten a receive error (stream disconnected, a read error etc) in which case we need
// to cancel the clientStream to the backend, let all of its goroutines be freed up by the CancelFunc and
// exit with an error to the stack
clientCancel()
// clientCancel()
return grpc.Errorf(codes.Internal, "failed proxying s2c: %v", s2cErr)
}
case c2sErr := <-c2sErrChan:
Expand All @@ -111,7 +119,21 @@ func (s *handler) handler(srv interface{}, serverStream grpc.ServerStream) error
return grpc.Errorf(codes.Internal, "gRPC proxying should never reach this stage.")
}

func (s *handler) forwardClientToServer(src grpc.ClientStream, dst grpc.ServerStream) chan error {
func addMetadata(ctx context.Context, serverCtx context.Context) context.Context {
source := "unknown"
if peer, ok := peer.FromContext(serverCtx); ok && peer.Addr != nil {
source = peer.Addr.String()
}
forwardMD := metadata.Pairs("X-Forwarded-For", source)

md, ok := metadata.FromIncomingContext(serverCtx)
if ok {
return metadata.NewOutgoingContext(ctx, metadata.Join(md, forwardMD))
}
return metadata.NewOutgoingContext(ctx, forwardMD)
}

func forwardClientToServer(src grpc.ClientStream, dst grpc.ServerStream) chan error {
ret := make(chan error, 1)
go func() {
f := &frame{}
Expand Down Expand Up @@ -143,7 +165,7 @@ func (s *handler) forwardClientToServer(src grpc.ClientStream, dst grpc.ServerSt
return ret
}

func (s *handler) forwardServerToClient(src grpc.ServerStream, dst grpc.ClientStream) chan error {
func forwardServerToClient(src grpc.ServerStream, dst grpc.ClientStream) chan error {
ret := make(chan error, 1)
go func() {
f := &frame{}
Expand Down
35 changes: 22 additions & 13 deletions proxy/handler_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -46,10 +46,10 @@ type assertingService struct {

func (s *assertingService) PingEmpty(ctx context.Context, _ *pb.Empty) (*pb.PingResponse, error) {
// Check that this call has client's metadata.
md, ok := metadata.FromContext(ctx)
md, ok := metadata.FromIncomingContext(ctx)
assert.True(s.t, ok, "PingEmpty call must have metadata in context")
_, ok = md[clientMdKey]
assert.True(s.t, ok, "PingEmpty call must have clients's custom headers in metadata")
assert.True(s.t, ok, "PingEmpty call must have clients's custom headers in metadata: %+v", md)
return &pb.PingResponse{Value: pingDefaultValue, Counter: 42}, nil
}

Expand Down Expand Up @@ -116,7 +116,7 @@ func (s *ProxyHappySuite) ctx() context.Context {
}

func (s *ProxyHappySuite) TestPingEmptyCarriesClientMetadata() {
ctx := metadata.NewContext(s.ctx(), metadata.Pairs(clientMdKey, "true"))
ctx := metadata.NewOutgoingContext(s.ctx(), metadata.Pairs(clientMdKey, "true"))
out, err := s.testClient.PingEmpty(ctx, &pb.Empty{})
require.NoError(s.T(), err, "PingEmpty should succeed without errors")
require.Equal(s.T(), &pb.PingResponse{Value: pingDefaultValue, Counter: 42}, out)
Expand Down Expand Up @@ -148,7 +148,7 @@ func (s *ProxyHappySuite) TestPingErrorPropagatesAppError() {

func (s *ProxyHappySuite) TestDirectorErrorIsPropagated() {
// See SetupSuite where the StreamDirector has a special case.
ctx := metadata.NewContext(s.ctx(), metadata.Pairs(rejectingMdKey, "true"))
ctx := metadata.NewOutgoingContext(s.ctx(), metadata.Pairs(rejectingMdKey, "true"))
_, err := s.testClient.Ping(ctx, &pb.PingRequest{Value: "foo"})
require.Error(s.T(), err, "Director should reject this RPC")
assert.Equal(s.T(), codes.PermissionDenied, grpc.Code(err))
Expand Down Expand Up @@ -188,6 +188,23 @@ func (s *ProxyHappySuite) TestPingStream_StressTest() {
}
}

type checkingDirector struct {
conn *grpc.ClientConn
}

func (c *checkingDirector) Connect(ctx context.Context, method string) (context.Context, *grpc.ClientConn, error) {
md, ok := metadata.FromIncomingContext(ctx)
if ok {
if _, exists := md[rejectingMdKey]; exists {
return ctx, nil, grpc.Errorf(codes.PermissionDenied, "testing rejection")
}
}
return ctx, c.conn, nil
}

func (c *checkingDirector) Release(ctx context.Context, conn *grpc.ClientConn) {
}

func (s *ProxyHappySuite) SetupSuite() {
var err error

Expand All @@ -204,15 +221,7 @@ func (s *ProxyHappySuite) SetupSuite() {
// Setup of the proxy's Director.
s.serverClientConn, err = grpc.Dial(s.serverListener.Addr().String(), grpc.WithInsecure(), grpc.WithCodec(proxy.Codec()))
require.NoError(s.T(), err, "must not error on deferred client Dial")
director := func(ctx context.Context, fullName string) (*grpc.ClientConn, error) {
md, ok := metadata.FromContext(ctx)
if ok {
if _, exists := md[rejectingMdKey]; exists {
return nil, grpc.Errorf(codes.PermissionDenied, "testing rejection")
}
}
return s.serverClientConn, nil
}
director := &checkingDirector{conn: s.serverClientConn}
s.proxy = grpc.NewServer(
grpc.CustomCodec(proxy.Codec()),
grpc.UnknownServiceHandler(proxy.TransparentHandler(director)),
Expand Down