Skip to content

Commit 9882e21

Browse files
committed
Add confugurable cachePath
1 parent 8b137fc commit 9882e21

File tree

9 files changed

+97
-39
lines changed

9 files changed

+97
-39
lines changed

README.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -42,6 +42,7 @@ This library aims to require as little configuration as possible, favouring over
4242
| Password | postgres |
4343
| Database | postgres |
4444
| Version | 12.1.0 |
45+
| CachePath | $USER_HOME/.embedded-postgres-go/ |
4546
| RuntimePath | $USER_HOME/.embedded-postgres-go/extracted |
4647
| DataPath | $USER_HOME/.embedded-postgres-go/extracted/data |
4748
| BinariesPath | $USER_HOME/.embedded-postgres-go/extracted |

cache_locator.go

Lines changed: 7 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -8,13 +8,15 @@ import (
88

99
// CacheLocator retrieves the location of the Postgres binary cache returning it to location.
1010
// The result of whether this cache is present will be returned to exists.
11-
type CacheLocator func() (location string, exists bool)
11+
type CacheLocator func(cachePath string) (location string, exists bool)
1212

1313
func defaultCacheLocator(versionStrategy VersionStrategy) CacheLocator {
14-
return func() (string, bool) {
15-
cacheDirectory := ".embedded-postgres-go"
16-
if userHome, err := os.UserHomeDir(); err == nil {
17-
cacheDirectory = filepath.Join(userHome, ".embedded-postgres-go")
14+
return func(cacheDirectory string) (string, bool) {
15+
if cacheDirectory == "" {
16+
cacheDirectory = ".embedded-postgres-go"
17+
if userHome, err := os.UserHomeDir(); err == nil {
18+
cacheDirectory = filepath.Join(userHome, ".embedded-postgres-go")
19+
}
1820
}
1921

2022
operatingSystem, architecture, version := versionStrategy()

cache_locator_test.go

Lines changed: 12 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -11,8 +11,19 @@ func Test_defaultCacheLocator_NotExists(t *testing.T) {
1111
return "a", "b", "1.2.3"
1212
})
1313

14-
cacheLocation, exists := locator()
14+
cacheLocation, exists := locator("")
1515

1616
assert.Contains(t, cacheLocation, ".embedded-postgres-go/embedded-postgres-binaries-a-b-1.2.3.txz")
1717
assert.False(t, exists)
1818
}
19+
20+
func Test_defaultCacheLocator_CustomPath(t *testing.T) {
21+
locator := defaultCacheLocator(func() (string, string, PostgresVersion) {
22+
return "a", "b", "1.2.3"
23+
})
24+
25+
cacheLocation, exists := locator("/custom/path")
26+
27+
assert.Equal(t, cacheLocation, "/custom/path/embedded-postgres-binaries-a-b-1.2.3.txz")
28+
assert.False(t, exists)
29+
}

config.go

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@ type Config struct {
1414
database string
1515
username string
1616
password string
17+
cachePath string
1718
runtimePath string
1819
dataPath string
1920
binariesPath string
@@ -82,6 +83,13 @@ func (c Config) RuntimePath(path string) Config {
8283
return c
8384
}
8485

86+
// CachePath sets the path that will be used for storing Postgres binaries archive.
87+
// If this option is not set, ~/.go-embedded-postgres will be used.
88+
func (c Config) CachePath(path string) Config {
89+
c.cachePath = path
90+
return c
91+
}
92+
8593
// DataPath sets the path that will be used for the Postgres data directory.
8694
// If this option is set, a previously initialized data directory will be reused if possible.
8795
func (c Config) DataPath(path string) Config {

embedded_postgres.go

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -45,7 +45,7 @@ func newDatabaseWithConfig(config Config) *EmbeddedPostgres {
4545
shouldUseAlpineLinuxBuild,
4646
)
4747
cacheLocator := defaultCacheLocator(versionStrategy)
48-
remoteFetchStrategy := defaultRemoteFetchStrategy(config.binaryRepositoryURL, versionStrategy, cacheLocator)
48+
remoteFetchStrategy := defaultRemoteFetchStrategy(config.binaryRepositoryURL, versionStrategy, cacheLocator, config.cachePath)
4949

5050
return &EmbeddedPostgres{
5151
config: config,
@@ -77,7 +77,7 @@ func (ep *EmbeddedPostgres) Start() error {
7777

7878
ep.syncedLogger = logger
7979

80-
cacheLocation, cacheExists := ep.cacheLocator()
80+
cacheLocation, cacheExists := ep.cacheLocator(ep.config.cachePath)
8181

8282
if ep.config.runtimePath == "" {
8383
ep.config.runtimePath = filepath.Join(filepath.Dir(cacheLocation), "extracted")

embedded_postgres_test.go

Lines changed: 32 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -66,7 +66,7 @@ func Test_ErrorWhenPortAlreadyTaken(t *testing.T) {
6666

6767
func Test_ErrorWhenRemoteFetchError(t *testing.T) {
6868
database := NewDatabase()
69-
database.cacheLocator = func() (string, bool) {
69+
database.cacheLocator = func(string) (string, bool) {
7070
return "", false
7171
}
7272
database.remoteFetchStrategy = func() error {
@@ -88,7 +88,7 @@ func Test_ErrorWhenUnableToUnArchiveFile_WrongFormat(t *testing.T) {
8888
Database("beer").
8989
StartTimeout(10 * time.Second))
9090

91-
database.cacheLocator = func() (string, bool) {
91+
database.cacheLocator = func(string) (string, bool) {
9292
return jarFile, true
9393
}
9494

@@ -119,7 +119,7 @@ func Test_ErrorWhenUnableToInitDatabase(t *testing.T) {
119119
RuntimePath(extractPath).
120120
StartTimeout(10 * time.Second))
121121

122-
database.cacheLocator = func() (string, bool) {
122+
database.cacheLocator = func(string) (string, bool) {
123123
return jarFile, true
124124
}
125125

@@ -222,7 +222,7 @@ func Test_ErrorWhenCannotStartPostgresProcess(t *testing.T) {
222222
database := NewDatabase(DefaultConfig().
223223
RuntimePath(extractPath))
224224

225-
database.cacheLocator = func() (string, bool) {
225+
database.cacheLocator = func(string) (string, bool) {
226226
return jarFile, true
227227
}
228228

@@ -360,7 +360,7 @@ func Test_ConcurrentStart(t *testing.T) {
360360
var wg sync.WaitGroup
361361

362362
database := NewDatabase()
363-
cacheLocation, _ := database.cacheLocator()
363+
cacheLocation, _ := database.cacheLocator("")
364364
err := os.RemoveAll(cacheLocation)
365365
require.NoError(t, err)
366366

@@ -644,7 +644,7 @@ func Test_CustomBinariesLocation(t *testing.T) {
644644
}
645645

646646
// Delete cache to make sure unarchive doesn't happen again.
647-
cacheLocation, _ := database.cacheLocator()
647+
cacheLocation, _ := database.cacheLocator("")
648648
if err := os.RemoveAll(cacheLocation); err != nil {
649649
panic(err)
650650
}
@@ -658,6 +658,30 @@ func Test_CustomBinariesLocation(t *testing.T) {
658658
}
659659
}
660660

661+
func Test_CachePath(t *testing.T) {
662+
cacheTempDir, err := os.MkdirTemp("", "prepare_database_test_cache")
663+
if err != nil {
664+
panic(err)
665+
}
666+
667+
defer func() {
668+
if err := os.RemoveAll(cacheTempDir); err != nil {
669+
panic(err)
670+
}
671+
}()
672+
673+
database := NewDatabase(DefaultConfig().
674+
CachePath(cacheTempDir))
675+
676+
if err := database.Start(); err != nil {
677+
shutdownDBAndFail(t, err, database)
678+
}
679+
680+
if err := database.Stop(); err != nil {
681+
shutdownDBAndFail(t, err, database)
682+
}
683+
}
684+
661685
func Test_PrefetchedBinaries(t *testing.T) {
662686
binTempDir, err := os.MkdirTemp("", "prepare_database_test_bin")
663687
if err != nil {
@@ -688,13 +712,13 @@ func Test_PrefetchedBinaries(t *testing.T) {
688712
panic(err)
689713
}
690714

691-
cacheLocation, _ := database.cacheLocator()
715+
cacheLocation, _ := database.cacheLocator("")
692716
if err := decompressTarXz(defaultTarReader, cacheLocation, binTempDir); err != nil {
693717
panic(err)
694718
}
695719

696720
// Expect everything to work without cacheLocator and/or remoteFetch abilities.
697-
database.cacheLocator = func() (string, bool) {
721+
database.cacheLocator = func(string) (string, bool) {
698722
return "", false
699723
}
700724
database.remoteFetchStrategy = func() error {

remote_fetch.go

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,7 @@ import (
1919
type RemoteFetchStrategy func() error
2020

2121
//nolint:funlen
22-
func defaultRemoteFetchStrategy(remoteFetchHost string, versionStrategy VersionStrategy, cacheLocator CacheLocator) RemoteFetchStrategy {
22+
func defaultRemoteFetchStrategy(remoteFetchHost string, versionStrategy VersionStrategy, cacheLocator CacheLocator, cachePath string) RemoteFetchStrategy {
2323
return func() error {
2424
operatingSystem, architecture, version := versionStrategy()
2525

@@ -62,7 +62,7 @@ func defaultRemoteFetchStrategy(remoteFetchHost string, versionStrategy VersionS
6262
}
6363
}
6464

65-
return decompressResponse(jarBodyBytes, jarDownloadResponse.ContentLength, cacheLocator, jarDownloadURL)
65+
return decompressResponse(jarBodyBytes, jarDownloadResponse.ContentLength, cacheLocator, cachePath, jarDownloadURL)
6666
}
6767
}
6868

@@ -74,13 +74,13 @@ func closeBody(resp *http.Response) func() {
7474
}
7575
}
7676

77-
func decompressResponse(bodyBytes []byte, contentLength int64, cacheLocator CacheLocator, downloadURL string) error {
77+
func decompressResponse(bodyBytes []byte, contentLength int64, cacheLocator CacheLocator, cachePath string, downloadURL string) error {
7878
zipReader, err := zip.NewReader(bytes.NewReader(bodyBytes), contentLength)
7979
if err != nil {
8080
return errorFetchingPostgres(err)
8181
}
8282

83-
cacheLocation, _ := cacheLocator()
83+
cacheLocation, _ := cacheLocator(cachePath)
8484

8585
if err := os.MkdirAll(filepath.Dir(cacheLocation), 0755); err != nil {
8686
return errorExtractingPostgres(err)

remote_fetch_test.go

Lines changed: 30 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,8 @@ import (
1818
func Test_defaultRemoteFetchStrategy_ErrorWhenHttpGet(t *testing.T) {
1919
remoteFetchStrategy := defaultRemoteFetchStrategy("http://localhost:1234/maven2",
2020
testVersionStrategy(),
21-
testCacheLocator())
21+
testCacheLocator(),
22+
"")
2223

2324
err := remoteFetchStrategy()
2425

@@ -33,7 +34,8 @@ func Test_defaultRemoteFetchStrategy_ErrorWhenHttpStatusNot200(t *testing.T) {
3334

3435
remoteFetchStrategy := defaultRemoteFetchStrategy(server.URL,
3536
testVersionStrategy(),
36-
testCacheLocator())
37+
testCacheLocator(),
38+
"")
3739

3840
err := remoteFetchStrategy()
3941

@@ -48,7 +50,8 @@ func Test_defaultRemoteFetchStrategy_ErrorWhenBodyReadIssue(t *testing.T) {
4850

4951
remoteFetchStrategy := defaultRemoteFetchStrategy(server.URL+"/maven2",
5052
testVersionStrategy(),
51-
testCacheLocator())
53+
testCacheLocator(),
54+
"")
5255

5356
err := remoteFetchStrategy()
5457

@@ -66,7 +69,8 @@ func Test_defaultRemoteFetchStrategy_ErrorWhenCannotUnzipSubFile(t *testing.T) {
6669

6770
remoteFetchStrategy := defaultRemoteFetchStrategy(server.URL+"/maven2",
6871
testVersionStrategy(),
69-
testCacheLocator())
72+
testCacheLocator(),
73+
"")
7074

7175
err := remoteFetchStrategy()
7276

@@ -88,7 +92,8 @@ func Test_defaultRemoteFetchStrategy_ErrorWhenCannotUnzip(t *testing.T) {
8892

8993
remoteFetchStrategy := defaultRemoteFetchStrategy(server.URL+"/maven2",
9094
testVersionStrategy(),
91-
testCacheLocator())
95+
testCacheLocator(),
96+
"")
9297

9398
err := remoteFetchStrategy()
9499

@@ -112,7 +117,8 @@ func Test_defaultRemoteFetchStrategy_ErrorWhenNoSubTarArchive(t *testing.T) {
112117

113118
remoteFetchStrategy := defaultRemoteFetchStrategy(server.URL+"/maven2",
114119
testVersionStrategy(),
115-
testCacheLocator())
120+
testCacheLocator(),
121+
"")
116122

117123
err := remoteFetchStrategy()
118124

@@ -141,9 +147,10 @@ func Test_defaultRemoteFetchStrategy_ErrorWhenCannotExtractSubArchive(t *testing
141147

142148
remoteFetchStrategy := defaultRemoteFetchStrategy(server.URL+"/maven2",
143149
testVersionStrategy(),
144-
func() (s string, b bool) {
150+
func(string) (s string, b bool) {
145151
return filepath.FromSlash("/invalid"), false
146-
})
152+
},
153+
"")
147154

148155
err := remoteFetchStrategy()
149156

@@ -181,9 +188,10 @@ func Test_defaultRemoteFetchStrategy_ErrorWhenCannotCreateCacheDirectory(t *test
181188

182189
remoteFetchStrategy := defaultRemoteFetchStrategy(server.URL+"/maven2",
183190
testVersionStrategy(),
184-
func() (s string, b bool) {
191+
func(string) (s string, b bool) {
185192
return cacheLocation, false
186-
})
193+
},
194+
"")
187195

188196
err := remoteFetchStrategy()
189197

@@ -218,9 +226,10 @@ func Test_defaultRemoteFetchStrategy_ErrorWhenCannotCreateSubArchiveFile(t *test
218226

219227
remoteFetchStrategy := defaultRemoteFetchStrategy(server.URL+"/maven2",
220228
testVersionStrategy(),
221-
func() (s string, b bool) {
229+
func(string) (s string, b bool) {
222230
return "/\\000", false
223-
})
231+
},
232+
"")
224233

225234
err := remoteFetchStrategy()
226235

@@ -256,9 +265,10 @@ func Test_defaultRemoteFetchStrategy_ErrorWhenSHA256NotMatch(t *testing.T) {
256265

257266
remoteFetchStrategy := defaultRemoteFetchStrategy(server.URL+"/maven2",
258267
testVersionStrategy(),
259-
func() (s string, b bool) {
268+
func(string) (s string, b bool) {
260269
return cacheLocation, false
261-
})
270+
},
271+
"")
262272

263273
err := remoteFetchStrategy()
264274

@@ -295,9 +305,10 @@ func Test_defaultRemoteFetchStrategy(t *testing.T) {
295305

296306
remoteFetchStrategy := defaultRemoteFetchStrategy(server.URL+"/maven2",
297307
testVersionStrategy(),
298-
func() (s string, b bool) {
308+
func(string) (s string, b bool) {
299309
return cacheLocation, false
300-
})
310+
},
311+
"")
301312

302313
err := remoteFetchStrategy()
303314

@@ -347,9 +358,10 @@ func Test_defaultRemoteFetchStrategyWithExistingDownload(t *testing.T) {
347358

348359
remoteFetchStrategy := defaultRemoteFetchStrategy(server.URL+"/maven2",
349360
testVersionStrategy(),
350-
func() (s string, b bool) {
361+
func(string) (s string, b bool) {
351362
return cacheLocation, false
352-
})
363+
},
364+
"")
353365

354366
// call it the remoteFetchStrategy(). The output location should be empty and a new file created
355367
err = remoteFetchStrategy()

test_util_test.go

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -55,7 +55,7 @@ func testVersionStrategy() VersionStrategy {
5555
}
5656

5757
func testCacheLocator() CacheLocator {
58-
return func() (s string, b bool) {
58+
return func(string) (s string, b bool) {
5959
return "", false
6060
}
6161
}

0 commit comments

Comments
 (0)