@@ -30,6 +30,7 @@ import (
3030 "github.com/google/uuid"
3131 "github.com/stretchr/testify/assert"
3232 "github.com/stretchr/testify/require"
33+ "github.com/tidwall/gjson"
3334
3435 "github.com/openai/openai-go/v2"
3536 oaissestream "github.com/openai/openai-go/v2/packages/ssestream"
4748 antSingleInjectedTool []byte
4849 //go:embed fixtures/anthropic/fallthrough.txtar
4950 antFallthrough []byte
50- //go:embed fixtures/anthropic/error.txtar
51- antErr []byte
51+ //go:embed fixtures/anthropic/stream_error.txtar
52+ antMidStreamErr []byte
53+ //go:embed fixtures/anthropic/non_stream_error.txtar
54+ antNonStreamErr []byte
5255
5356 //go:embed fixtures/openai/simple.txtar
5457 oaiSimple []byte
5861 oaiSingleInjectedTool []byte
5962 //go:embed fixtures/openai/fallthrough.txtar
6063 oaiFallthrough []byte
61- //go:embed fixtures/openai/error.txtar
62- oaiErr []byte
64+ //go:embed fixtures/openai/stream_error.txtar
65+ oaiMidStreamErr []byte
66+ //go:embed fixtures/openai/non_stream_error.txtar
67+ oaiNonStreamErr []byte
6368)
6469
6570const (
@@ -676,11 +681,11 @@ func TestFallthrough(t *testing.T) {
676681 t .FailNow ()
677682 }
678683
684+ receivedHeaders = & r .Header
685+
679686 w .Header ().Set ("Content-Type" , "application/json" )
680687 w .WriteHeader (http .StatusOK )
681688 _ , _ = w .Write (respBody )
682-
683- receivedHeaders = & r .Header
684689 }))
685690 t .Cleanup (upstream .Close )
686691
@@ -1009,48 +1014,147 @@ func setupInjectedToolTest(t *testing.T, fixture []byte, streaming bool, configu
10091014func TestErrorHandling (t * testing.T ) {
10101015 t .Parallel ()
10111016
1012- cases := []struct {
1013- name string
1014- fixture []byte
1015- createRequestFunc createRequestFunc
1016- configureFunc configureFunc
1017- responseHandlerFn func (streaming bool , resp * http.Response )
1018- }{
1019- {
1020- name : aibridge .ProviderAnthropic ,
1021- fixture : antErr ,
1022- createRequestFunc : createAnthropicMessagesReq ,
1023- configureFunc : func (addr string , client aibridge.Recorder , srvProxyMgr * mcp.ServerProxyManager ) (* aibridge.RequestBridge , error ) {
1024- logger := slogtest .Make (t , & slogtest.Options {IgnoreErrors : false }).Leveled (slog .LevelDebug )
1025- return aibridge .NewRequestBridge (t .Context (), []aibridge.Provider {aibridge .NewAnthropicProvider (anthropicCfg (addr , apiKey ), nil )}, logger , client , srvProxyMgr )
1017+ // Tests that errors which occur *before* a streaming response begins, or in non-streaming requests, are handled as expected.
1018+ t .Run ("non-stream error" , func (t * testing.T ) {
1019+ cases := []struct {
1020+ name string
1021+ fixture []byte
1022+ createRequestFunc createRequestFunc
1023+ configureFunc configureFunc
1024+ responseHandlerFn func (resp * http.Response )
1025+ }{
1026+ {
1027+ name : aibridge .ProviderAnthropic ,
1028+ fixture : antNonStreamErr ,
1029+ createRequestFunc : createAnthropicMessagesReq ,
1030+ configureFunc : func (addr string , client aibridge.Recorder , srvProxyMgr * mcp.ServerProxyManager ) (* aibridge.RequestBridge , error ) {
1031+ logger := slogtest .Make (t , & slogtest.Options {IgnoreErrors : false }).Leveled (slog .LevelDebug )
1032+ return aibridge .NewRequestBridge (t .Context (), []aibridge.Provider {aibridge .NewAnthropicProvider (anthropicCfg (addr , apiKey ), nil )}, logger , client , srvProxyMgr )
1033+ },
1034+ responseHandlerFn : func (resp * http.Response ) {
1035+ require .Equal (t , http .StatusBadRequest , resp .StatusCode )
1036+ body , err := io .ReadAll (resp .Body )
1037+ require .NoError (t , err )
1038+ require .Equal (t , "error" , gjson .GetBytes (body , "type" ).Str )
1039+ require .Equal (t , "invalid_request_error" , gjson .GetBytes (body , "error.type" ).Str )
1040+ require .Contains (t , gjson .GetBytes (body , "error.message" ).Str , "prompt is too long" )
1041+ },
10261042 },
1027- responseHandlerFn : func (streaming bool , resp * http.Response ) {
1028- if streaming {
1043+ {
1044+ name : aibridge .ProviderOpenAI ,
1045+ fixture : oaiNonStreamErr ,
1046+ createRequestFunc : createOpenAIChatCompletionsReq ,
1047+ configureFunc : func (addr string , client aibridge.Recorder , srvProxyMgr * mcp.ServerProxyManager ) (* aibridge.RequestBridge , error ) {
1048+ logger := slogtest .Make (t , & slogtest.Options {IgnoreErrors : false }).Leveled (slog .LevelDebug )
1049+ return aibridge .NewRequestBridge (t .Context (), []aibridge.Provider {aibridge .NewOpenAIProvider (aibridge .OpenAIConfig (anthropicCfg (addr , apiKey )))}, logger , client , srvProxyMgr )
1050+ },
1051+ responseHandlerFn : func (resp * http.Response ) {
1052+ require .Equal (t , http .StatusBadRequest , resp .StatusCode )
1053+ body , err := io .ReadAll (resp .Body )
1054+ require .NoError (t , err )
1055+ require .Equal (t , "context_length_exceeded" , gjson .GetBytes (body , "error.code" ).Str )
1056+ require .Equal (t , "invalid_request_error" , gjson .GetBytes (body , "error.type" ).Str )
1057+ require .Contains (t , gjson .GetBytes (body , "error.message" ).Str , "Input tokens exceed the configured limit" )
1058+ },
1059+ },
1060+ }
1061+
1062+ for _ , tc := range cases {
1063+ t .Run (tc .name , func (t * testing.T ) {
1064+ t .Parallel ()
1065+
1066+ for _ , streaming := range []bool {true , false } {
1067+ t .Run (fmt .Sprintf ("streaming=%v" , streaming ), func (t * testing.T ) {
1068+ t .Parallel ()
1069+
1070+ ctx , cancel := context .WithTimeout (t .Context (), time .Second * 30 )
1071+ t .Cleanup (cancel )
1072+
1073+ arc := txtar .Parse (tc .fixture )
1074+ t .Logf ("%s: %s" , t .Name (), arc .Comment )
1075+
1076+ files := filesMap (arc )
1077+ require .Len (t , files , 3 )
1078+ require .Contains (t , files , fixtureRequest )
1079+ require .Contains (t , files , fixtureStreamingResponse )
1080+ require .Contains (t , files , fixtureNonStreamingResponse )
1081+
1082+ reqBody := files [fixtureRequest ]
1083+ // Add the stream param to the request.
1084+ newBody , err := setJSON (reqBody , "stream" , streaming )
1085+ require .NoError (t , err )
1086+ reqBody = newBody
1087+
1088+ // Setup mock server.
1089+ mockResp := files [fixtureStreamingResponse ]
1090+ if ! streaming {
1091+ mockResp = files [fixtureNonStreamingResponse ]
1092+ }
1093+ mockSrv := newMockHTTPReflector (ctx , t , mockResp )
1094+ t .Cleanup (mockSrv .Close )
1095+
1096+ recorderClient := & mockRecorderClient {}
1097+
1098+ b , err := tc .configureFunc (mockSrv .URL , recorderClient , mcp .NewServerProxyManager (nil ))
1099+ require .NoError (t , err )
1100+
1101+ // Invoke request to mocked API via aibridge.
1102+ bridgeSrv := httptest .NewUnstartedServer (b )
1103+ bridgeSrv .Config .BaseContext = func (_ net.Listener ) context.Context {
1104+ return aibridge .AsActor (ctx , userID , nil )
1105+ }
1106+ bridgeSrv .Start ()
1107+ t .Cleanup (bridgeSrv .Close )
1108+
1109+ req := tc .createRequestFunc (t , bridgeSrv .URL , reqBody )
1110+ resp , err := http .DefaultClient .Do (req )
1111+ t .Cleanup (func () { _ = resp .Body .Close () })
1112+ require .NoError (t , err )
1113+
1114+ tc .responseHandlerFn (resp )
1115+ recorderClient .verifyAllInterceptionsEnded (t )
1116+ })
1117+ }
1118+ })
1119+ }
1120+ })
1121+
1122+ // Tests that errors which occur *during* a streaming response are handled as expected.
1123+ t .Run ("mid-stream error" , func (t * testing.T ) {
1124+ cases := []struct {
1125+ name string
1126+ fixture []byte
1127+ createRequestFunc createRequestFunc
1128+ configureFunc configureFunc
1129+ responseHandlerFn func (resp * http.Response )
1130+ }{
1131+ {
1132+ name : aibridge .ProviderAnthropic ,
1133+ fixture : antMidStreamErr ,
1134+ createRequestFunc : createAnthropicMessagesReq ,
1135+ configureFunc : func (addr string , client aibridge.Recorder , srvProxyMgr * mcp.ServerProxyManager ) (* aibridge.RequestBridge , error ) {
1136+ logger := slogtest .Make (t , & slogtest.Options {IgnoreErrors : false }).Leveled (slog .LevelDebug )
1137+ return aibridge .NewRequestBridge (t .Context (), []aibridge.Provider {aibridge .NewAnthropicProvider (anthropicCfg (addr , apiKey ), nil )}, logger , client , srvProxyMgr )
1138+ },
1139+ responseHandlerFn : func (resp * http.Response ) {
10291140 // Server responds first with 200 OK then starts streaming.
10301141 require .Equal (t , http .StatusOK , resp .StatusCode )
10311142
10321143 sp := aibridge .NewSSEParser ()
10331144 require .NoError (t , sp .Parse (resp .Body ))
10341145 require .Len (t , sp .EventsByType ("error" ), 1 )
10351146 require .Contains (t , sp .EventsByType ("error" )[0 ].Data , "Overloaded" )
1036- } else {
1037- require .Equal (t , resp .StatusCode , http .StatusInternalServerError )
1038- body , err := io .ReadAll (resp .Body )
1039- require .NoError (t , err )
1040- require .Contains (t , string (body ), "Overloaded" )
1041- }
1147+ },
10421148 },
1043- },
1044- {
1045- name : aibridge .ProviderOpenAI ,
1046- fixture : oaiErr ,
1047- createRequestFunc : createOpenAIChatCompletionsReq ,
1048- configureFunc : func (addr string , client aibridge.Recorder , srvProxyMgr * mcp.ServerProxyManager ) (* aibridge.RequestBridge , error ) {
1049- logger := slogtest .Make (t , & slogtest.Options {IgnoreErrors : false }).Leveled (slog .LevelDebug )
1050- return aibridge .NewRequestBridge (t .Context (), []aibridge.Provider {aibridge .NewOpenAIProvider (aibridge .OpenAIConfig (anthropicCfg (addr , apiKey )))}, logger , client , srvProxyMgr )
1051- },
1052- responseHandlerFn : func (streaming bool , resp * http.Response ) {
1053- if streaming {
1149+ {
1150+ name : aibridge .ProviderOpenAI ,
1151+ fixture : oaiMidStreamErr ,
1152+ createRequestFunc : createOpenAIChatCompletionsReq ,
1153+ configureFunc : func (addr string , client aibridge.Recorder , srvProxyMgr * mcp.ServerProxyManager ) (* aibridge.RequestBridge , error ) {
1154+ logger := slogtest .Make (t , & slogtest.Options {IgnoreErrors : false }).Leveled (slog .LevelDebug )
1155+ return aibridge .NewRequestBridge (t .Context (), []aibridge.Provider {aibridge .NewOpenAIProvider (aibridge .OpenAIConfig (anthropicCfg (addr , apiKey )))}, logger , client , srvProxyMgr )
1156+ },
1157+ responseHandlerFn : func (resp * http.Response ) {
10541158 // Server responds first with 200 OK then starts streaming.
10551159 require .Equal (t , http .StatusOK , resp .StatusCode )
10561160
@@ -1063,72 +1167,55 @@ func TestErrorHandling(t *testing.T) {
10631167 errEvent := sp .MessageEvents ()[len (sp .MessageEvents ())- 2 ] // Last event is termination marker ("[DONE]").
10641168 require .NotEmpty (t , errEvent )
10651169 require .Contains (t , errEvent .Data , "The server had an error while processing your request. Sorry about that!" )
1066- } else {
1067- require .Equal (t , resp .StatusCode , http .StatusInternalServerError )
1068- body , err := io .ReadAll (resp .Body )
1069- require .NoError (t , err )
1070- require .Contains (t , string (body ), "The server had an error while processing your request. Sorry about that" )
1071- }
1170+ },
10721171 },
1073- },
1074- }
1075-
1076- for _ , tc := range cases {
1077- t .Run (tc .name , func (t * testing.T ) {
1078- t .Parallel ()
1079-
1080- for _ , streaming := range []bool {true , false } {
1081- t .Run (fmt .Sprintf ("streaming=%v" , streaming ), func (t * testing.T ) {
1082- t .Parallel ()
1172+ }
10831173
1084- ctx , cancel := context .WithTimeout (t .Context (), time .Second * 30 )
1085- t .Cleanup (cancel )
1174+ for _ , tc := range cases {
1175+ t .Run (tc .name , func (t * testing.T ) {
1176+ t .Parallel ()
10861177
1087- arc := txtar . Parse ( tc . fixture )
1088- t . Logf ( "%s: %s" , t . Name (), arc . Comment )
1178+ ctx , cancel := context . WithTimeout ( t . Context (), time . Second * 30 )
1179+ t . Cleanup ( cancel )
10891180
1090- files := filesMap (arc )
1091- require .Len (t , files , 3 )
1092- require .Contains (t , files , fixtureRequest )
1093- require .Contains (t , files , fixtureStreamingResponse )
1094- require .Contains (t , files , fixtureNonStreamingResponse )
1181+ arc := txtar .Parse (tc .fixture )
1182+ t .Logf ("%s: %s" , t .Name (), arc .Comment )
10951183
1096- reqBody := files [fixtureRequest ]
1184+ files := filesMap (arc )
1185+ require .Len (t , files , 2 )
1186+ require .Contains (t , files , fixtureRequest )
1187+ require .Contains (t , files , fixtureStreamingResponse )
10971188
1098- // Add the stream param to the request.
1099- newBody , err := setJSON (reqBody , "stream" , streaming )
1100- require .NoError (t , err )
1101- reqBody = newBody
1189+ reqBody := files [fixtureRequest ]
11021190
1103- // Setup mock server.
1104- mockSrv := newMockServer (ctx , t , files , nil )
1105- mockSrv .statusCode = http .StatusInternalServerError
1106- t .Cleanup (mockSrv .Close )
1191+ // Setup mock server.
1192+ mockSrv := newMockServer (ctx , t , files , nil )
1193+ mockSrv .statusCode = http .StatusInternalServerError
1194+ t .Cleanup (mockSrv .Close )
11071195
1108- recorderClient := & mockRecorderClient {}
1196+ recorderClient := & mockRecorderClient {}
11091197
1110- b , err := tc .configureFunc (mockSrv .URL , recorderClient , mcp .NewServerProxyManager (nil ))
1111- require .NoError (t , err )
1198+ b , err := tc .configureFunc (mockSrv .URL , recorderClient , mcp .NewServerProxyManager (nil ))
1199+ require .NoError (t , err )
11121200
1113- // Invoke request to mocked API via aibridge.
1114- bridgeSrv := httptest .NewUnstartedServer (b )
1115- bridgeSrv .Config .BaseContext = func (_ net.Listener ) context.Context {
1116- return aibridge .AsActor (ctx , userID , nil )
1117- }
1118- bridgeSrv .Start ()
1119- t .Cleanup (bridgeSrv .Close )
1201+ // Invoke request to mocked API via aibridge.
1202+ bridgeSrv := httptest .NewUnstartedServer (b )
1203+ bridgeSrv .Config .BaseContext = func (_ net.Listener ) context.Context {
1204+ return aibridge .AsActor (ctx , userID , nil )
1205+ }
1206+ bridgeSrv .Start ()
1207+ t .Cleanup (bridgeSrv .Close )
11201208
1121- req := tc .createRequestFunc (t , bridgeSrv .URL , reqBody )
1122- resp , err := http .DefaultClient .Do (req )
1123- t .Cleanup (func () { _ = resp .Body .Close () })
1124- require .NoError (t , err )
1209+ req := tc .createRequestFunc (t , bridgeSrv .URL , reqBody )
1210+ resp , err := http .DefaultClient .Do (req )
1211+ t .Cleanup (func () { _ = resp .Body .Close () })
1212+ require .NoError (t , err )
11251213
1126- tc .responseHandlerFn (streaming , resp )
1127- recorderClient .verifyAllInterceptionsEnded (t )
1128- })
1129- }
1130- })
1131- }
1214+ tc .responseHandlerFn (resp )
1215+ recorderClient .verifyAllInterceptionsEnded (t )
1216+ })
1217+ }
1218+ })
11321219}
11331220
11341221// TestStableRequestEncoding validates that a given intercepted request and a
@@ -1297,6 +1384,44 @@ func createOpenAIChatCompletionsReq(t *testing.T, baseURL string, input []byte)
12971384 return req
12981385}
12991386
1387+ type mockHTTPReflector struct {
1388+ * httptest.Server
1389+ }
1390+
1391+ func newMockHTTPReflector (ctx context.Context , t * testing.T , resp []byte ) * mockHTTPReflector {
1392+ ref := & mockHTTPReflector {}
1393+
1394+ srv := httptest .NewUnstartedServer (http .HandlerFunc (func (w http.ResponseWriter , r * http.Request ) {
1395+ mock , err := http .ReadResponse (bufio .NewReader (bytes .NewBuffer (resp )), r )
1396+ require .NoError (t , err )
1397+ defer mock .Body .Close ()
1398+
1399+ // Copy headers from the mocked response.
1400+ for key , values := range mock .Header {
1401+ for _ , value := range values {
1402+ w .Header ().Add (key , value )
1403+ }
1404+ }
1405+
1406+ // Write the status code.
1407+ w .WriteHeader (mock .StatusCode )
1408+
1409+ // Copy the body.
1410+ _ , err = io .Copy (w , mock .Body )
1411+ require .NoError (t , err )
1412+ }))
1413+ srv .Config .BaseContext = func (_ net.Listener ) context.Context {
1414+ return ctx
1415+ }
1416+
1417+ srv .Start ()
1418+ t .Cleanup (srv .Close )
1419+
1420+ ref .Server = srv
1421+ return ref
1422+ }
1423+
1424+ // TODO: replace this with mockHTTPReflector.
13001425type mockServer struct {
13011426 * httptest.Server
13021427
0 commit comments