Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
53 changes: 42 additions & 11 deletions rpc/http.go
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,24 @@ type httpConn struct {
headers http.Header
}

type HTTPTransportOption func(*httpConn)

// WithHeaders merges the default headers with the headers provided.
func WithHeaders(headers map[string]string) HTTPTransportOption {
return func(h *httpConn) {
for k, v := range headers {
h.headers.Set(k, v)
}
}
}

// WithHTTPClient sets the http client on the HTTP Transport
func WithHTTPClient(c *http.Client) HTTPTransportOption {
return func(h *httpConn) {
h.client = c
}
}

// httpConn implements ServerCodec, but it is treated specially by Client
// and some methods don't work. The panic() stubs here exist to ensure
// this special treatment is correct.
Expand Down Expand Up @@ -112,32 +130,45 @@ var DefaultHTTPTimeouts = HTTPTimeouts{
// DialHTTPWithClient creates a new RPC client that connects to an RPC server over HTTP
// using the provided HTTP Client.
func DialHTTPWithClient(endpoint string, client *http.Client) (*Client, error) {
return DialHTTPWithOptions(endpoint, WithHTTPClient(client))
}

func newHttpTransport(endpoint string) (*httpConn, error) {
// Sanity check URL so we don't end up with a client that will fail every request.
_, err := url.Parse(endpoint)
if err != nil {
return nil, err
}

initctx := context.Background()
headers := make(http.Header, 2)
headers.Set("accept", contentType)
headers.Set("content-type", contentType)
return newClient(initctx, func(context.Context) (ServerCodec, error) {
hc := &httpConn{
client: client,
headers: headers,
url: endpoint,
closeCh: make(chan interface{}),
}
return hc, nil
})

return &httpConn{
client: &http.Client{},
headers: headers,
url: endpoint,
closeCh: make(chan interface{}),
}, nil
}

// DialHTTP creates a new RPC client that connects to an RPC server over HTTP.
func DialHTTP(endpoint string) (*Client, error) {
return DialHTTPWithClient(endpoint, new(http.Client))
}

func DialHTTPWithOptions(endpoint string, options ...HTTPTransportOption) (*Client, error) {
hc, err := newHttpTransport(endpoint)
if err != nil {
return nil, err
}
for _, opt := range options {
opt(hc)
}
return newClient(context.Background(), func(context.Context) (ServerCodec, error) {
return hc, nil
})
}

func (c *Client) sendHTTP(ctx context.Context, op *requestOp, msg interface{}) error {
hc := c.writeConn.(*httpConn)
respBody, err := hc.doRequest(ctx, msg)
Expand Down
111 changes: 79 additions & 32 deletions rpc/websocket.go
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,57 @@ const (

var wsBufferPool = new(sync.Pool)

type websocketTransport struct {
d *websocket.Dialer
headers http.Header
ctx context.Context
}

type WebsocketTransportOption func(*websocketTransport) error

func WithDialHeaders(h http.Header) WebsocketTransportOption {
return func(t *websocketTransport) error {
for k, vals := range h {
for _, v := range vals {
t.headers.Add(k, v)
}
}
return nil
}
}

func WithBasicAuth(endpoint, origin string) WebsocketTransportOption {
return func(t *websocketTransport) error {
endpointURL, err := url.Parse(endpoint)
if err != nil {
return err
}
if origin != "" {
t.headers.Add("origin", origin)
}
if endpointURL.User != nil {
b64auth := base64.StdEncoding.EncodeToString([]byte(endpointURL.User.String()))
t.headers.Add("authorization", "Basic "+b64auth)
endpointURL.User = nil
}
return nil
}
}

func WithDialContext(ctx context.Context) WebsocketTransportOption {
return func(t *websocketTransport) error {
t.ctx = ctx
return nil
}
}

func WithDialer(d *websocket.Dialer) WebsocketTransportOption {
return func(t *websocketTransport) error {
t.d = d
return nil
}
}

// WebsocketHandler returns a handler that serves JSON-RPC to WebSocket connections.
//
// allowedOrigins should be a comma-separated list of allowed origin URLs.
Expand Down Expand Up @@ -181,55 +232,51 @@ func parseOriginURL(origin string) (string, string, string, error) {
return scheme, hostname, port, nil
}

// DialWebsocketWithDialer creates a new RPC client that communicates with a JSON-RPC server
// that is listening on the given endpoint using the provided dialer.
func DialWebsocketWithDialer(ctx context.Context, endpoint, origin string, dialer websocket.Dialer) (*Client, error) {
endpoint, header, err := wsClientHeaders(endpoint, origin)
if err != nil {
return nil, err
func newWebsocketTransport(endpoint string) *websocketTransport {
dialer := &websocket.Dialer{
ReadBufferSize: wsReadBuffer,
WriteBufferSize: wsWriteBuffer,
WriteBufferPool: wsBufferPool,
}
return &websocketTransport{
d: dialer,
headers: http.Header{},
}
return newClient(ctx, func(ctx context.Context) (ServerCodec, error) {
conn, resp, err := dialer.DialContext(ctx, endpoint, header)
}

func DialWebsocketWithOptions(endpoint string, options ...WebsocketTransportOption) (*Client, error) {
transport := newWebsocketTransport(endpoint)
for _, option := range options {
if err := option(transport); err != nil {
return nil, err
}
}
return newClient(transport.ctx, func(ctx context.Context) (ServerCodec, error) {
conn, resp, err := transport.d.DialContext(ctx, endpoint, transport.headers)
if err != nil {
hErr := wsHandshakeError{err: err}
if resp != nil {
hErr.status = resp.Status
}
return nil, hErr
}
return newWebsocketCodec(conn, endpoint, header), nil
return newWebsocketCodec(conn, endpoint, transport.headers), nil
})
}

// DialWebsocketWithDialer creates a new RPC client that communicates with a JSON-RPC server
// that is listening on the given endpoint using the provided dialer.
func DialWebsocketWithDialer(ctx context.Context, endpoint, origin string, dialer websocket.Dialer) (*Client, error) {
return DialWebsocketWithOptions(endpoint, WithBasicAuth(endpoint, origin), WithDialer(&dialer), WithDialContext(ctx))
}

// DialWebsocket creates a new RPC client that communicates with a JSON-RPC server
// that is listening on the given endpoint.
//
// The context is used for the initial connection establishment. It does not
// affect subsequent interactions with the client.
func DialWebsocket(ctx context.Context, endpoint, origin string) (*Client, error) {
dialer := websocket.Dialer{
ReadBufferSize: wsReadBuffer,
WriteBufferSize: wsWriteBuffer,
WriteBufferPool: wsBufferPool,
}
return DialWebsocketWithDialer(ctx, endpoint, origin, dialer)
}

func wsClientHeaders(endpoint, origin string) (string, http.Header, error) {
endpointURL, err := url.Parse(endpoint)
if err != nil {
return endpoint, nil, err
}
header := make(http.Header)
if origin != "" {
header.Add("origin", origin)
}
if endpointURL.User != nil {
b64auth := base64.StdEncoding.EncodeToString([]byte(endpointURL.User.String()))
header.Add("authorization", "Basic "+b64auth)
endpointURL.User = nil
}
return endpointURL.String(), header, nil
return DialWebsocketWithOptions(endpoint, WithBasicAuth(endpoint, origin), WithDialContext(ctx))
}

type websocketCodec struct {
Expand Down