Skip to content

Commit d7a4964

Browse files
SarahFrenchradeksimko
authored andcommitted
Sarah's updates to radek/pss-read-write (#37642)
* Update code to not expect a chunk when EOF encountered * Return early if any grpc errors are encountered during ReadStateBytes * Close the stream with CloseSend once everything's read without error. Add test case about handling grpc errors from CloseSend. * Fix test case about warnings: We would expect to receive a chunk with data alongside the warning and have a normal closing of the stream after EOF * Add log line, remove unneeded type info
1 parent 37d7c38 commit d7a4964

File tree

2 files changed

+106
-16
lines changed

2 files changed

+106
-16
lines changed

internal/plugin6/grpc_provider.go

Lines changed: 19 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1555,11 +1555,8 @@ func (p *GRPCProvider) ReadStateBytes(r providers.ReadStateBytesRequest) (resp p
15551555
for {
15561556
chunk, err := client.Recv()
15571557
if err == io.EOF {
1558-
// End of stream, we're done
1559-
if chunk != nil {
1560-
// TODO: The EOF error could be just returned alongside the last chunk?
1561-
resp.Diagnostics = resp.Diagnostics.Append(convert.ProtoToDiagnostics(chunk.Diagnostics))
1562-
}
1558+
// No chunk is returned alongside an EOF.
1559+
// And as EOF is a sentinel error it isn't wrapped as a diagnostic.
15631560
break
15641561
}
15651562
if err != nil {
@@ -1585,13 +1582,29 @@ func (p *GRPCProvider) ReadStateBytes(r providers.ReadStateBytesRequest) (resp p
15851582
logger.Trace("GRPCProvider.v6: ReadStateBytes: read bytes of a chunk", n)
15861583
}
15871584

1588-
logger.Trace("GRPCProvider.v6: ReadStateBytes: received all chunks", buf.Len())
1585+
if resp.Diagnostics.HasErrors() {
1586+
logger.Trace("GRPCProvider.v6: ReadStateBytes: experienced an error when receiving state data from the provider", resp.Diagnostics.Err())
1587+
return resp
1588+
}
1589+
15891590
if buf.Len() != expectedTotalLength {
1591+
logger.Trace("GRPCProvider.v6: ReadStateBytes: received %d bytes but expected the total bytes to be %d.", buf.Len(), expectedTotalLength)
1592+
15901593
err = fmt.Errorf("expected state file of total %d bytes, received %d bytes",
15911594
expectedTotalLength, buf.Len())
15921595
resp.Diagnostics = resp.Diagnostics.Append(err)
15931596
return resp
15941597
}
1598+
1599+
// We're done, so close the stream
1600+
logger.Trace("GRPCProvider.v6: ReadStateBytes: received all chunks, total bytes: ", buf.Len())
1601+
err = client.CloseSend()
1602+
if err != nil {
1603+
resp.Diagnostics = resp.Diagnostics.Append(grpcErr(err))
1604+
return resp
1605+
}
1606+
1607+
// Add the state data in the response once we know there are no errors
15951608
resp.Bytes = buf.Bytes()
15961609

15971610
return resp

internal/plugin6/grpc_provider_test.go

Lines changed: 87 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -2624,6 +2624,9 @@ func TestGRPCProvider_ReadStateBytes(t *testing.T) {
26242624
return ret.resp, ret.err
26252625
}).Times(3)
26262626

2627+
// There will be a call to CloseSend to close the stream
2628+
mockReadBytesClient.EXPECT().CloseSend().Return(nil).Times(1)
2629+
26272630
// Act
26282631
request := providers.ReadStateBytesRequest{
26292632
TypeName: expectedReq.TypeName,
@@ -2772,7 +2775,7 @@ func TestGRPCProvider_ReadStateBytes(t *testing.T) {
27722775
// Define what will be returned by each call to Recv
27732776
mockReadBytesClient.EXPECT().Recv().Return(&proto.ReadStateBytes_Response{
27742777
Diagnostics: []*proto.Diagnostic{
2775-
&proto.Diagnostic{
2778+
{
27762779
Severity: proto.Diagnostic_ERROR,
27772780
Summary: "Error from test",
27782781
Detail: "This error is forced by the test case",
@@ -2820,15 +2823,44 @@ func TestGRPCProvider_ReadStateBytes(t *testing.T) {
28202823
).Return(mockReadBytesClient, nil)
28212824

28222825
// Define what will be returned by each call to Recv
2823-
mockReadBytesClient.EXPECT().Recv().Return(&proto.ReadStateBytes_Response{
2824-
Diagnostics: []*proto.Diagnostic{
2825-
&proto.Diagnostic{
2826-
Severity: proto.Diagnostic_WARNING,
2827-
Summary: "Warning from test",
2828-
Detail: "This warning is forced by the test case",
2826+
chunk := "hello world"
2827+
totalLength := len(chunk)
2828+
mockResp := map[int]struct {
2829+
resp *proto.ReadStateBytes_Response
2830+
err error
2831+
}{
2832+
0: {
2833+
resp: &proto.ReadStateBytes_Response{
2834+
Bytes: []byte(chunk),
2835+
TotalLength: int64(totalLength),
2836+
Range: &proto.StateRange{
2837+
Start: 0,
2838+
End: int64(len(chunk)),
2839+
},
2840+
Diagnostics: []*proto.Diagnostic{
2841+
{
2842+
Severity: proto.Diagnostic_WARNING,
2843+
Summary: "Warning from test",
2844+
Detail: "This warning is forced by the test case",
2845+
},
2846+
},
28292847
},
2848+
err: nil,
2849+
},
2850+
1: {
2851+
resp: &proto.ReadStateBytes_Response{},
2852+
err: io.EOF,
28302853
},
2831-
}, io.EOF)
2854+
}
2855+
var count int
2856+
mockReadBytesClient.EXPECT().Recv().DoAndReturn(func() (*proto.ReadStateBytes_Response, error) {
2857+
ret := mockResp[count]
2858+
count++
2859+
return ret.resp, ret.err
2860+
}).Times(2)
2861+
2862+
// There will be a call to CloseSend to close the stream
2863+
mockReadBytesClient.EXPECT().CloseSend().Return(nil).Times(1)
28322864

28332865
// Act
28342866
request := providers.ReadStateBytesRequest{
@@ -2843,8 +2875,8 @@ func TestGRPCProvider_ReadStateBytes(t *testing.T) {
28432875
if resp.Diagnostics.ErrWithWarnings().Error() != expectedWarn {
28442876
t.Fatalf("expected warning diagnostic %q, but got: %q", expectedWarn, resp.Diagnostics.ErrWithWarnings().Error())
28452877
}
2846-
if len(resp.Bytes) != 0 {
2847-
t.Fatalf("expected data to be omitted in error condition, but got: %q", string(resp.Bytes))
2878+
if len(resp.Bytes) == 0 {
2879+
t.Fatal("expected data to included despite warnings, but got no bytes")
28482880
}
28492881
})
28502882

@@ -2888,6 +2920,51 @@ func TestGRPCProvider_ReadStateBytes(t *testing.T) {
28882920
t.Fatalf("expected data to be omitted in error condition, but got: %q", string(resp.Bytes))
28892921
}
28902922
})
2923+
2924+
t.Run("when closing the stream, grpc errors are surfaced via diagnostics", func(t *testing.T) {
2925+
client := mockProviderClient(t)
2926+
p := &GRPCProvider{
2927+
client: client,
2928+
ctx: context.Background(),
2929+
}
2930+
2931+
// Call to ReadStateBytes
2932+
// > Assert the arguments received
2933+
// > Define the returned mock client
2934+
mockClient := mockReadStateBytesClient(t)
2935+
expectedReq := &proto.ReadStateBytes_Request{
2936+
TypeName: "mock_store",
2937+
StateId: backend.DefaultStateName,
2938+
}
2939+
client.EXPECT().ReadStateBytes(
2940+
gomock.Any(),
2941+
gomock.Eq(expectedReq),
2942+
).Return(mockClient, nil)
2943+
2944+
// Sufficient mocking of Recv to get to the call to CloseSend
2945+
mockClient.EXPECT().Recv().Return(&proto.ReadStateBytes_Response{}, io.EOF)
2946+
2947+
// Force a gRPC error from CloseSend
2948+
mockError := errors.New("grpc error forced in test")
2949+
mockClient.EXPECT().CloseSend().Return(mockError).Times(1)
2950+
2951+
// Act
2952+
request := providers.ReadStateBytesRequest{
2953+
TypeName: expectedReq.TypeName,
2954+
StateId: expectedReq.StateId,
2955+
}
2956+
resp := p.ReadStateBytes(request)
2957+
2958+
// Assert returned values
2959+
checkDiagsHasError(t, resp.Diagnostics)
2960+
wantErr := fmt.Sprintf("Plugin error: The plugin returned an unexpected error from plugin6.(*GRPCProvider).ReadStateBytes: %s", mockError)
2961+
if resp.Diagnostics.Err().Error() != wantErr {
2962+
t.Fatalf("expected error diagnostic %q, but got: %q", wantErr, resp.Diagnostics.Err())
2963+
}
2964+
if len(resp.Bytes) != 0 {
2965+
t.Fatalf("expected data to be omitted in error condition, but got: %q", string(resp.Bytes))
2966+
}
2967+
})
28912968
}
28922969

28932970
func TestGRPCProvider_WriteStateBytes(t *testing.T) {

0 commit comments

Comments
 (0)