Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
23 changes: 17 additions & 6 deletions machine.go
Original file line number Diff line number Diff line change
Expand Up @@ -338,7 +338,10 @@ func NewMachine(ctx context.Context, cfg Config, opts ...Opt) (*Machine, error)
return m, nil
}

// Start will iterate through the handler list and call each handler. If an
// Start actually start a Firecracker microVM.
// The context must not be cancelled while the microVM is running.
//
// It will iterate through the handler list and call each handler. If an
// error occurred during handler execution, that error will be returned. If the
// handlers succeed, then this will start the VMM instance.
// Start may only be called once per Machine. Subsequent calls will return
Expand Down Expand Up @@ -516,14 +519,22 @@ func (m *Machine) startVMM(ctx context.Context) error {

return err
}

// This goroutine is used to kill the process by context cancelletion,
// but doesn't tell anyone about that.
go func() {
select {
case <-ctx.Done():
m.fatalErr = ctx.Err()
case err := <-errCh:
m.fatalErr = err
<-ctx.Done()
err := m.stopVMM()
if err != nil {
m.logger.WithError(err).Errorf("failed to stop vm %q", m.Cfg.VMID)
}
}()

// This goroutine is used to tell clients that the process is stopped
// (gracefully or not).
go func() {
m.fatalErr = <-errCh
m.logger.Debugf("closing the exitCh %v", m.fatalErr)
close(m.exitCh)
}()

Expand Down
104 changes: 81 additions & 23 deletions machine_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -1081,37 +1081,95 @@ func TestCaptureFifoToFile_leak(t *testing.T) {
assert.Contains(t, loggerBuffer.String(), `file already closed`, "log")
}

func TestWaitWithKill(t *testing.T) {
// Replace filesystem-unsafe characters (such as /) which are often seen in Go's test names
var fsSafeTestName = strings.NewReplacer("/", "_")

func TestWait(t *testing.T) {
fctesting.RequiresRoot(t)
ctx := context.Background()

socketPath := filepath.Join(testDataPath, t.Name())
defer os.Remove(socketPath)
cases := []struct {
name string
stop func(m *Machine, cancel context.CancelFunc)
}{
{
name: "StopVMM",
stop: func(m *Machine, _ context.CancelFunc) {
err := m.StopVMM()
require.NoError(t, err)
},
},
{
name: "Kill",
stop: func(m *Machine, cancel context.CancelFunc) {
pid, err := m.PID()
require.NoError(t, err)

process, err := os.FindProcess(pid)
err = process.Kill()
require.NoError(t, err)
},
},
{
name: "Context Cancel",
stop: func(m *Machine, cancel context.CancelFunc) {
cancel()
},
},
{
name: "StopVMM + Context Cancel",
stop: func(m *Machine, cancel context.CancelFunc) {
m.StopVMM()
time.Sleep(1 * time.Second)
cancel()
},
},
}

cfg := createValidConfig(t, socketPath)
cmd := VMCommandBuilder{}.
WithSocketPath(cfg.SocketPath).
WithBin(getFirecrackerBinaryPath()).
Build(ctx)
m, err := NewMachine(ctx, cfg, WithProcessRunner(cmd))
require.NoError(t, err)
for _, c := range cases {
t.Run(c.name, func(t *testing.T) {
ctx := context.Background()
vmContext, vmCancel := context.WithCancel(context.Background())

err = m.Start(ctx)
require.NoError(t, err)
socketPath := filepath.Join(testDataPath, fsSafeTestName.Replace(t.Name()))
defer os.Remove(socketPath)

go func() {
pid, err := m.PID()
require.NoError(t, err)
cfg := createValidConfig(t, socketPath)
m, err := NewMachine(ctx, cfg, func(m *Machine) {
// Rewriting m.cmd partially wouldn't work since Cmd has
// some unexported members
args := m.cmd.Args[1:]
m.cmd = exec.Command(getFirecrackerBinaryPath(), args...)
})
require.NoError(t, err)

process, err := os.FindProcess(pid)
require.NoError(t, err)
err = m.Start(vmContext)
require.NoError(t, err)

err = process.Kill()
require.NoError(t, err)
}()
pid, err := m.PID()
require.NoError(t, err)

err = m.Wait(ctx)
require.Error(t, err, "Firecracker was killed and it must be reported")
var wg sync.WaitGroup
wg.Add(1)
go func() {
defer wg.Done()
c.stop(m, vmCancel)
}()

err = m.Wait(ctx)
require.Error(t, err, "Firecracker was killed and it must be reported")
t.Logf("err = %v", err)

proc, err := os.FindProcess(pid)
// Having an error here doesn't mean the process is not there.
// In fact it won't be non-nil on Unix systems
require.NoError(t, err)

err = proc.Signal(syscall.Signal(0))
require.Equal(t, "os: process already finished", err.Error())

wg.Wait()
})
}
}

func TestWaitWithInvalidBinary(t *testing.T) {
Expand Down