Skip to content

Commit a578ac2

Browse files
authored
Sarah's updates to radek/pss-read-write (#37642)
* Update code to not expect a chunk when EOF encountered * Return early if any grpc errors are encountered during ReadStateBytes * Close the stream with CloseSend once everything's read without error. Add test case about handling grpc errors from CloseSend. * Fix test case about warnings: We would expect to receive a chunk with data alongside the warning and have a normal closing of the stream after EOF * Add log line, remove unneeded type info
1 parent d4cb951 commit a578ac2

File tree

2 files changed

+106
-16
lines changed

2 files changed

+106
-16
lines changed

internal/plugin6/grpc_provider.go

Lines changed: 19 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1515,11 +1515,8 @@ func (p *GRPCProvider) ReadStateBytes(r providers.ReadStateBytesRequest) (resp p
15151515
for {
15161516
chunk, err := client.Recv()
15171517
if err == io.EOF {
1518-
// End of stream, we're done
1519-
if chunk != nil {
1520-
// TODO: The EOF error could be just returned alongside the last chunk?
1521-
resp.Diagnostics = resp.Diagnostics.Append(convert.ProtoToDiagnostics(chunk.Diagnostics))
1522-
}
1518+
// No chunk is returned alongside an EOF.
1519+
// And as EOF is a sentinel error it isn't wrapped as a diagnostic.
15231520
break
15241521
}
15251522
if err != nil {
@@ -1545,13 +1542,29 @@ func (p *GRPCProvider) ReadStateBytes(r providers.ReadStateBytesRequest) (resp p
15451542
logger.Trace("GRPCProvider.v6: ReadStateBytes: read bytes of a chunk", n)
15461543
}
15471544

1548-
logger.Trace("GRPCProvider.v6: ReadStateBytes: received all chunks", buf.Len())
1545+
if resp.Diagnostics.HasErrors() {
1546+
logger.Trace("GRPCProvider.v6: ReadStateBytes: experienced an error when receiving state data from the provider", resp.Diagnostics.Err())
1547+
return resp
1548+
}
1549+
15491550
if buf.Len() != expectedTotalLength {
1551+
logger.Trace("GRPCProvider.v6: ReadStateBytes: received %d bytes but expected the total bytes to be %d.", buf.Len(), expectedTotalLength)
1552+
15501553
err = fmt.Errorf("expected state file of total %d bytes, received %d bytes",
15511554
expectedTotalLength, buf.Len())
15521555
resp.Diagnostics = resp.Diagnostics.Append(err)
15531556
return resp
15541557
}
1558+
1559+
// We're done, so close the stream
1560+
logger.Trace("GRPCProvider.v6: ReadStateBytes: received all chunks, total bytes: ", buf.Len())
1561+
err = client.CloseSend()
1562+
if err != nil {
1563+
resp.Diagnostics = resp.Diagnostics.Append(grpcErr(err))
1564+
return resp
1565+
}
1566+
1567+
// Add the state data in the response once we know there are no errors
15551568
resp.Bytes = buf.Bytes()
15561569

15571570
return resp

internal/plugin6/grpc_provider_test.go

Lines changed: 87 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -3556,6 +3556,9 @@ func TestGRPCProvider_ReadStateBytes(t *testing.T) {
35563556
return ret.resp, ret.err
35573557
}).Times(3)
35583558

3559+
// There will be a call to CloseSend to close the stream
3560+
mockReadBytesClient.EXPECT().CloseSend().Return(nil).Times(1)
3561+
35593562
// Act
35603563
request := providers.ReadStateBytesRequest{
35613564
TypeName: expectedReq.TypeName,
@@ -3704,7 +3707,7 @@ func TestGRPCProvider_ReadStateBytes(t *testing.T) {
37043707
// Define what will be returned by each call to Recv
37053708
mockReadBytesClient.EXPECT().Recv().Return(&proto.ReadStateBytes_Response{
37063709
Diagnostics: []*proto.Diagnostic{
3707-
&proto.Diagnostic{
3710+
{
37083711
Severity: proto.Diagnostic_ERROR,
37093712
Summary: "Error from test",
37103713
Detail: "This error is forced by the test case",
@@ -3752,15 +3755,44 @@ func TestGRPCProvider_ReadStateBytes(t *testing.T) {
37523755
).Return(mockReadBytesClient, nil)
37533756

37543757
// Define what will be returned by each call to Recv
3755-
mockReadBytesClient.EXPECT().Recv().Return(&proto.ReadStateBytes_Response{
3756-
Diagnostics: []*proto.Diagnostic{
3757-
&proto.Diagnostic{
3758-
Severity: proto.Diagnostic_WARNING,
3759-
Summary: "Warning from test",
3760-
Detail: "This warning is forced by the test case",
3758+
chunk := "hello world"
3759+
totalLength := len(chunk)
3760+
mockResp := map[int]struct {
3761+
resp *proto.ReadStateBytes_Response
3762+
err error
3763+
}{
3764+
0: {
3765+
resp: &proto.ReadStateBytes_Response{
3766+
Bytes: []byte(chunk),
3767+
TotalLength: int64(totalLength),
3768+
Range: &proto.StateRange{
3769+
Start: 0,
3770+
End: int64(len(chunk)),
3771+
},
3772+
Diagnostics: []*proto.Diagnostic{
3773+
{
3774+
Severity: proto.Diagnostic_WARNING,
3775+
Summary: "Warning from test",
3776+
Detail: "This warning is forced by the test case",
3777+
},
3778+
},
37613779
},
3780+
err: nil,
3781+
},
3782+
1: {
3783+
resp: &proto.ReadStateBytes_Response{},
3784+
err: io.EOF,
37623785
},
3763-
}, io.EOF)
3786+
}
3787+
var count int
3788+
mockReadBytesClient.EXPECT().Recv().DoAndReturn(func() (*proto.ReadStateBytes_Response, error) {
3789+
ret := mockResp[count]
3790+
count++
3791+
return ret.resp, ret.err
3792+
}).Times(2)
3793+
3794+
// There will be a call to CloseSend to close the stream
3795+
mockReadBytesClient.EXPECT().CloseSend().Return(nil).Times(1)
37643796

37653797
// Act
37663798
request := providers.ReadStateBytesRequest{
@@ -3775,8 +3807,8 @@ func TestGRPCProvider_ReadStateBytes(t *testing.T) {
37753807
if resp.Diagnostics.ErrWithWarnings().Error() != expectedWarn {
37763808
t.Fatalf("expected warning diagnostic %q, but got: %q", expectedWarn, resp.Diagnostics.ErrWithWarnings().Error())
37773809
}
3778-
if len(resp.Bytes) != 0 {
3779-
t.Fatalf("expected data to be omitted in error condition, but got: %q", string(resp.Bytes))
3810+
if len(resp.Bytes) == 0 {
3811+
t.Fatal("expected data to included despite warnings, but got no bytes")
37803812
}
37813813
})
37823814

@@ -3820,6 +3852,51 @@ func TestGRPCProvider_ReadStateBytes(t *testing.T) {
38203852
t.Fatalf("expected data to be omitted in error condition, but got: %q", string(resp.Bytes))
38213853
}
38223854
})
3855+
3856+
t.Run("when closing the stream, grpc errors are surfaced via diagnostics", func(t *testing.T) {
3857+
client := mockProviderClient(t)
3858+
p := &GRPCProvider{
3859+
client: client,
3860+
ctx: context.Background(),
3861+
}
3862+
3863+
// Call to ReadStateBytes
3864+
// > Assert the arguments received
3865+
// > Define the returned mock client
3866+
mockClient := mockReadStateBytesClient(t)
3867+
expectedReq := &proto.ReadStateBytes_Request{
3868+
TypeName: "mock_store",
3869+
StateId: backend.DefaultStateName,
3870+
}
3871+
client.EXPECT().ReadStateBytes(
3872+
gomock.Any(),
3873+
gomock.Eq(expectedReq),
3874+
).Return(mockClient, nil)
3875+
3876+
// Sufficient mocking of Recv to get to the call to CloseSend
3877+
mockClient.EXPECT().Recv().Return(&proto.ReadStateBytes_Response{}, io.EOF)
3878+
3879+
// Force a gRPC error from CloseSend
3880+
mockError := errors.New("grpc error forced in test")
3881+
mockClient.EXPECT().CloseSend().Return(mockError).Times(1)
3882+
3883+
// Act
3884+
request := providers.ReadStateBytesRequest{
3885+
TypeName: expectedReq.TypeName,
3886+
StateId: expectedReq.StateId,
3887+
}
3888+
resp := p.ReadStateBytes(request)
3889+
3890+
// Assert returned values
3891+
checkDiagsHasError(t, resp.Diagnostics)
3892+
wantErr := fmt.Sprintf("Plugin error: The plugin returned an unexpected error from plugin6.(*GRPCProvider).ReadStateBytes: %s", mockError)
3893+
if resp.Diagnostics.Err().Error() != wantErr {
3894+
t.Fatalf("expected error diagnostic %q, but got: %q", wantErr, resp.Diagnostics.Err())
3895+
}
3896+
if len(resp.Bytes) != 0 {
3897+
t.Fatalf("expected data to be omitted in error condition, but got: %q", string(resp.Bytes))
3898+
}
3899+
})
38233900
}
38243901

38253902
func TestGRPCProvider_WriteStateBytes(t *testing.T) {

0 commit comments

Comments
 (0)