diff --git a/cmd/functions.go b/cmd/functions.go index 6b5684bc2..bddb114ba 100644 --- a/cmd/functions.go +++ b/cmd/functions.go @@ -49,7 +49,10 @@ var ( Long: "Download the source code for a Function from the linked Supabase project.", Args: cobra.ExactArgs(1), RunE: func(cmd *cobra.Command, args []string) error { - return download.Run(cmd.Context(), args[0], flags.ProjectRef, useLegacyBundle, afero.NewOsFs()) + if useApi { + useDocker = false + } + return download.Run(cmd.Context(), args[0], flags.ProjectRef, useLegacyBundle, useDocker, afero.NewOsFs()) }, } @@ -138,6 +141,7 @@ func init() { deployFlags.BoolVar(&useLegacyBundle, "legacy-bundle", false, "Use legacy bundling mechanism.") functionsDeployCmd.MarkFlagsMutuallyExclusive("use-api", "use-docker", "legacy-bundle") cobra.CheckErr(deployFlags.MarkHidden("legacy-bundle")) + cobra.CheckErr(deployFlags.MarkHidden("use-docker")) deployFlags.UintVarP(&maxJobs, "jobs", "j", 1, "Maximum number of parallel jobs.") deployFlags.BoolVar(noVerifyJWT, "no-verify-jwt", false, "Disable JWT verification for the Function.") deployFlags.BoolVar(&prune, "prune", false, "Delete Functions that exist in Supabase project but not locally.") @@ -152,8 +156,14 @@ func init() { functionsServeCmd.MarkFlagsMutuallyExclusive("inspect", "inspect-mode") functionsServeCmd.Flags().Bool("all", true, "Serve all Functions.") cobra.CheckErr(functionsServeCmd.Flags().MarkHidden("all")) - functionsDownloadCmd.Flags().StringVar(&flags.ProjectRef, "project-ref", "", "Project ref of the Supabase project.") - functionsDownloadCmd.Flags().BoolVar(&useLegacyBundle, "legacy-bundle", false, "Use legacy bundling mechanism.") + downloadFlags := functionsDownloadCmd.Flags() + downloadFlags.StringVar(&flags.ProjectRef, "project-ref", "", "Project ref of the Supabase project.") + downloadFlags.BoolVar(&useLegacyBundle, "legacy-bundle", false, "Use legacy bundling mechanism.") + downloadFlags.BoolVar(&useApi, "use-api", false, "Use Management API to unbundle functions server-side.") + downloadFlags.BoolVar(&useDocker, "use-docker", true, "Use Docker to unbundle functions client-side.") + functionsDownloadCmd.MarkFlagsMutuallyExclusive("use-api", "use-docker", "legacy-bundle") + cobra.CheckErr(downloadFlags.MarkHidden("legacy-bundle")) + cobra.CheckErr(downloadFlags.MarkHidden("use-docker")) functionsCmd.AddCommand(functionsListCmd) functionsCmd.AddCommand(functionsDeleteCmd) functionsCmd.AddCommand(functionsDeployCmd) diff --git a/internal/functions/download/download.go b/internal/functions/download/download.go index e75ec7d9f..0c19c605d 100644 --- a/internal/functions/download/download.go +++ b/internal/functions/download/download.go @@ -6,7 +6,10 @@ import ( "context" "fmt" "io" + "mime" + "mime/multipart" "net/http" + "net/url" "os" "os/exec" "path" @@ -112,15 +115,30 @@ func downloadFunction(ctx context.Context, projectRef, slug, extractScriptPath s return nil } -func Run(ctx context.Context, slug string, projectRef string, useLegacyBundle bool, fsys afero.Fs) error { +func Run(ctx context.Context, slug, projectRef string, useLegacyBundle, useDocker bool, fsys afero.Fs) error { + // Sanity check + if err := flags.LoadConfig(fsys); err != nil { + return err + } + if useLegacyBundle { return RunLegacy(ctx, slug, projectRef, fsys) } - // 1. Sanity check - if err := flags.LoadConfig(fsys); err != nil { - return err + + if useDocker { + if utils.IsDockerRunning(ctx) { + // download eszip file for client-side unbundling with edge-runtime + return downloadWithDockerUnbundle(ctx, slug, projectRef, fsys) + } else { + fmt.Fprintln(os.Stderr, utils.Yellow("WARNING:"), "Docker is not running") + } } - // 2. Download eszip to temp file + + // Use server-side unbundling with multipart/form-data + return downloadWithServerSideUnbundle(ctx, slug, projectRef, fsys) +} + +func downloadWithDockerUnbundle(ctx context.Context, slug string, projectRef string, fsys afero.Fs) error { eszipPath, err := downloadOne(ctx, slug, projectRef, fsys) if err != nil { return err @@ -238,3 +256,253 @@ deno_version = 2 func suggestLegacyBundle(slug string) string { return fmt.Sprintf("\nIf your function is deployed using CLI < 1.120.0, trying running %s instead.", utils.Aqua("supabase functions download --legacy-bundle "+slug)) } + +// New server-side unbundle implementation that mirrors Studio's entrypoint-based +// base-dir + relative path behaviour. +func downloadWithServerSideUnbundle(ctx context.Context, slug, projectRef string, fsys afero.Fs) error { + fmt.Fprintln(os.Stderr, "Downloading "+utils.Bold(slug)) + + metadata, err := getFunctionMetadata(ctx, projectRef, slug) + if err != nil { + return errors.Errorf("failed to get function metadata: %w", err) + } + + entrypointUrl, err := url.Parse(*metadata.EntrypointPath) + if err != nil { + return errors.Errorf("failed to parse entrypoint URL: %w", err) + } + + // Request multipart/form-data response using RequestEditorFn + resp, err := utils.GetSupabase().V1GetAFunctionBody(ctx, projectRef, slug, func(ctx context.Context, req *http.Request) error { + req.Header.Set("Accept", "multipart/form-data") + return nil + }) + if err != nil { + return errors.Errorf("failed to download function: %w", err) + } + defer resp.Body.Close() + + if resp.StatusCode != http.StatusOK { + body, err := io.ReadAll(resp.Body) + if err != nil { + return errors.Errorf("Error status %d: %w", resp.StatusCode, err) + } + return errors.Errorf("Error status %d: %s", resp.StatusCode, string(body)) + } + + // Parse the multipart response + mediaType, params, err := mime.ParseMediaType(resp.Header.Get("Content-Type")) + if err != nil { + return errors.Errorf("failed to parse content type: %w", err) + } + + if !strings.HasPrefix(mediaType, "multipart/") { + return errors.Errorf("expected multipart response, got %s", mediaType) + } + + // Root directory on disk: supabase/functions/ + funcDir := filepath.Join(utils.FunctionsDir, slug) + if err := utils.MkdirIfNotExistFS(fsys, funcDir); err != nil { + return err + } + + type partEntry struct { + path string + data []byte + } + + var parts []partEntry + + // Parse multipart form and buffer parts in memory. + mr := multipart.NewReader(resp.Body, params["boundary"]) + for { + part, err := mr.NextPart() + if errors.Is(err, io.EOF) { + break + } + if err != nil { + return errors.Errorf("failed to read multipart: %w", err) + } + + partPath, err := getPartPath(part) + if err != nil { + return err + } + + data, err := io.ReadAll(part) + if err != nil { + return errors.Errorf("failed to read part data: %w", err) + } + + if partPath == "" { + fmt.Fprintln(utils.GetDebugLogger(), "Skipping part without filename") + } else { + parts = append(parts, partEntry{path: partPath, data: data}) + } + } + + // Collect file paths (excluding empty ones) to infer the base directory. + var filepaths []string + for _, p := range parts { + if p.path != "" { + filepaths = append(filepaths, p.path) + } + } + + baseDir := getBaseDirFromEntrypoint(entrypointUrl, filepaths) + fmt.Println("Function base directory: " + utils.Aqua(baseDir)) + + // Place each file under funcDir using a path relative to baseDir, + // mirroring Studio's getBasePath + relative() behavior. + for _, p := range parts { + if p.path == "" { + continue + } + + relPath := getRelativePathFromBase(baseDir, p.path) + filePath, err := joinWithinDir(funcDir, relPath) + if err != nil { + return err + } + + if err := utils.MkdirIfNotExistFS(fsys, filepath.Dir(filePath)); err != nil { + return err + } + + if err := afero.WriteReader(fsys, filePath, bytes.NewReader(p.data)); err != nil { + return errors.Errorf("failed to write file: %w", err) + } + } + + fmt.Println("Downloaded Function " + utils.Aqua(slug) + " from project " + utils.Aqua(projectRef) + ".") + return nil +} + +// getPartPath extracts the filename for a multipart part, allowing for +// relative paths via the custom Supabase-Path header. +func getPartPath(part *multipart.Part) (string, error) { + // dedicated header to specify relative path, not expected to be used + if relPath := part.Header.Get("Supabase-Path"); relPath != "" { + return relPath, nil + } + + // part.FileName() does not allow us to handle relative paths, so we parse Content-Disposition manually + cd := part.Header.Get("Content-Disposition") + if cd == "" { + return "", nil + } + + _, params, err := mime.ParseMediaType(cd) + if err != nil { + return "", errors.Errorf("failed to parse content disposition: %w", err) + } + + if filename := params["filename"]; filename != "" { + return filename, nil + } + return "", nil +} + +// joinWithinDir safely joins base and rel ensuring the result stays within base directory +func joinWithinDir(base, rel string) (string, error) { + cleanRel := filepath.Clean(rel) + // Be forgiving: treat a rooted path as relative to base (e.g. "/foo" -> "foo") + if filepath.IsAbs(cleanRel) { + cleanRel = strings.TrimLeft(cleanRel, "/\\") + } + if cleanRel == ".." || strings.HasPrefix(cleanRel, "../") { + return "", errors.Errorf("invalid file path outside function directory: %s", rel) + } + joined := filepath.Join(base, cleanRel) + cleanJoined := filepath.Clean(joined) + cleanBase := filepath.Clean(base) + if cleanJoined != cleanBase && !strings.HasPrefix(cleanJoined, cleanBase+"/") { + return "", errors.Errorf("refusing to write outside function directory: %s", rel) + } + return joined, nil +} + +// getBaseDirFromEntrypoint tries to infer the "base" directory for function +// files from the entrypoint URL and the list of filenames, similar to Studio's +// getBasePath logic. +func getBaseDirFromEntrypoint(entrypointUrl *url.URL, filenames []string) string { + if entrypointUrl.Path == "" { + return "/" + } + + entryPath := filepath.ToSlash(entrypointUrl.Path) + + // First, prefer relative filenames (no leading slash) when matching the entrypoint. + var baseDir string + for _, filename := range filenames { + if filename == "" { + continue + } + clean := filepath.ToSlash(filename) + if strings.HasPrefix(clean, "/") { + // Skip absolute paths like /tmp/... + continue + } + if strings.HasSuffix(entryPath, clean) { + baseDir = filepath.Dir(clean) + break + } + } + + // If nothing matched among relative paths, fall back to any filename. + if baseDir == "" { + for _, filename := range filenames { + if filename == "" { + continue + } + clean := filepath.ToSlash(filename) + if strings.HasSuffix(entryPath, clean) { + baseDir = filepath.Dir(clean) + break + } + } + } + + if baseDir != "" { + return baseDir + } + + // Final fallback: derive from the entrypoint URL path itself. + baseDir = filepath.Dir(entrypointUrl.Path) + if baseDir != "" && baseDir != "." { + return baseDir + } + return "/" +} + +// getRelativePathFromBase mirrors the Studio behaviour of making file paths +// relative to the "base" directory inferred from the entrypoint. +func getRelativePathFromBase(baseDir, filename string) string { + if filename == "" { + return "" + } + + cleanBase := filepath.ToSlash(filepath.Clean(baseDir)) + cleanFile := filepath.ToSlash(filepath.Clean(filename)) + + // If we don't have a meaningful base, just normalize to a relative path. + if cleanBase == "" || cleanBase == "/" || cleanBase == "." { + return strings.TrimLeft(cleanFile, "/") + } + + // Try a straightforward relative path first (e.g. source/index.ts -> index.ts). + if rel, err := filepath.Rel(cleanBase, cleanFile); err == nil && rel != "." && !strings.HasPrefix(rel, "..") { + return filepath.ToSlash(rel) + } + + // If the file path contains "//" somewhere (e.g. /tmp/.../source/index.ts), + // strip everything up to and including that segment so we get a stable relative path + // like "index.ts" or "dir/file.ts". + segment := "/" + cleanBase + "/" + if idx := strings.Index(cleanFile, segment); idx >= 0 { + return cleanFile[idx+len(segment):] + } + + // Last resort: return a normalized, slash-stripped path. + return strings.TrimLeft(cleanFile, "/") +} diff --git a/internal/functions/download/download_test.go b/internal/functions/download/download_test.go index b727f95c3..10b2dce74 100644 --- a/internal/functions/download/download_test.go +++ b/internal/functions/download/download_test.go @@ -1,12 +1,18 @@ package download import ( + "bytes" "context" "errors" "fmt" "log" + "mime/multipart" "net/http" + "net/textproto" + "net/url" "os" + "path/filepath" + "strings" "testing" "github.com/h2non/gock" @@ -15,6 +21,7 @@ import ( "github.com/stretchr/testify/require" "github.com/supabase/cli/internal/testing/apitest" "github.com/supabase/cli/internal/utils" + "github.com/supabase/cli/internal/utils/flags" "github.com/supabase/cli/pkg/api" ) @@ -37,201 +44,613 @@ func TestMain(m *testing.M) { os.Exit(m.Run()) } -func TestDownloadCommand(t *testing.T) { +func writeConfig(t *testing.T, fsys afero.Fs) { + t.Helper() + require.NoError(t, utils.WriteConfig(fsys, false)) +} + +func newFunctionMetadata(slug string) api.FunctionSlugResponse { + entrypoint := "file:///src/index.ts" + status := api.FunctionSlugResponseStatus("ACTIVE") + return api.FunctionSlugResponse{ + Id: "1", + Name: slug, + Slug: slug, + Status: status, + Version: 1, + CreatedAt: 0, + UpdatedAt: 0, + EntrypointPath: &entrypoint, + } +} + +func mockFunctionMetadata(projectRef, slug string, meta api.FunctionSlugResponse) { + gock.New(utils.DefaultApiHost). + Get(fmt.Sprintf("/v1/projects/%s/functions/%s", projectRef, slug)). + Reply(http.StatusOK). + JSON(meta) +} + +type multipartPart struct { + filename string + supabasePath string + contents string +} + +func mockMultipartBody(t *testing.T, projectRef, slug string, parts []multipartPart) { + t.Helper() + var buf bytes.Buffer + writer := multipart.NewWriter(&buf) + for _, part := range parts { + headers := textproto.MIMEHeader{} + headers.Set("Content-Disposition", fmt.Sprintf(`form-data; name="file"; filename="%s"`, part.filename)) + if part.supabasePath != "" { + headers.Set("Supabase-Path", part.supabasePath) + } + pw, err := writer.CreatePart(headers) + require.NoError(t, err) + _, err = pw.Write([]byte(part.contents)) + require.NoError(t, err) + } + require.NoError(t, writer.Close()) + + gock.New(utils.DefaultApiHost). + Get(fmt.Sprintf("/v1/projects/%s/functions/%s/body", projectRef, slug)). + Reply(http.StatusOK). + SetHeader("Content-Type", writer.FormDataContentType()). + Body(&buf) +} + +func mustParseURL(t *testing.T, raw string) *url.URL { + t.Helper() + u, err := url.Parse(raw) + require.NoError(t, err) + return u +} + +func TestRun(t *testing.T) { const slug = "test-func" - t.Run("downloads eszip bundle", func(t *testing.T) { - // Setup in-memory fs + t.Run("downloads legacy bundle", func(t *testing.T) { fsys := afero.NewMemMapFs() - // Setup valid project ref + writeConfig(t, fsys) + t.Cleanup(func() { + gock.OffAll() + utils.CmdSuggestion = "" + }) project := apitest.RandomProjectRef() - // Setup valid access token + flags.ProjectRef = project + t.Cleanup(func() { flags.ProjectRef = "" }) + token := apitest.RandomAccessToken(t) t.Setenv("SUPABASE_ACCESS_TOKEN", string(token)) - // Setup valid deno path + _, err := fsys.Create(utils.DenoPathOverride) require.NoError(t, err) - // Setup mock api - defer gock.OffAll() - gock.New(utils.DefaultApiHost). - Get("/v1/projects/" + project + "/functions/" + slug). - Reply(http.StatusOK). - JSON(api.FunctionResponse{Id: "1"}) + + meta := newFunctionMetadata(slug) + mockFunctionMetadata(project, slug, meta) gock.New(utils.DefaultApiHost). - Get("/v1/projects/" + project + "/functions/" + slug + "/body"). + Get(fmt.Sprintf("/v1/projects/%s/functions/%s/body", project, slug)). Reply(http.StatusOK) - // Run test - err = Run(context.Background(), slug, project, true, fsys) - // Check error + + err = Run(context.Background(), slug, project, true, false, fsys) assert.NoError(t, err) assert.Empty(t, apitest.ListUnmatchedRequests()) }) t.Run("throws error on malformed slug", func(t *testing.T) { - // Setup in-memory fs fsys := afero.NewMemMapFs() - // Setup valid project ref + writeConfig(t, fsys) project := apitest.RandomProjectRef() - // Run test - err := Run(context.Background(), "@", project, true, fsys) - // Check error + flags.ProjectRef = project + t.Cleanup(func() { flags.ProjectRef = "" }) + + err := Run(context.Background(), "@", project, true, false, fsys) assert.ErrorContains(t, err, "Invalid Function name.") }) t.Run("throws error on failure to install deno", func(t *testing.T) { - // Setup in-memory fs - fsys := afero.NewReadOnlyFs(afero.NewMemMapFs()) - // Setup valid project ref + base := afero.NewMemMapFs() + writeConfig(t, base) project := apitest.RandomProjectRef() - // Run test - err := Run(context.Background(), slug, project, true, fsys) - // Check error + flags.ProjectRef = project + t.Cleanup(func() { flags.ProjectRef = "" }) + + err := Run(context.Background(), slug, project, true, false, afero.NewReadOnlyFs(base)) assert.ErrorContains(t, err, "operation not permitted") }) t.Run("throws error on copy failure", func(t *testing.T) { - // Setup in-memory fs - fsys := afero.NewMemMapFs() - // Setup valid project ref - project := apitest.RandomProjectRef() - // Setup valid deno path - _, err := fsys.Create(utils.DenoPathOverride) + base := afero.NewMemMapFs() + writeConfig(t, base) + _, err := base.Create(utils.DenoPathOverride) require.NoError(t, err) - // Run test - err = Run(context.Background(), slug, project, true, afero.NewReadOnlyFs(fsys)) - // Check error + project := apitest.RandomProjectRef() + flags.ProjectRef = project + t.Cleanup(func() { flags.ProjectRef = "" }) + + err = Run(context.Background(), slug, project, true, false, afero.NewReadOnlyFs(base)) assert.ErrorContains(t, err, "operation not permitted") }) t.Run("throws error on missing function", func(t *testing.T) { - // Setup in-memory fs fsys := afero.NewMemMapFs() - // Setup valid project ref + writeConfig(t, fsys) + t.Cleanup(func() { + gock.OffAll() + utils.CmdSuggestion = "" + }) project := apitest.RandomProjectRef() - // Setup valid access token + flags.ProjectRef = project + t.Cleanup(func() { flags.ProjectRef = "" }) + token := apitest.RandomAccessToken(t) t.Setenv("SUPABASE_ACCESS_TOKEN", string(token)) - // Setup valid deno path + _, err := fsys.Create(utils.DenoPathOverride) require.NoError(t, err) - // Setup mock api - defer gock.OffAll() + gock.New(utils.DefaultApiHost). - Get("/v1/projects/" + project + "/functions/" + slug). + Get(fmt.Sprintf("/v1/projects/%s/functions/%s", project, slug)). Reply(http.StatusNotFound). JSON(map[string]string{"message": "Function not found"}) - // Run test - err = Run(context.Background(), slug, project, true, fsys) - // Check error + + err = Run(context.Background(), slug, project, true, false, fsys) assert.ErrorContains(t, err, "Function test-func does not exist on the Supabase project.") + assert.Empty(t, apitest.ListUnmatchedRequests()) + }) + + t.Run("downloads bundle with docker when available", func(t *testing.T) { + const slugDocker = "demo" + fsys := afero.NewMemMapFs() + writeConfig(t, fsys) + project := apitest.RandomProjectRef() + flags.ProjectRef = project + t.Cleanup(func() { flags.ProjectRef = "" }) + require.NoError(t, flags.LoadConfig(fsys)) + + token := apitest.RandomAccessToken(t) + t.Setenv("SUPABASE_ACCESS_TOKEN", string(token)) + + require.NoError(t, apitest.MockDocker(utils.Docker)) + dockerHost := utils.Docker.DaemonHost() + + defer func() { + gock.OffAll() + utils.CmdSuggestion = "" + }() + + gock.New(dockerHost). + Head("/_ping"). + Reply(http.StatusOK) + + imageURL := utils.GetRegistryImageUrl(utils.Config.EdgeRuntime.Image) + containerID := "docker-unbundle-test" + apitest.MockDockerStart(utils.Docker, imageURL, containerID) + require.NoError(t, apitest.MockDockerLogs(utils.Docker, containerID, "unbundle ok")) + + gock.New(utils.DefaultApiHost). + Get(fmt.Sprintf("/v1/projects/%s/functions/%s/body", project, slugDocker)). + Reply(http.StatusOK). + BodyString("fake eszip payload") + + err := Run(context.Background(), slugDocker, project, false, true, fsys) + require.NoError(t, err) + + eszipPath := filepath.Join(utils.TempDir, fmt.Sprintf("output_%s.eszip", slugDocker)) + exists, err := afero.Exists(fsys, eszipPath) + require.NoError(t, err) + assert.False(t, exists, "temporary eszip file should be removed after extraction") + + assert.Empty(t, apitest.ListUnmatchedRequests()) + }) + + t.Run("falls back to server-side unbundle when docker unavailable", func(t *testing.T) { + const slugDocker = "demo-fallback" + fsys := afero.NewMemMapFs() + writeConfig(t, fsys) + project := apitest.RandomProjectRef() + flags.ProjectRef = project + t.Cleanup(func() { flags.ProjectRef = "" }) + require.NoError(t, flags.LoadConfig(fsys)) + + token := apitest.RandomAccessToken(t) + t.Setenv("SUPABASE_ACCESS_TOKEN", string(token)) + + require.NoError(t, apitest.MockDocker(utils.Docker)) + dockerHost := utils.Docker.DaemonHost() + + defer func() { + gock.OffAll() + utils.CmdSuggestion = "" + }() + + gock.New(dockerHost). + Head("/_ping"). + ReplyError(errors.New("docker unavailable")) + + meta := newFunctionMetadata(slugDocker) + entrypoint := "file:///source/index.ts" + meta.EntrypointPath = &entrypoint + mockFunctionMetadata(project, slugDocker, meta) + mockMultipartBody(t, project, slugDocker, []multipartPart{ + {filename: "source/index.ts", contents: "console.log('hello')"}, + }) + + err := Run(context.Background(), slugDocker, project, false, true, fsys) + require.NoError(t, err) + + data, err := afero.ReadFile(fsys, filepath.Join(utils.FunctionsDir, slugDocker, "index.ts")) + require.NoError(t, err) + assert.Equal(t, "console.log('hello')", string(data)) + + assert.Empty(t, apitest.ListUnmatchedRequests()) }) } func TestDownloadFunction(t *testing.T) { const slug = "test-func" - // Setup valid project ref project := apitest.RandomProjectRef() - // Setup valid access token token := apitest.RandomAccessToken(t) t.Setenv("SUPABASE_ACCESS_TOKEN", string(token)) t.Run("throws error on network error", func(t *testing.T) { - // Setup mock api - defer gock.OffAll() + t.Cleanup(gock.OffAll) + mockFunctionMetadata(project, slug, newFunctionMetadata(slug)) gock.New(utils.DefaultApiHost). - Get("/v1/projects/" + project + "/functions/" + slug). - Reply(http.StatusOK). - JSON(api.FunctionResponse{Id: "1"}) - gock.New(utils.DefaultApiHost). - Get("/v1/projects/" + project + "/functions/" + slug + "/body"). + Get(fmt.Sprintf("/v1/projects/%s/functions/%s/body", project, slug)). ReplyError(errors.New("network error")) - // Run test + err := downloadFunction(context.Background(), project, slug, "") - // Check error + assert.ErrorContains(t, err, "failed to get function body") assert.ErrorContains(t, err, "network error") }) t.Run("throws error on service unavailable", func(t *testing.T) { - // Setup mock api - defer gock.OffAll() + t.Cleanup(gock.OffAll) + mockFunctionMetadata(project, slug, newFunctionMetadata(slug)) gock.New(utils.DefaultApiHost). - Get("/v1/projects/" + project + "/functions/" + slug). - Reply(http.StatusOK). - JSON(api.FunctionResponse{Id: "1"}) - gock.New(utils.DefaultApiHost). - Get("/v1/projects/" + project + "/functions/" + slug + "/body"). + Get(fmt.Sprintf("/v1/projects/%s/functions/%s/body", project, slug)). Reply(http.StatusServiceUnavailable) - // Run test + err := downloadFunction(context.Background(), project, slug, "") - // Check error - assert.ErrorContains(t, err, "Unexpected error downloading Function:") + assert.ErrorContains(t, err, "Unexpected error downloading Function") }) t.Run("throws error on extract failure", func(t *testing.T) { - // Setup deno error + t.Cleanup(gock.OffAll) t.Setenv("TEST_DENO_ERROR", "extract failed") - // Setup mock api - defer gock.OffAll() - gock.New(utils.DefaultApiHost). - Get("/v1/projects/" + project + "/functions/" + slug). - Reply(http.StatusOK). - JSON(api.FunctionResponse{Id: "1"}) + mockFunctionMetadata(project, slug, newFunctionMetadata(slug)) gock.New(utils.DefaultApiHost). - Get("/v1/projects/" + project + "/functions/" + slug + "/body"). + Get(fmt.Sprintf("/v1/projects/%s/functions/%s/body", project, slug)). Reply(http.StatusOK) - // Run test + err := downloadFunction(context.Background(), project, slug, "") - // Check error assert.ErrorContains(t, err, "Error downloading function: exit status 1\nextract failed\n") assert.Empty(t, apitest.ListUnmatchedRequests()) }) } -func TestGetMetadata(t *testing.T) { +func TestGetFunctionMetadata(t *testing.T) { const slug = "test-func" project := apitest.RandomProjectRef() - // Setup valid access token - token := apitest.RandomAccessToken(t) - t.Setenv("SUPABASE_ACCESS_TOKEN", string(token)) t.Run("fallback to default paths", func(t *testing.T) { - // Setup mock api - defer gock.OffAll() - gock.New(utils.DefaultApiHost). - Get("/v1/projects/" + project + "/functions/" + slug). - Reply(http.StatusOK). - JSON(api.FunctionResponse{Id: "1"}) - // Run test - meta, err := getFunctionMetadata(context.Background(), project, slug) - // Check error + t.Cleanup(gock.OffAll) + meta := newFunctionMetadata(slug) + meta.EntrypointPath = nil + meta.ImportMapPath = nil + mockFunctionMetadata(project, slug, meta) + + got, err := getFunctionMetadata(context.Background(), project, slug) assert.NoError(t, err) - assert.Equal(t, legacyEntrypointPath, *meta.EntrypointPath) - assert.Equal(t, legacyImportMapPath, *meta.ImportMapPath) + require.NotNil(t, got) + assert.Equal(t, legacyEntrypointPath, *got.EntrypointPath) + assert.Equal(t, legacyImportMapPath, *got.ImportMapPath) }) t.Run("throws error on network error", func(t *testing.T) { - // Setup mock api - defer gock.OffAll() + t.Cleanup(gock.OffAll) gock.New(utils.DefaultApiHost). - Get("/v1/projects/" + project + "/functions/" + slug). + Get(fmt.Sprintf("/v1/projects/%s/functions/%s", project, slug)). ReplyError(errors.New("network error")) - // Run test + meta, err := getFunctionMetadata(context.Background(), project, slug) - // Check error - assert.ErrorContains(t, err, "network error") + assert.ErrorContains(t, err, "failed to get function metadata") assert.Nil(t, meta) }) t.Run("throws error on service unavailable", func(t *testing.T) { - // Setup mock api - defer gock.OffAll() + t.Cleanup(gock.OffAll) gock.New(utils.DefaultApiHost). - Get("/v1/projects/" + project + "/functions/" + slug). + Get(fmt.Sprintf("/v1/projects/%s/functions/%s", project, slug)). Reply(http.StatusServiceUnavailable) - // Run test + meta, err := getFunctionMetadata(context.Background(), project, slug) - // Check error assert.ErrorContains(t, err, "Failed to download Function test-func on the Supabase project:") assert.Nil(t, meta) }) } + +func TestDownloadWithServerSideUnbundle(t *testing.T) { + const slug = "test-func" + token := apitest.RandomAccessToken(t) + t.Setenv("SUPABASE_ACCESS_TOKEN", string(token)) + + t.Run("writes files using inferred base directory", func(t *testing.T) { + fsys := afero.NewMemMapFs() + project := apitest.RandomProjectRef() + t.Cleanup(func() { + gock.OffAll() + utils.CmdSuggestion = "" + }) + + meta := newFunctionMetadata(slug) + entrypoint := "file:///source/index.ts" + meta.EntrypointPath = &entrypoint + mockFunctionMetadata(project, slug, meta) + mockMultipartBody(t, project, slug, []multipartPart{ + {filename: "source/index.ts", contents: "console.log('hello')"}, + {filename: "source/utils.ts", contents: "export const value = 1;"}, + }) + + err := downloadWithServerSideUnbundle(context.Background(), slug, project, fsys) + require.NoError(t, err) + + data, err := afero.ReadFile(fsys, filepath.Join(utils.FunctionsDir, slug, "index.ts")) + require.NoError(t, err) + assert.Equal(t, "console.log('hello')", string(data)) + + data, err = afero.ReadFile(fsys, filepath.Join(utils.FunctionsDir, slug, "utils.ts")) + require.NoError(t, err) + assert.Equal(t, "export const value = 1;", string(data)) + + assert.Empty(t, apitest.ListUnmatchedRequests()) + }) + + t.Run("fails when response not multipart", func(t *testing.T) { + fsys := afero.NewMemMapFs() + project := apitest.RandomProjectRef() + t.Cleanup(func() { + gock.OffAll() + utils.CmdSuggestion = "" + }) + mockFunctionMetadata(project, slug, newFunctionMetadata(slug)) + gock.New(utils.DefaultApiHost). + Get(fmt.Sprintf("/v1/projects/%s/functions/%s/body", project, slug)). + Reply(http.StatusOK). + SetHeader("Content-Type", "application/json"). + BodyString(`{"error":"no multipart"}`) + + err := downloadWithServerSideUnbundle(context.Background(), slug, project, fsys) + assert.ErrorContains(t, err, "expected multipart response") + }) + + t.Run("fails when part escapes base dir", func(t *testing.T) { + fsys := afero.NewMemMapFs() + project := apitest.RandomProjectRef() + t.Cleanup(func() { + gock.OffAll() + utils.CmdSuggestion = "" + }) + + meta := newFunctionMetadata(slug) + entrypoint := "file:///source/index.ts" + meta.EntrypointPath = &entrypoint + mockFunctionMetadata(project, slug, meta) + mockMultipartBody(t, project, slug, []multipartPart{ + {filename: "source/index.ts", contents: "console.log('hello')"}, + {filename: "source/secret.env", supabasePath: "../secret.env", contents: "SECRET=1"}, + }) + + err := downloadWithServerSideUnbundle(context.Background(), slug, project, fsys) + assert.ErrorContains(t, err, "invalid file path outside function directory") + }) +} + +func TestGetPartPath(t *testing.T) { + t.Parallel() + + newPart := func(headers map[string]string) *multipart.Part { + mh := make(textproto.MIMEHeader, len(headers)) + for k, v := range headers { + mh.Set(k, v) + } + return &multipart.Part{Header: mh} + } + + t.Run("returns path from Supabase header", func(t *testing.T) { + part := newPart(map[string]string{ + "Supabase-Path": "dir/file.ts", + }) + got, err := getPartPath(part) + require.NoError(t, err) + assert.Equal(t, "dir/file.ts", got) + }) + + t.Run("returns filename from content disposition", func(t *testing.T) { + part := newPart(map[string]string{ + "Content-Disposition": `form-data; name="file"; filename="test-func/index.ts"`, + }) + got, err := getPartPath(part) + require.NoError(t, err) + assert.Equal(t, "test-func/index.ts", got) + }) + + t.Run("returns filename from editor-originated content disposition", func(t *testing.T) { + part := newPart(map[string]string{ + "Content-Disposition": `form-data; name="file"; filename="source/index.ts"`, + }) + got, err := getPartPath(part) + require.NoError(t, err) + assert.Equal(t, "source/index.ts", got) + }) + + t.Run("writes file of arbitrary depth", func(t *testing.T) { + part := newPart(map[string]string{ + "Content-Disposition": `form-data; name="file"; filename="test-func/dir/subdir/file.ts"`, + }) + got, err := getPartPath(part) + require.NoError(t, err) + assert.Equal(t, "test-func/dir/subdir/file.ts", got) + }) + + t.Run("returns empty when no filename provided", func(t *testing.T) { + part := newPart(map[string]string{ + "Content-Disposition": `form-data; name="file"`, + }) + got, err := getPartPath(part) + require.NoError(t, err) + assert.Equal(t, "", got) + }) + + t.Run("returns error on invalid content disposition", func(t *testing.T) { + part := newPart(map[string]string{ + "Content-Disposition": `form-data; filename="unterminated`, + }) + got, err := getPartPath(part) + require.ErrorContains(t, err, "failed to parse content disposition") + assert.Equal(t, "", got) + }) +} + +func TestGetBaseDirFromEntrypoint(t *testing.T) { + t.Parallel() + + tests := []struct { + name string + entrypoint string + filenames []string + want string + }{ + { + name: "prefers relative match", + entrypoint: "file:///source/index.ts", + filenames: []string{"source/index.ts", "source/utils.ts"}, + want: "source", + }, + { + name: "falls back to absolute match", + entrypoint: "file:///src/index.ts", + filenames: []string{"/tmp/project/src/index.ts"}, + want: "/tmp/project/src", + }, + { + name: "falls back to entrypoint directory", + entrypoint: "file:///dir/api/index.ts", + filenames: []string{"/tmp/project/api/index.ts"}, + want: "/dir/api", + }, + { + name: "empty entrypoint returns root", + entrypoint: "file:///", + filenames: nil, + want: "/", + }, + } + + for _, tt := range tests { + tt := tt + t.Run(tt.name, func(t *testing.T) { + t.Parallel() + got := getBaseDirFromEntrypoint(mustParseURL(t, tt.entrypoint), tt.filenames) + assert.Equal(t, tt.want, got) + }) + } +} + +func TestGetRelativePathFromBase(t *testing.T) { + t.Parallel() + + tests := []struct { + name string + base string + filename string + want string + }{ + { + name: "strips relative base", + base: "source", + filename: "source/index.ts", + want: "index.ts", + }, + { + name: "trims leading slash when base empty", + base: "", + filename: "/tmp/source/index.ts", + want: "tmp/source/index.ts", + }, + { + name: "trims leading slash when base root", + base: "/", + filename: "/index.ts", + want: "index.ts", + }, + { + name: "handles absolute base prefix", + base: "/tmp/source", + filename: "/tmp/source/dir/file.ts", + want: "dir/file.ts", + }, + { + name: "strips embedded base segment", + base: "source", + filename: "/Users/foo/project/source/utils.ts", + want: "utils.ts", + }, + { + name: "preserves escaping path when outside base", + base: "source", + filename: "../secret.ts", + want: "../secret.ts", + }, + } + + for _, tt := range tests { + tt := tt + t.Run(tt.name, func(t *testing.T) { + t.Parallel() + got := getRelativePathFromBase(tt.base, tt.filename) + assert.Equal(t, tt.want, got) + }) + } +} + +func TestJoinWithinDir(t *testing.T) { + t.Parallel() + + base := filepath.Join(os.TempDir(), "base-dir") + + t.Run("joins path within base directory", func(t *testing.T) { + got, err := joinWithinDir(base, filepath.Join("sub", "file.ts")) + require.NoError(t, err) + cleanBase := filepath.Clean(base) + cleanGot := filepath.Clean(got) + assert.True(t, cleanGot == cleanBase || strings.HasPrefix(cleanGot, cleanBase+string(os.PathSeparator))) + }) + + t.Run("normalizes leading slash", func(t *testing.T) { + got, err := joinWithinDir(base, "/foo/bar.ts") + require.NoError(t, err) + assert.Equal(t, filepath.Join(filepath.Clean(base), "foo", "bar.ts"), filepath.Clean(got)) + }) + + t.Run("rejects parent directory traversal", func(t *testing.T) { + got, err := joinWithinDir(base, filepath.Join("..", "escape")) + require.Error(t, err) + assert.Equal(t, "", got) + }) + + t.Run("accepts internal traversal", func(t *testing.T) { + got, err := joinWithinDir(base, filepath.Join("dir", "..", "file.ts")) + require.NoError(t, err) + assert.Equal(t, filepath.Join(filepath.Clean(base), "file.ts"), filepath.Clean(got)) + }) + + t.Run("rejects traversal beginning with ../", func(t *testing.T) { + got, err := joinWithinDir(base, filepath.Join("..", "..", "file.ts")) + require.Error(t, err) + assert.Equal(t, "", got) + }) +}