Skip to content

Commit 64eff22

Browse files
authored
fix: handle errors which occur prior to stream initiation (#57)
* chore: renaming fixtures and making tests more specific Signed-off-by: Danny Kopping <[email protected]> * chore: add test for upstream errs that occur before stream starts (anthropic only) Signed-off-by: Danny Kopping <[email protected]> * chore: implement openai error handling Signed-off-by: Danny Kopping <[email protected]> * chore: self-review Signed-off-by: Danny Kopping <[email protected]> * chore: refactor away from atomic to compound mutex + bool Signed-off-by: Danny Kopping <[email protected]> * chore: drive-by flake fix Signed-off-by: Danny Kopping <[email protected]> * chore: simplify approach Signed-off-by: Danny Kopping <[email protected]> * chore: fix flake due to order of operations Signed-off-by: Danny Kopping <[email protected]> * chore: fixing race Signed-off-by: Danny Kopping <[email protected]> * chore: headers drive-by race fix Signed-off-by: Danny Kopping <[email protected]> --------- Signed-off-by: Danny Kopping <[email protected]>
1 parent 471c745 commit 64eff22

15 files changed

+505
-244
lines changed

bridge_integration_test.go

Lines changed: 218 additions & 93 deletions
Original file line numberDiff line numberDiff line change
@@ -30,6 +30,7 @@ import (
3030
"github.com/google/uuid"
3131
"github.com/stretchr/testify/assert"
3232
"github.com/stretchr/testify/require"
33+
"github.com/tidwall/gjson"
3334

3435
"github.com/openai/openai-go/v2"
3536
oaissestream "github.com/openai/openai-go/v2/packages/ssestream"
@@ -47,8 +48,10 @@ var (
4748
antSingleInjectedTool []byte
4849
//go:embed fixtures/anthropic/fallthrough.txtar
4950
antFallthrough []byte
50-
//go:embed fixtures/anthropic/error.txtar
51-
antErr []byte
51+
//go:embed fixtures/anthropic/stream_error.txtar
52+
antMidStreamErr []byte
53+
//go:embed fixtures/anthropic/non_stream_error.txtar
54+
antNonStreamErr []byte
5255

5356
//go:embed fixtures/openai/simple.txtar
5457
oaiSimple []byte
@@ -58,8 +61,10 @@ var (
5861
oaiSingleInjectedTool []byte
5962
//go:embed fixtures/openai/fallthrough.txtar
6063
oaiFallthrough []byte
61-
//go:embed fixtures/openai/error.txtar
62-
oaiErr []byte
64+
//go:embed fixtures/openai/stream_error.txtar
65+
oaiMidStreamErr []byte
66+
//go:embed fixtures/openai/non_stream_error.txtar
67+
oaiNonStreamErr []byte
6368
)
6469

6570
const (
@@ -676,11 +681,11 @@ func TestFallthrough(t *testing.T) {
676681
t.FailNow()
677682
}
678683

684+
receivedHeaders = &r.Header
685+
679686
w.Header().Set("Content-Type", "application/json")
680687
w.WriteHeader(http.StatusOK)
681688
_, _ = w.Write(respBody)
682-
683-
receivedHeaders = &r.Header
684689
}))
685690
t.Cleanup(upstream.Close)
686691

@@ -1009,48 +1014,147 @@ func setupInjectedToolTest(t *testing.T, fixture []byte, streaming bool, configu
10091014
func TestErrorHandling(t *testing.T) {
10101015
t.Parallel()
10111016

1012-
cases := []struct {
1013-
name string
1014-
fixture []byte
1015-
createRequestFunc createRequestFunc
1016-
configureFunc configureFunc
1017-
responseHandlerFn func(streaming bool, resp *http.Response)
1018-
}{
1019-
{
1020-
name: aibridge.ProviderAnthropic,
1021-
fixture: antErr,
1022-
createRequestFunc: createAnthropicMessagesReq,
1023-
configureFunc: func(addr string, client aibridge.Recorder, srvProxyMgr *mcp.ServerProxyManager) (*aibridge.RequestBridge, error) {
1024-
logger := slogtest.Make(t, &slogtest.Options{IgnoreErrors: false}).Leveled(slog.LevelDebug)
1025-
return aibridge.NewRequestBridge(t.Context(), []aibridge.Provider{aibridge.NewAnthropicProvider(anthropicCfg(addr, apiKey), nil)}, logger, client, srvProxyMgr)
1017+
// Tests that errors which occur *before* a streaming response begins, or in non-streaming requests, are handled as expected.
1018+
t.Run("non-stream error", func(t *testing.T) {
1019+
cases := []struct {
1020+
name string
1021+
fixture []byte
1022+
createRequestFunc createRequestFunc
1023+
configureFunc configureFunc
1024+
responseHandlerFn func(resp *http.Response)
1025+
}{
1026+
{
1027+
name: aibridge.ProviderAnthropic,
1028+
fixture: antNonStreamErr,
1029+
createRequestFunc: createAnthropicMessagesReq,
1030+
configureFunc: func(addr string, client aibridge.Recorder, srvProxyMgr *mcp.ServerProxyManager) (*aibridge.RequestBridge, error) {
1031+
logger := slogtest.Make(t, &slogtest.Options{IgnoreErrors: false}).Leveled(slog.LevelDebug)
1032+
return aibridge.NewRequestBridge(t.Context(), []aibridge.Provider{aibridge.NewAnthropicProvider(anthropicCfg(addr, apiKey), nil)}, logger, client, srvProxyMgr)
1033+
},
1034+
responseHandlerFn: func(resp *http.Response) {
1035+
require.Equal(t, http.StatusBadRequest, resp.StatusCode)
1036+
body, err := io.ReadAll(resp.Body)
1037+
require.NoError(t, err)
1038+
require.Equal(t, "error", gjson.GetBytes(body, "type").Str)
1039+
require.Equal(t, "invalid_request_error", gjson.GetBytes(body, "error.type").Str)
1040+
require.Contains(t, gjson.GetBytes(body, "error.message").Str, "prompt is too long")
1041+
},
10261042
},
1027-
responseHandlerFn: func(streaming bool, resp *http.Response) {
1028-
if streaming {
1043+
{
1044+
name: aibridge.ProviderOpenAI,
1045+
fixture: oaiNonStreamErr,
1046+
createRequestFunc: createOpenAIChatCompletionsReq,
1047+
configureFunc: func(addr string, client aibridge.Recorder, srvProxyMgr *mcp.ServerProxyManager) (*aibridge.RequestBridge, error) {
1048+
logger := slogtest.Make(t, &slogtest.Options{IgnoreErrors: false}).Leveled(slog.LevelDebug)
1049+
return aibridge.NewRequestBridge(t.Context(), []aibridge.Provider{aibridge.NewOpenAIProvider(aibridge.OpenAIConfig(anthropicCfg(addr, apiKey)))}, logger, client, srvProxyMgr)
1050+
},
1051+
responseHandlerFn: func(resp *http.Response) {
1052+
require.Equal(t, http.StatusBadRequest, resp.StatusCode)
1053+
body, err := io.ReadAll(resp.Body)
1054+
require.NoError(t, err)
1055+
require.Equal(t, "context_length_exceeded", gjson.GetBytes(body, "error.code").Str)
1056+
require.Equal(t, "invalid_request_error", gjson.GetBytes(body, "error.type").Str)
1057+
require.Contains(t, gjson.GetBytes(body, "error.message").Str, "Input tokens exceed the configured limit")
1058+
},
1059+
},
1060+
}
1061+
1062+
for _, tc := range cases {
1063+
t.Run(tc.name, func(t *testing.T) {
1064+
t.Parallel()
1065+
1066+
for _, streaming := range []bool{true, false} {
1067+
t.Run(fmt.Sprintf("streaming=%v", streaming), func(t *testing.T) {
1068+
t.Parallel()
1069+
1070+
ctx, cancel := context.WithTimeout(t.Context(), time.Second*30)
1071+
t.Cleanup(cancel)
1072+
1073+
arc := txtar.Parse(tc.fixture)
1074+
t.Logf("%s: %s", t.Name(), arc.Comment)
1075+
1076+
files := filesMap(arc)
1077+
require.Len(t, files, 3)
1078+
require.Contains(t, files, fixtureRequest)
1079+
require.Contains(t, files, fixtureStreamingResponse)
1080+
require.Contains(t, files, fixtureNonStreamingResponse)
1081+
1082+
reqBody := files[fixtureRequest]
1083+
// Add the stream param to the request.
1084+
newBody, err := setJSON(reqBody, "stream", streaming)
1085+
require.NoError(t, err)
1086+
reqBody = newBody
1087+
1088+
// Setup mock server.
1089+
mockResp := files[fixtureStreamingResponse]
1090+
if !streaming {
1091+
mockResp = files[fixtureNonStreamingResponse]
1092+
}
1093+
mockSrv := newMockHTTPReflector(ctx, t, mockResp)
1094+
t.Cleanup(mockSrv.Close)
1095+
1096+
recorderClient := &mockRecorderClient{}
1097+
1098+
b, err := tc.configureFunc(mockSrv.URL, recorderClient, mcp.NewServerProxyManager(nil))
1099+
require.NoError(t, err)
1100+
1101+
// Invoke request to mocked API via aibridge.
1102+
bridgeSrv := httptest.NewUnstartedServer(b)
1103+
bridgeSrv.Config.BaseContext = func(_ net.Listener) context.Context {
1104+
return aibridge.AsActor(ctx, userID, nil)
1105+
}
1106+
bridgeSrv.Start()
1107+
t.Cleanup(bridgeSrv.Close)
1108+
1109+
req := tc.createRequestFunc(t, bridgeSrv.URL, reqBody)
1110+
resp, err := http.DefaultClient.Do(req)
1111+
t.Cleanup(func() { _ = resp.Body.Close() })
1112+
require.NoError(t, err)
1113+
1114+
tc.responseHandlerFn(resp)
1115+
recorderClient.verifyAllInterceptionsEnded(t)
1116+
})
1117+
}
1118+
})
1119+
}
1120+
})
1121+
1122+
// Tests that errors which occur *during* a streaming response are handled as expected.
1123+
t.Run("mid-stream error", func(t *testing.T) {
1124+
cases := []struct {
1125+
name string
1126+
fixture []byte
1127+
createRequestFunc createRequestFunc
1128+
configureFunc configureFunc
1129+
responseHandlerFn func(resp *http.Response)
1130+
}{
1131+
{
1132+
name: aibridge.ProviderAnthropic,
1133+
fixture: antMidStreamErr,
1134+
createRequestFunc: createAnthropicMessagesReq,
1135+
configureFunc: func(addr string, client aibridge.Recorder, srvProxyMgr *mcp.ServerProxyManager) (*aibridge.RequestBridge, error) {
1136+
logger := slogtest.Make(t, &slogtest.Options{IgnoreErrors: false}).Leveled(slog.LevelDebug)
1137+
return aibridge.NewRequestBridge(t.Context(), []aibridge.Provider{aibridge.NewAnthropicProvider(anthropicCfg(addr, apiKey), nil)}, logger, client, srvProxyMgr)
1138+
},
1139+
responseHandlerFn: func(resp *http.Response) {
10291140
// Server responds first with 200 OK then starts streaming.
10301141
require.Equal(t, http.StatusOK, resp.StatusCode)
10311142

10321143
sp := aibridge.NewSSEParser()
10331144
require.NoError(t, sp.Parse(resp.Body))
10341145
require.Len(t, sp.EventsByType("error"), 1)
10351146
require.Contains(t, sp.EventsByType("error")[0].Data, "Overloaded")
1036-
} else {
1037-
require.Equal(t, resp.StatusCode, http.StatusInternalServerError)
1038-
body, err := io.ReadAll(resp.Body)
1039-
require.NoError(t, err)
1040-
require.Contains(t, string(body), "Overloaded")
1041-
}
1147+
},
10421148
},
1043-
},
1044-
{
1045-
name: aibridge.ProviderOpenAI,
1046-
fixture: oaiErr,
1047-
createRequestFunc: createOpenAIChatCompletionsReq,
1048-
configureFunc: func(addr string, client aibridge.Recorder, srvProxyMgr *mcp.ServerProxyManager) (*aibridge.RequestBridge, error) {
1049-
logger := slogtest.Make(t, &slogtest.Options{IgnoreErrors: false}).Leveled(slog.LevelDebug)
1050-
return aibridge.NewRequestBridge(t.Context(), []aibridge.Provider{aibridge.NewOpenAIProvider(aibridge.OpenAIConfig(anthropicCfg(addr, apiKey)))}, logger, client, srvProxyMgr)
1051-
},
1052-
responseHandlerFn: func(streaming bool, resp *http.Response) {
1053-
if streaming {
1149+
{
1150+
name: aibridge.ProviderOpenAI,
1151+
fixture: oaiMidStreamErr,
1152+
createRequestFunc: createOpenAIChatCompletionsReq,
1153+
configureFunc: func(addr string, client aibridge.Recorder, srvProxyMgr *mcp.ServerProxyManager) (*aibridge.RequestBridge, error) {
1154+
logger := slogtest.Make(t, &slogtest.Options{IgnoreErrors: false}).Leveled(slog.LevelDebug)
1155+
return aibridge.NewRequestBridge(t.Context(), []aibridge.Provider{aibridge.NewOpenAIProvider(aibridge.OpenAIConfig(anthropicCfg(addr, apiKey)))}, logger, client, srvProxyMgr)
1156+
},
1157+
responseHandlerFn: func(resp *http.Response) {
10541158
// Server responds first with 200 OK then starts streaming.
10551159
require.Equal(t, http.StatusOK, resp.StatusCode)
10561160

@@ -1063,72 +1167,55 @@ func TestErrorHandling(t *testing.T) {
10631167
errEvent := sp.MessageEvents()[len(sp.MessageEvents())-2] // Last event is termination marker ("[DONE]").
10641168
require.NotEmpty(t, errEvent)
10651169
require.Contains(t, errEvent.Data, "The server had an error while processing your request. Sorry about that!")
1066-
} else {
1067-
require.Equal(t, resp.StatusCode, http.StatusInternalServerError)
1068-
body, err := io.ReadAll(resp.Body)
1069-
require.NoError(t, err)
1070-
require.Contains(t, string(body), "The server had an error while processing your request. Sorry about that")
1071-
}
1170+
},
10721171
},
1073-
},
1074-
}
1075-
1076-
for _, tc := range cases {
1077-
t.Run(tc.name, func(t *testing.T) {
1078-
t.Parallel()
1079-
1080-
for _, streaming := range []bool{true, false} {
1081-
t.Run(fmt.Sprintf("streaming=%v", streaming), func(t *testing.T) {
1082-
t.Parallel()
1172+
}
10831173

1084-
ctx, cancel := context.WithTimeout(t.Context(), time.Second*30)
1085-
t.Cleanup(cancel)
1174+
for _, tc := range cases {
1175+
t.Run(tc.name, func(t *testing.T) {
1176+
t.Parallel()
10861177

1087-
arc := txtar.Parse(tc.fixture)
1088-
t.Logf("%s: %s", t.Name(), arc.Comment)
1178+
ctx, cancel := context.WithTimeout(t.Context(), time.Second*30)
1179+
t.Cleanup(cancel)
10891180

1090-
files := filesMap(arc)
1091-
require.Len(t, files, 3)
1092-
require.Contains(t, files, fixtureRequest)
1093-
require.Contains(t, files, fixtureStreamingResponse)
1094-
require.Contains(t, files, fixtureNonStreamingResponse)
1181+
arc := txtar.Parse(tc.fixture)
1182+
t.Logf("%s: %s", t.Name(), arc.Comment)
10951183

1096-
reqBody := files[fixtureRequest]
1184+
files := filesMap(arc)
1185+
require.Len(t, files, 2)
1186+
require.Contains(t, files, fixtureRequest)
1187+
require.Contains(t, files, fixtureStreamingResponse)
10971188

1098-
// Add the stream param to the request.
1099-
newBody, err := setJSON(reqBody, "stream", streaming)
1100-
require.NoError(t, err)
1101-
reqBody = newBody
1189+
reqBody := files[fixtureRequest]
11021190

1103-
// Setup mock server.
1104-
mockSrv := newMockServer(ctx, t, files, nil)
1105-
mockSrv.statusCode = http.StatusInternalServerError
1106-
t.Cleanup(mockSrv.Close)
1191+
// Setup mock server.
1192+
mockSrv := newMockServer(ctx, t, files, nil)
1193+
mockSrv.statusCode = http.StatusInternalServerError
1194+
t.Cleanup(mockSrv.Close)
11071195

1108-
recorderClient := &mockRecorderClient{}
1196+
recorderClient := &mockRecorderClient{}
11091197

1110-
b, err := tc.configureFunc(mockSrv.URL, recorderClient, mcp.NewServerProxyManager(nil))
1111-
require.NoError(t, err)
1198+
b, err := tc.configureFunc(mockSrv.URL, recorderClient, mcp.NewServerProxyManager(nil))
1199+
require.NoError(t, err)
11121200

1113-
// Invoke request to mocked API via aibridge.
1114-
bridgeSrv := httptest.NewUnstartedServer(b)
1115-
bridgeSrv.Config.BaseContext = func(_ net.Listener) context.Context {
1116-
return aibridge.AsActor(ctx, userID, nil)
1117-
}
1118-
bridgeSrv.Start()
1119-
t.Cleanup(bridgeSrv.Close)
1201+
// Invoke request to mocked API via aibridge.
1202+
bridgeSrv := httptest.NewUnstartedServer(b)
1203+
bridgeSrv.Config.BaseContext = func(_ net.Listener) context.Context {
1204+
return aibridge.AsActor(ctx, userID, nil)
1205+
}
1206+
bridgeSrv.Start()
1207+
t.Cleanup(bridgeSrv.Close)
11201208

1121-
req := tc.createRequestFunc(t, bridgeSrv.URL, reqBody)
1122-
resp, err := http.DefaultClient.Do(req)
1123-
t.Cleanup(func() { _ = resp.Body.Close() })
1124-
require.NoError(t, err)
1209+
req := tc.createRequestFunc(t, bridgeSrv.URL, reqBody)
1210+
resp, err := http.DefaultClient.Do(req)
1211+
t.Cleanup(func() { _ = resp.Body.Close() })
1212+
require.NoError(t, err)
11251213

1126-
tc.responseHandlerFn(streaming, resp)
1127-
recorderClient.verifyAllInterceptionsEnded(t)
1128-
})
1129-
}
1130-
})
1131-
}
1214+
tc.responseHandlerFn(resp)
1215+
recorderClient.verifyAllInterceptionsEnded(t)
1216+
})
1217+
}
1218+
})
11321219
}
11331220

11341221
// TestStableRequestEncoding validates that a given intercepted request and a
@@ -1297,6 +1384,44 @@ func createOpenAIChatCompletionsReq(t *testing.T, baseURL string, input []byte)
12971384
return req
12981385
}
12991386

1387+
type mockHTTPReflector struct {
1388+
*httptest.Server
1389+
}
1390+
1391+
func newMockHTTPReflector(ctx context.Context, t *testing.T, resp []byte) *mockHTTPReflector {
1392+
ref := &mockHTTPReflector{}
1393+
1394+
srv := httptest.NewUnstartedServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
1395+
mock, err := http.ReadResponse(bufio.NewReader(bytes.NewBuffer(resp)), r)
1396+
require.NoError(t, err)
1397+
defer mock.Body.Close()
1398+
1399+
// Copy headers from the mocked response.
1400+
for key, values := range mock.Header {
1401+
for _, value := range values {
1402+
w.Header().Add(key, value)
1403+
}
1404+
}
1405+
1406+
// Write the status code.
1407+
w.WriteHeader(mock.StatusCode)
1408+
1409+
// Copy the body.
1410+
_, err = io.Copy(w, mock.Body)
1411+
require.NoError(t, err)
1412+
}))
1413+
srv.Config.BaseContext = func(_ net.Listener) context.Context {
1414+
return ctx
1415+
}
1416+
1417+
srv.Start()
1418+
t.Cleanup(srv.Close)
1419+
1420+
ref.Server = srv
1421+
return ref
1422+
}
1423+
1424+
// TODO: replace this with mockHTTPReflector.
13001425
type mockServer struct {
13011426
*httptest.Server
13021427

0 commit comments

Comments
 (0)