Skip to content

Commit 783db6a

Browse files
authored
feat: add support for unix sockets (GoogleCloudPlatform#44)
This is an adaptation from GoogleCloudPlatform/cloud-sql-proxy#1182
1 parent 7069117 commit 783db6a

File tree

5 files changed

+321
-38
lines changed

5 files changed

+321
-38
lines changed

cmd/root.go

Lines changed: 32 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -128,6 +128,8 @@ without having to manage any client SSL certificates.`,
128128
"Address on which to bind AlloyDB instance listeners.")
129129
cmd.PersistentFlags().IntVarP(&c.conf.Port, "port", "p", 5432,
130130
"Initial port to use for listeners. Subsequent listeners increment from this value.")
131+
cmd.PersistentFlags().StringVarP(&c.conf.UnixSocket, "unix-socket", "u", "",
132+
`Enables Unix sockets for all listeners using the provided directory.`)
131133

132134
c.Command = cmd
133135
return c
@@ -138,6 +140,15 @@ func parseConfig(cmd *cobra.Command, conf *proxy.Config, args []string) error {
138140
if len(args) == 0 {
139141
return newBadCommandError("missing instance uri (e.g., /projects/$PROJECTS/locations/$LOCTION/clusters/$CLUSTER/instances/$INSTANCES)")
140142
}
143+
userHasSet := func(f string) bool {
144+
return cmd.PersistentFlags().Lookup(f).Changed
145+
}
146+
if userHasSet("address") && userHasSet("unix-socket") {
147+
return newBadCommandError("cannot specify --unix-socket and --address together")
148+
}
149+
if userHasSet("port") && userHasSet("unix-socket") {
150+
return newBadCommandError("cannot specify --unix-socket and --port together")
151+
}
141152
// First, validate global config.
142153
if ip := net.ParseIP(conf.Addr); ip == nil {
143154
return newBadCommandError(fmt.Sprintf("not a valid IP address: %q", conf.Addr))
@@ -171,7 +182,18 @@ func parseConfig(cmd *cobra.Command, conf *proxy.Config, args []string) error {
171182
return newBadCommandError(fmt.Sprintf("could not parse query: %q", res[1]))
172183
}
173184

174-
if a, ok := q["address"]; ok {
185+
a, aok := q["address"]
186+
p, pok := q["port"]
187+
u, uok := q["unix-socket"]
188+
189+
if aok && uok {
190+
return newBadCommandError("cannot specify both address and unix-socket query params")
191+
}
192+
if pok && uok {
193+
return newBadCommandError("cannot specify both port and unix-socket query params")
194+
}
195+
196+
if aok {
175197
if len(a) != 1 {
176198
return newBadCommandError(fmt.Sprintf("address query param should be only one value: %q", a))
177199
}
@@ -184,7 +206,7 @@ func parseConfig(cmd *cobra.Command, conf *proxy.Config, args []string) error {
184206
ic.Addr = a[0]
185207
}
186208

187-
if p, ok := q["port"]; ok {
209+
if pok {
188210
if len(p) != 1 {
189211
return newBadCommandError(fmt.Sprintf("port query param should be only one value: %q", a))
190212
}
@@ -197,6 +219,14 @@ func parseConfig(cmd *cobra.Command, conf *proxy.Config, args []string) error {
197219
}
198220
ic.Port = pp
199221
}
222+
223+
if uok {
224+
if len(u) != 1 {
225+
return newBadCommandError(fmt.Sprintf("unix query param should be only one value: %q", a))
226+
}
227+
ic.UnixSocket = u[0]
228+
229+
}
200230
}
201231
ics = append(ics, ic)
202232
}

cmd/root_test.go

Lines changed: 43 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -137,6 +137,29 @@ func TestNewCommandArguments(t *testing.T) {
137137
CredentialsFile: "/path/to/file",
138138
}),
139139
},
140+
{
141+
desc: "using the unix socket flag",
142+
args: []string{"--unix-socket", "/path/to/dir/", "/projects/proj/locations/region/clusters/clust/instances/inst"},
143+
want: withDefaults(&proxy.Config{
144+
UnixSocket: "/path/to/dir/",
145+
}),
146+
},
147+
{
148+
desc: "using the (short) unix socket flag",
149+
args: []string{"-u", "/path/to/dir/", "/projects/proj/locations/region/clusters/clust/instances/inst"},
150+
want: withDefaults(&proxy.Config{
151+
UnixSocket: "/path/to/dir/",
152+
}),
153+
},
154+
{
155+
desc: "using the unix socket query param",
156+
args: []string{"/projects/proj/locations/region/clusters/clust/instances/inst?unix-socket=/path/to/dir/"},
157+
want: withDefaults(&proxy.Config{
158+
Instances: []proxy.InstanceConnConfig{{
159+
UnixSocket: "/path/to/dir/",
160+
}},
161+
}),
162+
},
140163
}
141164

142165
for _, tc := range tcs {
@@ -210,6 +233,26 @@ func TestNewCommandWithErrors(t *testing.T) {
210233
"--token", "my-token",
211234
"--credentials-file", "/path/to/file", "/projects/proj/locations/region/clusters/clust/instances/inst"},
212235
},
236+
{
237+
desc: "when the unix socket query param contains multiple values",
238+
args: []string{"/projects/proj/locations/region/clusters/clust/instances/inst?unix-socket=/one&unix-socket=/two"},
239+
},
240+
{
241+
desc: "using the unix socket flag with addr",
242+
args: []string{"-u", "/path/to/dir/", "-a", "127.0.0.1", "/projects/proj/locations/region/clusters/clust/instances/inst"},
243+
},
244+
{
245+
desc: "using the unix socket flag with port",
246+
args: []string{"-u", "/path/to/dir/", "-p", "5432", "/projects/proj/locations/region/clusters/clust/instances/inst"},
247+
},
248+
{
249+
desc: "using the unix socket and addr query params",
250+
args: []string{"/projects/proj/locations/region/clusters/clust/instances/inst?unix-socket=/path&address=127.0.0.1"},
251+
},
252+
{
253+
desc: "using the unix socket and port query params",
254+
args: []string{"/projects/proj/locations/region/clusters/clust/instances/inst?unix-socket=/path&port=5000"},
255+
},
213256
}
214257

215258
for _, tc := range tcs {

internal/proxy/proxy.go

Lines changed: 97 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,10 @@ import (
1919
"fmt"
2020
"io"
2121
"net"
22+
"os"
23+
"path/filepath"
24+
"regexp"
25+
"strings"
2226
"sync"
2327
"time"
2428

@@ -37,6 +41,10 @@ type InstanceConnConfig struct {
3741
Addr string
3842
// Port is the port on which to bind a listener for the instance.
3943
Port int
44+
// UnixSocket is the directory where a Unix socket will be created,
45+
// connected to the Cloud SQL instance. If set, takes precedence over Addr
46+
// and Port.
47+
UnixSocket string
4048
}
4149

4250
// Config contains all the configuration provided by the caller.
@@ -54,6 +62,10 @@ type Config struct {
5462
// increments from this value.
5563
Port int
5664

65+
// UnixSocket is the directory where Unix sockets will be created,
66+
// connected to any Instances. If set, takes precedence over Addr and Port.
67+
UnixSocket string
68+
5769
// Instances are configuration for individual instances. Instance
5870
// configuration takes precedence over global configuration.
5971
Instances []InstanceConnConfig
@@ -95,6 +107,28 @@ func (c *portConfig) nextPort() int {
95107
return p
96108
}
97109

110+
var (
111+
// Instance URI is in the format:
112+
// '/projects/<PROJECT>/locations/<REGION>/clusters/<CLUSTER>/instances/<INSTANCE>'
113+
// Additionally, we have to support legacy "domain-scoped" projects (e.g. "google.com:PROJECT")
114+
instURIRegex = regexp.MustCompile("projects/([^:]+(:[^:]+)?)/locations/([^:]+)/clusters/([^:]+)/instances/([^:]+)")
115+
)
116+
117+
// UnixSocketDir returns a shorted instance connection name to prevent exceeding
118+
// the Unix socket length.
119+
func UnixSocketDir(dir, inst string) (string, error) {
120+
m := instURIRegex.FindSubmatch([]byte(inst))
121+
if m == nil {
122+
return "", fmt.Errorf("invalid instance name: %v", inst)
123+
}
124+
project := string(m[1])
125+
region := string(m[3])
126+
cluster := string(m[4])
127+
name := string(m[5])
128+
shortName := strings.Join([]string{project, region, cluster, name}, ".")
129+
return filepath.Join(dir, shortName), nil
130+
}
131+
98132
// Client represents the state of the current instantiation of the proxy.
99133
type Client struct {
100134
cmd *cobra.Command
@@ -106,31 +140,79 @@ type Client struct {
106140

107141
// NewClient completes the initial setup required to get the proxy to a "steady" state.
108142
func NewClient(ctx context.Context, d alloydb.Dialer, cmd *cobra.Command, conf *Config) (*Client, error) {
109-
var mnts []*socketMount
110143
pc := newPortConfig(conf.Port)
144+
var mnts []*socketMount
111145
for _, inst := range conf.Instances {
112-
m := &socketMount{inst: inst.Name}
113-
a := conf.Addr
114-
if inst.Addr != "" {
115-
a = inst.Addr
116-
}
117-
var np int
118-
switch {
119-
case inst.Port != 0:
120-
np = inst.Port
121-
default: // use next increment from conf.Port
122-
np = pc.nextPort()
146+
var (
147+
// network is one of "tcp" or "unix"
148+
network string
149+
// address is either a TCP host port, or a Unix socket
150+
address string
151+
)
152+
// IF
153+
// a global Unix socket directory is NOT set AND
154+
// an instance-level Unix socket is NOT set
155+
// (e.g., I didn't set a Unix socket globally or for this instance)
156+
// OR
157+
// an instance-level TCP address or port IS set
158+
// (e.g., I'm overriding any global settings to use TCP for this
159+
// instance)
160+
// use a TCP listener.
161+
// Otherwise, use a Unix socket.
162+
if (conf.UnixSocket == "" && inst.UnixSocket == "") ||
163+
(inst.Addr != "" || inst.Port != 0) {
164+
network = "tcp"
165+
166+
a := conf.Addr
167+
if inst.Addr != "" {
168+
a = inst.Addr
169+
}
170+
171+
var np int
172+
switch {
173+
case inst.Port != 0:
174+
np = inst.Port
175+
case conf.Port != 0:
176+
np = pc.nextPort()
177+
default:
178+
np = pc.nextPort()
179+
}
180+
181+
address = net.JoinHostPort(a, fmt.Sprint(np))
182+
} else {
183+
network = "unix"
184+
185+
dir := conf.UnixSocket
186+
if dir == "" {
187+
dir = inst.UnixSocket
188+
}
189+
ud, err := UnixSocketDir(dir, inst.Name)
190+
if err != nil {
191+
return nil, err
192+
}
193+
// Create the parent directory that will hold the socket.
194+
if _, err := os.Stat(ud); err != nil {
195+
if err = os.Mkdir(ud, 0777); err != nil {
196+
return nil, err
197+
}
198+
}
199+
// use the Postgres-specific socket name
200+
address = filepath.Join(ud, ".s.PGSQL.5432")
123201
}
124-
addr, err := m.listen(ctx, "tcp", net.JoinHostPort(a, fmt.Sprint(np)))
202+
203+
m := &socketMount{inst: inst.Name}
204+
addr, err := m.listen(ctx, network, address)
125205
if err != nil {
126206
for _, m := range mnts {
127207
m.close()
128208
}
129209
return nil, fmt.Errorf("[%v] Unable to mount socket: %v", inst.Name, err)
130210
}
211+
131212
cmd.Printf("[%s] Listening on %s\n", inst.Name, addr.String())
132213
mnts = append(mnts, m)
133214
}
215+
134216
return &Client{mnts: mnts, cmd: cmd, dialer: d}, nil
135217
}
136218

@@ -210,9 +292,9 @@ type socketMount struct {
210292
}
211293

212294
// listen causes a socketMount to create a Listener at the specified network address.
213-
func (s *socketMount) listen(ctx context.Context, network string, host string) (net.Addr, error) {
295+
func (s *socketMount) listen(ctx context.Context, network string, address string) (net.Addr, error) {
214296
lc := net.ListenConfig{KeepAlive: 30 * time.Second}
215-
l, err := lc.Listen(ctx, network, host)
297+
l, err := lc.Listen(ctx, network, address)
216298
if err != nil {
217299
return nil, err
218300
}

0 commit comments

Comments
 (0)