1919package org .apache .flink .model .openai ;
2020
2121import org .apache .flink .configuration .ReadableConfig ;
22+ import org .apache .flink .table .api .DataTypes ;
2223import org .apache .flink .table .api .config .ExecutionConfigOptions ;
2324import org .apache .flink .table .catalog .Column ;
2425import org .apache .flink .table .catalog .ResolvedSchema ;
26+ import org .apache .flink .table .data .ArrayData ;
27+ import org .apache .flink .table .data .GenericArrayData ;
28+ import org .apache .flink .table .data .GenericMapData ;
29+ import org .apache .flink .table .data .GenericRowData ;
2530import org .apache .flink .table .data .RowData ;
31+ import org .apache .flink .table .data .StringData ;
32+ import org .apache .flink .table .data .binary .BinaryStringData ;
2633import org .apache .flink .table .factories .ModelProviderFactory ;
2734import org .apache .flink .table .functions .AsyncPredictFunction ;
2835import org .apache .flink .table .functions .FunctionContext ;
36+ import org .apache .flink .table .types .DataType ;
2937import org .apache .flink .table .types .logical .LogicalType ;
3038import org .apache .flink .table .types .logical .VarCharType ;
3139import org .apache .flink .util .ExceptionUtils ;
3240import org .apache .flink .util .Preconditions ;
3341
3442import com .openai .client .OpenAIClientAsync ;
43+ import com .openai .core .http .Headers ;
3544import com .openai .errors .OpenAIIoException ;
3645import com .openai .errors .OpenAIServiceException ;
3746import org .slf4j .Logger ;
4150
4251import java .io .IOException ;
4352import java .time .Duration ;
53+ import java .util .Arrays ;
4454import java .util .Collection ;
4555import java .util .Collections ;
56+ import java .util .HashMap ;
4657import java .util .HashSet ;
4758import java .util .List ;
59+ import java .util .Map ;
4860import java .util .Optional ;
4961import java .util .Set ;
5062import java .util .concurrent .CompletableFuture ;
@@ -78,6 +90,7 @@ public abstract class AbstractOpenAIModelFunction extends AsyncPredictFunction {
7890 private final String model ;
7991 @ Nullable private final Integer maxContextSize ;
8092 private final ContextOverflowAction contextOverflowAction ;
93+ protected final List <String > outputColumnNames ;
8194
8295 public AbstractOpenAIModelFunction (
8396 ModelProviderFactory .Context factoryContext , ReadableConfig config ) {
@@ -140,6 +153,9 @@ public AbstractOpenAIModelFunction(
140153 factoryContext .getCatalogModel ().getResolvedInputSchema (),
141154 new VarCharType (VarCharType .MAX_LENGTH ),
142155 "input" );
156+
157+ this .outputColumnNames =
158+ factoryContext .getCatalogModel ().getResolvedOutputSchema ().getColumnNames ();
143159 }
144160
145161 @ Override
@@ -184,23 +200,19 @@ public void close() throws Exception {
184200 protected void validateSingleColumnSchema (
185201 ResolvedSchema schema , LogicalType expectedType , String inputOrOutput ) {
186202 List <Column > columns = schema .getColumns ();
187- if (columns .size () != 1 ) {
203+ List <String > physicalColumnNames =
204+ columns .stream ()
205+ .filter (Column ::isPhysical )
206+ .map (Column ::getName )
207+ .collect (Collectors .toList ());
208+ if (physicalColumnNames .size () != 1 ) {
188209 throw new IllegalArgumentException (
189210 String .format (
190- "Model should have exactly one %s column, but actually has %s columns: %s" ,
191- inputOrOutput ,
192- columns .size (),
193- columns .stream ().map (Column ::getName ).collect (Collectors .toList ())));
194- }
195-
196- Column column = columns .get (0 );
197- if (!column .isPhysical ()) {
198- throw new IllegalArgumentException (
199- String .format (
200- "%s column %s should be a physical column, but is a %s." ,
201- inputOrOutput , column .getName (), column .getClass ()));
211+ "Model should have exactly one %s physical column, but actually has %s physical columns: %s" ,
212+ inputOrOutput , physicalColumnNames .size (), physicalColumnNames ));
202213 }
203214
215+ Column column = schema .getColumn (physicalColumnNames .get (0 )).get ();
204216 if (!expectedType .equals (column .getDataType ().getLogicalType ())) {
205217 throw new IllegalArgumentException (
206218 String .format (
@@ -210,6 +222,33 @@ protected void validateSingleColumnSchema(
210222 expectedType ,
211223 column .getDataType ().getLogicalType ()));
212224 }
225+
226+ List <Column > metadataColumns =
227+ columns .stream ()
228+ .filter (x -> x instanceof Column .MetadataColumn )
229+ .collect (Collectors .toList ());
230+ if (!metadataColumns .isEmpty ()) {
231+ Preconditions .checkArgument (
232+ "output" .equals (inputOrOutput ), "Only output schema supports metadata column" );
233+
234+ for (Column metadataColumn : metadataColumns ) {
235+ ErrorMessageMetadata errorMessageMetadata =
236+ ErrorMessageMetadata .get (metadataColumn .getName ());
237+ Preconditions .checkNotNull (
238+ errorMessageMetadata ,
239+ String .format (
240+ "Unexpected metadata column %s. Supported metadata columns:\n %s" ,
241+ metadataColumn .getName (),
242+ ErrorMessageMetadata .getAllKeysAndDescriptions ()));
243+ Preconditions .checkArgument (
244+ errorMessageMetadata .dataType .equals (metadataColumn .getDataType ()),
245+ String .format (
246+ "Expected metadata column %s to be of type %s, but is of type %s" ,
247+ metadataColumn .getName (),
248+ errorMessageMetadata .dataType ,
249+ metadataColumn .getDataType ()));
250+ }
251+ }
213252 }
214253
215254 /**
@@ -223,30 +262,52 @@ protected void validateSingleColumnSchema(
223262 * appropriate retry and error handling applied, or a null value if the request failed in
224263 * the middle and the failure should be ignored.
225264 */
226- protected <T > CompletableFuture <T > sendAsyncOpenAIRequest (
227- Supplier <CompletableFuture <T >> requestSender ) {
265+ protected <T > CompletableFuture <Collection <RowData >> sendAsyncOpenAIRequest (
266+ Supplier <CompletableFuture <T >> requestSender ,
267+ Function <T , Collection <RowData >> converter ) {
228268 CompletableFuture <T > result =
229269 retryAsync (
230270 requestSender ,
231271 numRetry ,
232272 retryBackoffBaseIntervalMs ,
233273 retryBackoffStrategy ,
234274 null );
235- ErrorHandlingStrategy finalErrorHandlingStrategy =
236- this .errorHandlingStrategy == ErrorHandlingStrategy .RETRY
237- ? this .retryFallbackStrategy
238- : this .errorHandlingStrategy ;
239- if (finalErrorHandlingStrategy == ErrorHandlingStrategy .IGNORE ) {
240- result =
241- result .exceptionally (
242- (e ) -> {
243- LOG .warn (
244- "The input row data failed to acquire a valid response. Ignoring the input." ,
245- e );
246- return null ;
247- });
275+ return result .handle ((x , throwable ) -> this .convertToRowData (x , throwable , converter ));
276+ }
277+
278+ private <T > Collection <RowData > convertToRowData (
279+ @ Nullable T t ,
280+ @ Nullable Throwable throwable ,
281+ Function <T , Collection <RowData >> converter ) {
282+ if (throwable != null ) {
283+ ErrorHandlingStrategy finalErrorHandlingStrategy =
284+ this .errorHandlingStrategy == ErrorHandlingStrategy .RETRY
285+ ? this .retryFallbackStrategy
286+ : this .errorHandlingStrategy ;
287+ if (finalErrorHandlingStrategy == ErrorHandlingStrategy .FAILOVER ) {
288+ throw new RuntimeException (throwable );
289+ } else {
290+ LOG .warn (
291+ "The input row data failed to acquire a valid response. Ignoring the input." ,
292+ throwable );
293+ GenericRowData rowData = new GenericRowData (this .outputColumnNames .size ());
294+ boolean isMetadataSet = false ;
295+ for (int i = 0 ; i < this .outputColumnNames .size (); i ++) {
296+ String columnName = this .outputColumnNames .get (i );
297+ ErrorMessageMetadata errorMessageMetadata =
298+ ErrorMessageMetadata .get (columnName );
299+ if (errorMessageMetadata != null ) {
300+ rowData .setField (i , errorMessageMetadata .converter .apply (throwable ));
301+ isMetadataSet = true ;
302+ }
303+ }
304+ return isMetadataSet ? Collections .singletonList (rowData ) : Collections .emptyList ();
305+ }
306+ } else if (t == null ) {
307+ return Collections .emptyList ();
308+ } else {
309+ return converter .apply (t );
248310 }
249- return result ;
250311 }
251312
252313 private <T > CompletableFuture <T > retryAsync (
@@ -348,4 +409,78 @@ public long getMinRetryTotalTime(long baseRetryInterval, int numRetry) {
348409
349410 public abstract long getMinRetryTotalTime (long baseRetryInterval , int numRetry );
350411 }
412+
413+ /**
414+ * Metadata that can be read from the output row about error messages. Referenced from Flink
415+ * HTTP Connector's ReadableMetadata.
416+ */
417+ protected enum ErrorMessageMetadata {
418+ ERROR_STRING (
419+ "error-string" ,
420+ DataTypes .STRING (),
421+ x -> BinaryStringData .fromString (x .getMessage ()),
422+ "A message associated with the error" ),
423+ HTTP_STATUS_CODE (
424+ "http-status-code" ,
425+ DataTypes .INT (),
426+ e ->
427+ ExceptionUtils .findThrowable (e , OpenAIServiceException .class )
428+ .map (OpenAIServiceException ::statusCode )
429+ .orElse (null ),
430+ "The HTTP status code" ),
431+ HTTP_HEADERS_MAP (
432+ "http-headers-map" ,
433+ DataTypes .MAP (DataTypes .STRING (), DataTypes .ARRAY (DataTypes .STRING ())),
434+ e ->
435+ ExceptionUtils .findThrowable (e , OpenAIServiceException .class )
436+ .map (
437+ e1 -> {
438+ Map <StringData , ArrayData > map = new HashMap <>();
439+ Headers headers = e1 .headers ();
440+ for (String name : headers .names ()) {
441+ map .put (
442+ BinaryStringData .fromString (name ),
443+ new GenericArrayData (
444+ headers .values (name ).stream ()
445+ .map (
446+ BinaryStringData
447+ ::fromString )
448+ .toArray ()));
449+ }
450+ return new GenericMapData (map );
451+ })
452+ .orElse (null ),
453+ "The headers returned with the response" );
454+
455+ final String key ;
456+ final DataType dataType ;
457+ final Function <Throwable , Object > converter ;
458+ final String description ;
459+
460+ ErrorMessageMetadata (
461+ String key ,
462+ DataType dataType ,
463+ Function <Throwable , Object > converter ,
464+ String description ) {
465+ this .key = key ;
466+ this .dataType = dataType ;
467+ this .converter = converter ;
468+ this .description = description ;
469+ }
470+
471+ static @ Nullable ErrorMessageMetadata get (String key ) {
472+ for (ErrorMessageMetadata value : values ()) {
473+ if (value .key .equals (key )) {
474+ return value ;
475+ }
476+ }
477+ return null ;
478+ }
479+
480+ static String getAllKeysAndDescriptions () {
481+ return Arrays .stream (values ())
482+ .map (value -> value .key + ":\t " + value .description )
483+ .collect (Collectors .joining ("\n " ));
484+ }
485+ }
351486}
0 commit comments