Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
25 changes: 19 additions & 6 deletions internal/plugin6/grpc_provider.go
Original file line number Diff line number Diff line change
Expand Up @@ -1515,11 +1515,8 @@ func (p *GRPCProvider) ReadStateBytes(r providers.ReadStateBytesRequest) (resp p
for {
chunk, err := client.Recv()
if err == io.EOF {
// End of stream, we're done
if chunk != nil {
// TODO: The EOF error could be just returned alongside the last chunk?
resp.Diagnostics = resp.Diagnostics.Append(convert.ProtoToDiagnostics(chunk.Diagnostics))
}
// No chunk is returned alongside an EOF.
// And as EOF is a sentinel error it isn't wrapped as a diagnostic.
break
}
Comment on lines 1517 to 1521
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Please see this comment in our terraform-plugin-go PR explaining why we will never see a chunk alongside the EOF error here: https://github.com/hashicorp/terraform-plugin-go/pull/563/files#r2362735975

if err != nil {
Expand All @@ -1545,13 +1542,29 @@ func (p *GRPCProvider) ReadStateBytes(r providers.ReadStateBytesRequest) (resp p
logger.Trace("GRPCProvider.v6: ReadStateBytes: read bytes of a chunk", n)
}

logger.Trace("GRPCProvider.v6: ReadStateBytes: received all chunks", buf.Len())
if resp.Diagnostics.HasErrors() {
logger.Trace("GRPCProvider.v6: ReadStateBytes: experienced an error when receiving state data from the provider", resp.Diagnostics.Err())
return resp
}

if buf.Len() != expectedTotalLength {
logger.Trace("GRPCProvider.v6: ReadStateBytes: received %d bytes but expected the total bytes to be %d.", buf.Len(), expectedTotalLength)

err = fmt.Errorf("expected state file of total %d bytes, received %d bytes",
expectedTotalLength, buf.Len())
resp.Diagnostics = resp.Diagnostics.Append(err)
return resp
}

// We're done, so close the stream
logger.Trace("GRPCProvider.v6: ReadStateBytes: received all chunks, total bytes: ", buf.Len())
err = client.CloseSend()
if err != nil {
resp.Diagnostics = resp.Diagnostics.Append(grpcErr(err))
return resp
}

// Add the state data in the response once we know there are no errors
resp.Bytes = buf.Bytes()

return resp
Expand Down
97 changes: 87 additions & 10 deletions internal/plugin6/grpc_provider_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -3556,6 +3556,9 @@ func TestGRPCProvider_ReadStateBytes(t *testing.T) {
return ret.resp, ret.err
}).Times(3)

// There will be a call to CloseSend to close the stream
mockReadBytesClient.EXPECT().CloseSend().Return(nil).Times(1)

// Act
request := providers.ReadStateBytesRequest{
TypeName: expectedReq.TypeName,
Expand Down Expand Up @@ -3704,7 +3707,7 @@ func TestGRPCProvider_ReadStateBytes(t *testing.T) {
// Define what will be returned by each call to Recv
mockReadBytesClient.EXPECT().Recv().Return(&proto.ReadStateBytes_Response{
Diagnostics: []*proto.Diagnostic{
&proto.Diagnostic{
{
Severity: proto.Diagnostic_ERROR,
Summary: "Error from test",
Detail: "This error is forced by the test case",
Expand Down Expand Up @@ -3752,15 +3755,44 @@ func TestGRPCProvider_ReadStateBytes(t *testing.T) {
).Return(mockReadBytesClient, nil)

// Define what will be returned by each call to Recv
mockReadBytesClient.EXPECT().Recv().Return(&proto.ReadStateBytes_Response{
Diagnostics: []*proto.Diagnostic{
&proto.Diagnostic{
Severity: proto.Diagnostic_WARNING,
Summary: "Warning from test",
Detail: "This warning is forced by the test case",
chunk := "hello world"
totalLength := len(chunk)
mockResp := map[int]struct {
resp *proto.ReadStateBytes_Response
err error
}{
0: {
resp: &proto.ReadStateBytes_Response{
Bytes: []byte(chunk),
TotalLength: int64(totalLength),
Range: &proto.StateRange{
Start: 0,
End: int64(len(chunk)),
},
Diagnostics: []*proto.Diagnostic{
{
Severity: proto.Diagnostic_WARNING,
Summary: "Warning from test",
Detail: "This warning is forced by the test case",
},
},
},
err: nil,
},
1: {
resp: &proto.ReadStateBytes_Response{},
err: io.EOF,
},
}, io.EOF)
}
var count int
mockReadBytesClient.EXPECT().Recv().DoAndReturn(func() (*proto.ReadStateBytes_Response, error) {
ret := mockResp[count]
count++
return ret.resp, ret.err
}).Times(2)

// There will be a call to CloseSend to close the stream
mockReadBytesClient.EXPECT().CloseSend().Return(nil).Times(1)

// Act
request := providers.ReadStateBytesRequest{
Expand All @@ -3775,8 +3807,8 @@ func TestGRPCProvider_ReadStateBytes(t *testing.T) {
if resp.Diagnostics.ErrWithWarnings().Error() != expectedWarn {
t.Fatalf("expected warning diagnostic %q, but got: %q", expectedWarn, resp.Diagnostics.ErrWithWarnings().Error())
}
if len(resp.Bytes) != 0 {
t.Fatalf("expected data to be omitted in error condition, but got: %q", string(resp.Bytes))
if len(resp.Bytes) == 0 {
t.Fatal("expected data to included despite warnings, but got no bytes")
}
})

Expand Down Expand Up @@ -3820,6 +3852,51 @@ func TestGRPCProvider_ReadStateBytes(t *testing.T) {
t.Fatalf("expected data to be omitted in error condition, but got: %q", string(resp.Bytes))
}
})

t.Run("when closing the stream, grpc errors are surfaced via diagnostics", func(t *testing.T) {
client := mockProviderClient(t)
p := &GRPCProvider{
client: client,
ctx: context.Background(),
}

// Call to ReadStateBytes
// > Assert the arguments received
// > Define the returned mock client
mockClient := mockReadStateBytesClient(t)
expectedReq := &proto.ReadStateBytes_Request{
TypeName: "mock_store",
StateId: backend.DefaultStateName,
}
client.EXPECT().ReadStateBytes(
gomock.Any(),
gomock.Eq(expectedReq),
).Return(mockClient, nil)

// Sufficient mocking of Recv to get to the call to CloseSend
mockClient.EXPECT().Recv().Return(&proto.ReadStateBytes_Response{}, io.EOF)

// Force a gRPC error from CloseSend
mockError := errors.New("grpc error forced in test")
mockClient.EXPECT().CloseSend().Return(mockError).Times(1)

// Act
request := providers.ReadStateBytesRequest{
TypeName: expectedReq.TypeName,
StateId: expectedReq.StateId,
}
resp := p.ReadStateBytes(request)

// Assert returned values
checkDiagsHasError(t, resp.Diagnostics)
wantErr := fmt.Sprintf("Plugin error: The plugin returned an unexpected error from plugin6.(*GRPCProvider).ReadStateBytes: %s", mockError)
if resp.Diagnostics.Err().Error() != wantErr {
t.Fatalf("expected error diagnostic %q, but got: %q", wantErr, resp.Diagnostics.Err())
}
if len(resp.Bytes) != 0 {
t.Fatalf("expected data to be omitted in error condition, but got: %q", string(resp.Bytes))
}
})
}

func TestGRPCProvider_WriteStateBytes(t *testing.T) {
Expand Down