diff --git a/cmd/go-mutesting/main.go b/cmd/go-mutesting/main.go index ee786f5..1b84c2d 100644 --- a/cmd/go-mutesting/main.go +++ b/cmd/go-mutesting/main.go @@ -13,6 +13,7 @@ import ( "io/ioutil" "os" "os/exec" + "os/signal" "path/filepath" "regexp" "strings" @@ -158,6 +159,10 @@ func (ms *mutationStats) Total() int { return ms.passed + ms.failed + ms.skipped } +// Used by OS signal handler. +var tmpFiles = make(map[string]struct{}) +var tmpDir string + func mainCmd(args []string) int { var opts = &options{} var mutationBlackList = map[string]struct{}{} @@ -166,6 +171,28 @@ func mainCmd(args []string) int { return exitCode } + sigCh := make(chan os.Signal, 1) + signal.Notify(sigCh, os.Interrupt, syscall.SIGINT, syscall.SIGTERM, syscall.SIGHUP, syscall.SIGQUIT) + go func() { + for { + sig := <-sigCh + verbose(opts, "Signal (%s) received, exiting...", sig) + + for file := range tmpFiles { + // "smth.tmp" -> "smth" + if err := os.Rename(file, file[:len(file)-4]); err != nil { + fmt.Printf("Failed to restore original version of file %s, it is stored in %s.\n", file[:len(file)-4], file) + fmt.Println(err) + } + } + if err := os.RemoveAll(tmpDir); err != nil { + fmt.Printf("Failed to remove temporary directory (%s): %v.\n", tmpDir, err) + } + + os.Exit(1) + } + }() + files := importing.FilesOfArgs(opts.Remaining.Targets) if len(files) == 0 { return exitError("Could not find any suitable Go source files") @@ -238,7 +265,8 @@ MUTATOR: }) } - tmpDir, err := ioutil.TempDir("", "go-mutesting-") + var err error + tmpDir, err = ioutil.TempDir("", "go-mutesting-") if err != nil { panic(err) } @@ -390,6 +418,9 @@ func mutateExec(opts *options, pkg *types.Package, file string, src ast.Node, mu } defer func() { + // Remove tmp file from clean list for signal handler. + delete(tmpFiles, file+".tmp") + _ = os.Rename(file+".tmp", file) }() @@ -402,6 +433,9 @@ func mutateExec(opts *options, pkg *types.Package, file string, src ast.Node, mu panic(err) } + // Add tmp file to clean list for signal handler. + tmpFiles[file+".tmp"] = struct{}{} + pkgName := pkg.Path() if opts.Test.Recursive { pkgName += "/..."