Skip to content

Commit 3927931

Browse files
[FLINK-38581][model] Support surfacing error message
1 parent b4e212e commit 3927931

File tree

7 files changed

+304
-60
lines changed

7 files changed

+304
-60
lines changed

docs/content.zh/docs/connectors/models/openai.md

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -168,7 +168,13 @@ FROM ML_PREDICT(
168168
<ul>
169169
<li><code>retry</code>: 重试发送请求。重试行为受 retry-num、retry-fallback-strategy、retry-backoff-strategy 和 retry-backoff-base-interval 限制。</li>
170170
<li><code>failover</code>: 抛出异常并使 Flink 作业失败。</li>
171-
<li><code>ignore</code>: 忽略导致错误的输入并继续。错误本身将记录在日志中。</li>
171+
<li><code>ignore</code>: 忽略导致错误的输入并继续执行。错误本身将被记录在日志中。您还可以指定以下元数据列,以便在输出流中显示错误信息。
172+
<ul>
173+
<li><code>error-string</code>: 与错误相关的消息</li>
174+
<li><code>http-status-code</code>: HTTP状态码</li>
175+
<li><code>http-headers-map</code>: 响应返回的Header</li>
176+
</ul>
177+
</li>
172178
</ul>
173179
</td>
174180
</tr>

docs/content/docs/connectors/models/openai.md

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -168,7 +168,13 @@ FROM ML_PREDICT(
168168
<ul>
169169
<li><code>retry</code>: Retry sending the request. The retrying behavior is limited by retry-num, retry-fallback-strategy, retry-backoff-strategy and retry-backoff-base-interval.</li>
170170
<li><code>failover</code>: Throw exceptions and fail the Flink job.</li>
171-
<li><code>ignore</code>: Ignore the input that caused the error and continue. The error itself would be recorded in log.</li>
171+
<li><code>ignore</code>: Ignore the input that caused the error and continue. The error itself would be recorded in log. You can also specify the following metadata columns to surface the error message in the output stream.
172+
<ul>
173+
<li><code>error-string</code>: A message associated with the error</li>
174+
<li><code>http-status-code</code>: The HTTP status code</li>
175+
<li><code>http-headers-map</code>: The headers returned with the response</li>
176+
</ul>
177+
</li>
172178
</ul>
173179
</td>
174180
</tr>

flink-models/flink-model-openai/src/main/java/org/apache/flink/model/openai/AbstractOpenAIModelFunction.java

Lines changed: 164 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -19,19 +19,28 @@
1919
package org.apache.flink.model.openai;
2020

2121
import org.apache.flink.configuration.ReadableConfig;
22+
import org.apache.flink.table.api.DataTypes;
2223
import org.apache.flink.table.api.config.ExecutionConfigOptions;
2324
import org.apache.flink.table.catalog.Column;
2425
import org.apache.flink.table.catalog.ResolvedSchema;
26+
import org.apache.flink.table.data.ArrayData;
27+
import org.apache.flink.table.data.GenericArrayData;
28+
import org.apache.flink.table.data.GenericMapData;
29+
import org.apache.flink.table.data.GenericRowData;
2530
import org.apache.flink.table.data.RowData;
31+
import org.apache.flink.table.data.StringData;
32+
import org.apache.flink.table.data.binary.BinaryStringData;
2633
import org.apache.flink.table.factories.ModelProviderFactory;
2734
import org.apache.flink.table.functions.AsyncPredictFunction;
2835
import org.apache.flink.table.functions.FunctionContext;
36+
import org.apache.flink.table.types.DataType;
2937
import org.apache.flink.table.types.logical.LogicalType;
3038
import org.apache.flink.table.types.logical.VarCharType;
3139
import org.apache.flink.util.ExceptionUtils;
3240
import org.apache.flink.util.Preconditions;
3341

3442
import com.openai.client.OpenAIClientAsync;
43+
import com.openai.core.http.Headers;
3544
import com.openai.errors.OpenAIIoException;
3645
import com.openai.errors.OpenAIServiceException;
3746
import org.slf4j.Logger;
@@ -41,10 +50,13 @@
4150

4251
import java.io.IOException;
4352
import java.time.Duration;
53+
import java.util.Arrays;
4454
import java.util.Collection;
4555
import java.util.Collections;
56+
import java.util.HashMap;
4657
import java.util.HashSet;
4758
import java.util.List;
59+
import java.util.Map;
4860
import java.util.Optional;
4961
import java.util.Set;
5062
import java.util.concurrent.CompletableFuture;
@@ -78,6 +90,7 @@ public abstract class AbstractOpenAIModelFunction extends AsyncPredictFunction {
7890
private final String model;
7991
@Nullable private final Integer maxContextSize;
8092
private final ContextOverflowAction contextOverflowAction;
93+
protected final List<String> outputColumnNames;
8194

8295
public AbstractOpenAIModelFunction(
8396
ModelProviderFactory.Context factoryContext, ReadableConfig config) {
@@ -140,6 +153,9 @@ public AbstractOpenAIModelFunction(
140153
factoryContext.getCatalogModel().getResolvedInputSchema(),
141154
new VarCharType(VarCharType.MAX_LENGTH),
142155
"input");
156+
157+
this.outputColumnNames =
158+
factoryContext.getCatalogModel().getResolvedOutputSchema().getColumnNames();
143159
}
144160

145161
@Override
@@ -184,23 +200,19 @@ public void close() throws Exception {
184200
protected void validateSingleColumnSchema(
185201
ResolvedSchema schema, LogicalType expectedType, String inputOrOutput) {
186202
List<Column> columns = schema.getColumns();
187-
if (columns.size() != 1) {
203+
List<String> physicalColumnNames =
204+
columns.stream()
205+
.filter(Column::isPhysical)
206+
.map(Column::getName)
207+
.collect(Collectors.toList());
208+
if (physicalColumnNames.size() != 1) {
188209
throw new IllegalArgumentException(
189210
String.format(
190-
"Model should have exactly one %s column, but actually has %s columns: %s",
191-
inputOrOutput,
192-
columns.size(),
193-
columns.stream().map(Column::getName).collect(Collectors.toList())));
194-
}
195-
196-
Column column = columns.get(0);
197-
if (!column.isPhysical()) {
198-
throw new IllegalArgumentException(
199-
String.format(
200-
"%s column %s should be a physical column, but is a %s.",
201-
inputOrOutput, column.getName(), column.getClass()));
211+
"Model should have exactly one %s physical column, but actually has %s physical columns: %s",
212+
inputOrOutput, physicalColumnNames.size(), physicalColumnNames));
202213
}
203214

215+
Column column = schema.getColumn(physicalColumnNames.get(0)).get();
204216
if (!expectedType.equals(column.getDataType().getLogicalType())) {
205217
throw new IllegalArgumentException(
206218
String.format(
@@ -210,6 +222,33 @@ protected void validateSingleColumnSchema(
210222
expectedType,
211223
column.getDataType().getLogicalType()));
212224
}
225+
226+
List<Column> metadataColumns =
227+
columns.stream()
228+
.filter(x -> x instanceof Column.MetadataColumn)
229+
.collect(Collectors.toList());
230+
if (!metadataColumns.isEmpty()) {
231+
Preconditions.checkArgument(
232+
"output".equals(inputOrOutput), "Only output schema supports metadata column");
233+
234+
for (Column metadataColumn : metadataColumns) {
235+
ErrorMessageMetadata errorMessageMetadata =
236+
ErrorMessageMetadata.get(metadataColumn.getName());
237+
Preconditions.checkNotNull(
238+
errorMessageMetadata,
239+
String.format(
240+
"Unexpected metadata column %s. Supported metadata columns:\n%s",
241+
metadataColumn.getName(),
242+
ErrorMessageMetadata.getAllKeysAndDescriptions()));
243+
Preconditions.checkArgument(
244+
errorMessageMetadata.dataType.equals(metadataColumn.getDataType()),
245+
String.format(
246+
"Expected metadata column %s to be of type %s, but is of type %s",
247+
metadataColumn.getName(),
248+
errorMessageMetadata.dataType,
249+
metadataColumn.getDataType()));
250+
}
251+
}
213252
}
214253

215254
/**
@@ -223,30 +262,52 @@ protected void validateSingleColumnSchema(
223262
* appropriate retry and error handling applied, or a null value if the request failed in
224263
* the middle and the failure should be ignored.
225264
*/
226-
protected <T> CompletableFuture<T> sendAsyncOpenAIRequest(
227-
Supplier<CompletableFuture<T>> requestSender) {
265+
protected <T> CompletableFuture<Collection<RowData>> sendAsyncOpenAIRequest(
266+
Supplier<CompletableFuture<T>> requestSender,
267+
Function<T, Collection<RowData>> converter) {
228268
CompletableFuture<T> result =
229269
retryAsync(
230270
requestSender,
231271
numRetry,
232272
retryBackoffBaseIntervalMs,
233273
retryBackoffStrategy,
234274
null);
235-
ErrorHandlingStrategy finalErrorHandlingStrategy =
236-
this.errorHandlingStrategy == ErrorHandlingStrategy.RETRY
237-
? this.retryFallbackStrategy
238-
: this.errorHandlingStrategy;
239-
if (finalErrorHandlingStrategy == ErrorHandlingStrategy.IGNORE) {
240-
result =
241-
result.exceptionally(
242-
(e) -> {
243-
LOG.warn(
244-
"The input row data failed to acquire a valid response. Ignoring the input.",
245-
e);
246-
return null;
247-
});
275+
return result.handle((x, throwable) -> this.convertToRowData(x, throwable, converter));
276+
}
277+
278+
private <T> Collection<RowData> convertToRowData(
279+
@Nullable T t,
280+
@Nullable Throwable throwable,
281+
Function<T, Collection<RowData>> converter) {
282+
if (throwable != null) {
283+
ErrorHandlingStrategy finalErrorHandlingStrategy =
284+
this.errorHandlingStrategy == ErrorHandlingStrategy.RETRY
285+
? this.retryFallbackStrategy
286+
: this.errorHandlingStrategy;
287+
if (finalErrorHandlingStrategy == ErrorHandlingStrategy.FAILOVER) {
288+
throw new RuntimeException(throwable);
289+
} else {
290+
LOG.warn(
291+
"The input row data failed to acquire a valid response. Ignoring the input.",
292+
throwable);
293+
GenericRowData rowData = new GenericRowData(this.outputColumnNames.size());
294+
boolean isMetadataSet = false;
295+
for (int i = 0; i < this.outputColumnNames.size(); i++) {
296+
String columnName = this.outputColumnNames.get(i);
297+
ErrorMessageMetadata errorMessageMetadata =
298+
ErrorMessageMetadata.get(columnName);
299+
if (errorMessageMetadata != null) {
300+
rowData.setField(i, errorMessageMetadata.converter.apply(throwable));
301+
isMetadataSet = true;
302+
}
303+
}
304+
return isMetadataSet ? Collections.singletonList(rowData) : Collections.emptyList();
305+
}
306+
} else if (t == null) {
307+
return Collections.emptyList();
308+
} else {
309+
return converter.apply(t);
248310
}
249-
return result;
250311
}
251312

252313
private <T> CompletableFuture<T> retryAsync(
@@ -348,4 +409,78 @@ public long getMinRetryTotalTime(long baseRetryInterval, int numRetry) {
348409

349410
public abstract long getMinRetryTotalTime(long baseRetryInterval, int numRetry);
350411
}
412+
413+
/**
414+
* Metadata that can be read from the output row about error messages. Referenced from Flink
415+
* HTTP Connector's ReadableMetadata.
416+
*/
417+
protected enum ErrorMessageMetadata {
418+
ERROR_STRING(
419+
"error-string",
420+
DataTypes.STRING(),
421+
x -> BinaryStringData.fromString(x.getMessage()),
422+
"A message associated with the error"),
423+
HTTP_STATUS_CODE(
424+
"http-status-code",
425+
DataTypes.INT(),
426+
e ->
427+
ExceptionUtils.findThrowable(e, OpenAIServiceException.class)
428+
.map(OpenAIServiceException::statusCode)
429+
.orElse(null),
430+
"The HTTP status code"),
431+
HTTP_HEADERS_MAP(
432+
"http-headers-map",
433+
DataTypes.MAP(DataTypes.STRING(), DataTypes.ARRAY(DataTypes.STRING())),
434+
e ->
435+
ExceptionUtils.findThrowable(e, OpenAIServiceException.class)
436+
.map(
437+
e1 -> {
438+
Map<StringData, ArrayData> map = new HashMap<>();
439+
Headers headers = e1.headers();
440+
for (String name : headers.names()) {
441+
map.put(
442+
BinaryStringData.fromString(name),
443+
new GenericArrayData(
444+
headers.values(name).stream()
445+
.map(
446+
BinaryStringData
447+
::fromString)
448+
.toArray()));
449+
}
450+
return new GenericMapData(map);
451+
})
452+
.orElse(null),
453+
"The headers returned with the response");
454+
455+
final String key;
456+
final DataType dataType;
457+
final Function<Throwable, Object> converter;
458+
final String description;
459+
460+
ErrorMessageMetadata(
461+
String key,
462+
DataType dataType,
463+
Function<Throwable, Object> converter,
464+
String description) {
465+
this.key = key;
466+
this.dataType = dataType;
467+
this.converter = converter;
468+
this.description = description;
469+
}
470+
471+
static @Nullable ErrorMessageMetadata get(String key) {
472+
for (ErrorMessageMetadata value : values()) {
473+
if (value.key.equals(key)) {
474+
return value;
475+
}
476+
}
477+
return null;
478+
}
479+
480+
static String getAllKeysAndDescriptions() {
481+
return Arrays.stream(values())
482+
.map(value -> value.key + ":\t" + value.description)
483+
.collect(Collectors.joining("\n"));
484+
}
485+
}
351486
}

flink-models/flink-model-openai/src/main/java/org/apache/flink/model/openai/OpenAIChatModelFunction.java

Lines changed: 27 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -33,11 +33,8 @@
3333
import com.openai.models.chat.completions.ChatCompletionCreateParams.ResponseFormat;
3434
import com.openai.services.async.chat.ChatCompletionServiceAsync;
3535

36-
import javax.annotation.Nullable;
37-
3836
import java.util.Arrays;
3937
import java.util.Collection;
40-
import java.util.Collections;
4138
import java.util.List;
4239
import java.util.concurrent.CompletableFuture;
4340
import java.util.stream.Collectors;
@@ -53,6 +50,7 @@ public class OpenAIChatModelFunction extends AbstractOpenAIModelFunction {
5350
private final String model;
5451
private final String systemPrompt;
5552
private final Configuration config;
53+
private final int outputColumnIndex;
5654

5755
public OpenAIChatModelFunction(
5856
ModelProviderFactory.Context factoryContext, ReadableConfig config) {
@@ -64,6 +62,21 @@ public OpenAIChatModelFunction(
6462
factoryContext.getCatalogModel().getResolvedOutputSchema(),
6563
new VarCharType(VarCharType.MAX_LENGTH),
6664
"output");
65+
this.outputColumnIndex = getOutputColumnIndex();
66+
}
67+
68+
private int getOutputColumnIndex() {
69+
for (int i = 0; i < this.outputColumnNames.size(); i++) {
70+
String columnName = this.outputColumnNames.get(i);
71+
if (ErrorMessageMetadata.get(columnName) == null) {
72+
// Prior checks have guaranteed that there is one and only one physical output
73+
// column.
74+
return i;
75+
}
76+
}
77+
throw new IllegalArgumentException(
78+
"There should be one and only one physical output column. Actual columns: "
79+
+ this.outputColumnNames);
6780
}
6881

6982
@Override
@@ -93,21 +106,21 @@ public CompletableFuture<Collection<RowData>> asyncPredictInternal(String input)
93106

94107
ChatCompletionCreateParams params = builder.build();
95108
ChatCompletionServiceAsync serviceAsync = client.chat().completions();
96-
return sendAsyncOpenAIRequest(() -> serviceAsync.create(params))
97-
.thenApply(this::convertToRowData);
109+
return sendAsyncOpenAIRequest(() -> serviceAsync.create(params), this::convertToRowData);
98110
}
99111

100-
private List<RowData> convertToRowData(@Nullable ChatCompletion chatCompletion) {
101-
if (chatCompletion == null) {
102-
return Collections.emptyList();
103-
}
104-
112+
private List<RowData> convertToRowData(ChatCompletion chatCompletion) {
105113
return chatCompletion.choices().stream()
106114
.map(
107-
choice ->
108-
GenericRowData.of(
109-
BinaryStringData.fromString(
110-
choice.message().content().orElse(""))))
115+
choice -> {
116+
GenericRowData rowData =
117+
new GenericRowData(this.outputColumnNames.size());
118+
rowData.setField(
119+
this.outputColumnIndex,
120+
BinaryStringData.fromString(
121+
choice.message().content().orElse("")));
122+
return rowData;
123+
})
111124
.collect(Collectors.toList());
112125
}
113126

0 commit comments

Comments
 (0)